# WandbWriter

> A writer to write results to wandb.


In [None]:
#| default_exp wandb_writer

In [None]:
#| hide
from nbdev.showdoc import *  # type: ignore # noqa: F403

In [None]:
#| export 
from fastcore.utils import *
import pandas as pd
import wandb
import os
import argparse

In [None]:
#| export
class WandbWriter:

    def __init__(self, cfg):
        self.cfg = argparse.Namespace(**cfg)
        self.exp_name = self.cfg.project_name + self.cfg.now
        key = os.getenv("WANDB_API_KEY")
        wandb.login(key=key, verify=False)
        self.run = wandb.init(project=self.cfg.project_name, name= self.exp_name, config=self.cfg)

In [None]:
#| export
@patch
def write(self: WandbWriter, lst_metrics, round):
    table = wandb.Table(dataframe= pd.DataFrame(lst_metrics))
    avg_metrics = {key: sum(d[key] for d in lst_metrics) / len(lst_metrics) for key in lst_metrics[0]}
    all_metrics = {f"Round {round} Metrics": table}
    all_metrics.update(avg_metrics)
    self.run.log(all_metrics)

In [None]:
#| export   
@patch
def save(self: WandbWriter, res):
    df = pd.concat([pd.DataFrame(d1) for d1 in res])
    os.makedirs(self.cfg.res_dir, exist_ok=True)
    df.to_csv(f"{self.cfg.res_dir}/results.csv", index=False)

In [None]:
#| export   
@patch
def finish(self: WandbWriter):
    self.run.finish()

In [None]:
#| hide
import nbdev
nbdev.nbdev_export() # type: ignore  # noqa: E702
