# WandbWriter

> A writer to write results to wandb.


In [None]:
#| default_exp wandb_writer

In [None]:
#| hide
from nbdev.showdoc import *  # type: ignore # noqa: F403

In [None]:
#| export 
from fastcore.utils import *
import pandas as pd
import wandb
import os
import argparse
import numpy as np

In [None]:
#| export
class WandbWriter:

    def __init__(self, cfg):
        self.cfg = argparse.Namespace(**cfg)
        self.exp_name = self.cfg.project_name + self.cfg.now
        key = os.getenv("WANDB_API_KEY")
        wandb.login(key=key, verify=False)
        self.run = wandb.init(project=self.cfg.project_name, name= self.exp_name, config=self.cfg)

`lst_train_histories` is of the shape `{'loss': 0.2, 'metrics': {"accuracy": 0.5, "f1": 0.6}}`, and
`test_history` is of the same shape.

We log Two different things:
- Tables of local train and test results
  - Note that the local train results table has length equal to the number of participnat clients in the round `t`.
- Average train results(average of the average local results) and average test results (for all clients even non-participant ones).

In [None]:
# #| export
# @patch
# def write(self: WandbWriter, lst_active_ids, lst_train_res, lst_test_res, round):
    
#     #[{'loss': 0.2, 'metrics': {"accuracy": 0.5, "f1": 0.6}}, 
#     # {'loss': 0.4, 'metrics': {"accuracy": 0.3, "f1": 0.2}}]

#     local_train_losses = [r["loss"] for r in lst_train_res] # for participants clients
#     local_train_metrics = [r["metrics"] for r in lst_train_res] # for participants clients

#     lst_train_histories = [{"client_id": client_id, "loss": loss, **metrics} for client_id, loss, metrics in zip(lst_active_ids, local_train_losses, local_train_metrics)]
#     train_table = wandb.Table(dataframe= pd.DataFrame(lst_train_histories))
#     ########################################################################
#     local_test_losses = [r["loss"] for r in lst_test_res] # for all cleints
#     local_test_metrics = [r["metrics"] for r in lst_test_res] # for all cleints

#     lst_test_histories = [{"client_id": client_id, "loss": loss, **metrics} for client_id, loss, metrics in zip(list(range(self.cfg.num_clients)), local_test_losses, local_test_metrics)]
#     test_table = wandb.Table(dataframe= pd.DataFrame(lst_test_histories))
#     ########################################################################
#     avg_train_losses = np.mean(local_train_losses)
#     avg_train_metrics = self.avg_lst_dicts(local_train_metrics)

#     avg_test_losses = np.mean(local_test_losses)
#     avg_test_metrics = self.avg_lst_dicts(local_test_metrics)

#     ########################################################################
#     train_metrics = {f"train_{k}": v for k, v in avg_train_metrics.items()}
#     test_metrics = {f"test_{k}": v for k, v in avg_test_metrics.items()}
#     ########################################################################

#     to_log = {"train_loss": avg_train_losses,
#               **train_metrics,
#               "avg_test_loss": avg_test_losses,
#               **test_metrics,
#               f"Round {round} Train metrics": train_table,
#               f"Round {round} test metrics": test_table}


#     self.run.log(to_log)


In [None]:
#| export
@patch
def write(self: WandbWriter, lst_active_ids, lst_train_res, lst_test_res, round):

    # Training results (for participating clients)
    local_train_losses = [r["loss"] for r in lst_train_res] if lst_train_res else []
    local_train_metrics = [r["metrics"] for r in lst_train_res] if lst_train_res else []

    lst_train_histories = [{"client_id": client_id, "loss": loss, **metrics} 
                           for client_id, loss, metrics in zip(lst_active_ids, local_train_losses, local_train_metrics)]
    train_table = wandb.Table(dataframe=pd.DataFrame(lst_train_histories))

    # Test results (for all clients)
    local_test_losses = [r["loss"] for r in lst_test_res] if lst_test_res else []
    local_test_metrics = [r["metrics"] for r in lst_test_res] if lst_test_res else []

    lst_test_histories = [{"client_id": client_id, "loss": loss, **metrics} 
                          for client_id, loss, metrics in zip(range(len(lst_test_res)), local_test_losses, local_test_metrics)]
    test_table = wandb.Table(dataframe=pd.DataFrame(lst_test_histories))

    # Compute averages safely
    avg_train_losses = np.mean(local_train_losses) if local_train_losses else 0.0
    avg_train_metrics = self.avg_lst_dicts(local_train_metrics) if local_train_metrics else {}

    avg_test_losses = np.mean(local_test_losses) if local_test_losses else 0.0
    avg_test_metrics = self.avg_lst_dicts(local_test_metrics) if local_test_metrics else {}

    # Prepare logs
    train_metrics = {f"train_{k}": v for k, v in avg_train_metrics.items()}
    test_metrics = {f"test_{k}": v for k, v in avg_test_metrics.items()}

    to_log = {"train_loss": avg_train_losses,
              **train_metrics,
              "avg_test_loss": avg_test_losses,
              **test_metrics,
              f"Round {round} Train metrics": train_table,
              f"Round {round} Test metrics": test_table}

    self.run.log(to_log)


In [None]:
#| export 
@patch
def avg_lst_dicts(self: WandbWriter, lst_dict):
    return {key: sum(d[key] for d in lst_dict) / len(lst_dict) for key in lst_dict[0]}

In [None]:
#| export   
@patch
def save(self: WandbWriter, res):
    df = pd.concat([pd.DataFrame(d1) for d1 in res])
    os.makedirs(self.cfg.res_dir, exist_ok=True)
    df.to_csv(f"{self.cfg.res_dir}/results.csv", index=False)

In [None]:
#| export   
@patch
def finish(self: WandbWriter):
    self.run.finish()

In [None]:
#| hide
import nbdev
nbdev.nbdev_export() # type: ignore  # noqa: E702
