In [21]:
from utils.general import get_logging_dir
import torch
import utils
import models
from tqdm.notebook import tqdm

from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary

from utils.training import get_loss_fn, get_optimizer, train_one_epoch
from utils.general import get_logging_dir, make_values_scalar

import data.adult

In [22]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
torch.random.manual_seed(42)

BATCH_SIZE = 128
DEVICE = torch.device("cpu")

test_dataloader = torch.utils.data.DataLoader(
    data.adult.get_dataset("test"), batch_size=BATCH_SIZE, shuffle=False
)

In [24]:
from copy import deepcopy


hparams = {
    "optimizer": "adam",
    "learning_rate": 0.001,
    "loss_fn": "bce_with_logits",
    "batch_norm": False,
    "layer_norm": False,
    "dropout_rate": 0.0,
    "client_epochs": 50,
    "num_clients": 10,
    "communication_rounds": 10,
    "client_data_distribution": "random",
}
architecture: list[int] = [512, 512, 512, 256, 128]

model_config = models.DenseNetwork.Config(
    107,
    architecture,
    1,
    torch.nn.ReLU,
    use_batch_norm=hparams["batch_norm"],
    use_layer_norm=hparams["layer_norm"],
    dropout_rate=hparams["dropout_rate"],
)

global_model = models.DenseNetwork(model_config)


dataloaders = data.adult.get_client_train_dataloaders(
    hparams["num_clients"], hparams["client_data_distribution"], BATCH_SIZE, True
)


writer = SummaryWriter(get_logging_dir("fed_avg", "adult"))


loss_fn = get_loss_fn(hparams["loss_fn"])


for communication_round in tqdm(range(hparams["communication_rounds"])):
    client_models = {
        client_id: deepcopy(global_model) for client_id in range(hparams["num_clients"])
    }
    optimizers = {
        client_id: get_optimizer(
            hparams["optimizer"],
            hparams["learning_rate"],
            client_models[client_id].parameters(),
        )
        for client_id in range(hparams["num_clients"])
    }

    for m in client_models.values():
        m.to(DEVICE)

    for client_id in range(hparams["num_clients"]):

        for epoch in tqdm(range(hparams["client_epochs"]), leave=False):
            epoch_loss = train_one_epoch(
                client_models[client_id],
                dataloaders[client_id],
                optimizers[client_id],
                loss_fn,
                DEVICE,
            )
            writer.add_scalar(
                f"Loss/train/client{client_id}",
                epoch_loss,
                communication_round * hparams["client_epochs"] + epoch,
            )

    global_model = utils.federated_learning.average_models(list(client_models.values()))
    y_hats, ys = utils.evaluation.evaluate(
        global_model, test_dataloader, from_logits=True, return_outputs_only=True
    )
    writer.add_scalar(
        "Loss/test",
        loss_fn(y_hats.to(torch.float), ys.to(torch.float)).item(),
        communication_round,
    )


metrics = utils.evaluation.evaluate(global_model, test_dataloader, from_logits=True)

writer.add_hparams(
    {**make_values_scalar(dict(model_config)), **hparams}, dict(metrics), run_name="."
)

writer.add_text("Model Summary", str(summary(global_model, input_size=(1, 107))))
dummy_input = torch.randn(1, 107)  # Example input
writer.add_graph(global_model, dummy_input)
writer.flush()
writer.close()

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]