In [None]:
from utils.general import get_logging_dir
import torch
import utils
import models
from tqdm.notebook import tqdm

from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary

from utils.training import get_loss_fn, get_optimizer, train_one_epoch
from utils.general import get_logging_dir, make_values_scalar

import data.adult
import json

In [8]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
torch.random.manual_seed(42)

BATCH_SIZE = 128
DEVICE = torch.device("cpu")

MAX_THETA = 181 #degree -> more than any two vectors can be apart

test_dataloader = torch.utils.data.DataLoader(
    data.adult.get_dataset("test"), batch_size=BATCH_SIZE, shuffle=False
)

In [15]:
from copy import deepcopy

from models import SmartAttentionLayer


hparams = {
    "optimizer": "adam",
    "learning_rate": 0.001,
    "loss_fn": "bce_with_logits",
    "batch_norm": "---",
    "layer_norm": "---",
    "dropout_rate": "---",
    "client_epochs": 1,
    "num_clients": 10,
    "communication_rounds_training": 1,
    "communication_rounds_aligning": 1,
    "client_data_distribution": "random",
    "similarity_threshold_in_degree (theta)": 40,
    "aligning_method": "combine",
    "added_noise_in_training": False,
}
prediction_network_architecture: list[int] = [40, 40, 40, 20, 10]
input_importance_network_architecture: list[int] = [10, 10, 10]
client_importance_network_architecture: list[int] = [10, 10, 10]


hparams["prediction_network_architecture"] = str(prediction_network_architecture)
hparams["input_importance_network_architecture"] = str(
    input_importance_network_architecture
)
hparams["client_importance_network_architecture"] = str(
    client_importance_network_architecture
)

global_model = SmartAttentionLayer.initialize_from_scratch(
    107,
    1,
    hparams["num_clients"],
    None,
    prediction_network_architecture,
    input_importance_network_architecture,
    client_importance_network_architecture,
    device=DEVICE,
)


dataloaders = data.adult.get_client_train_dataloaders(
    hparams["num_clients"], hparams["client_data_distribution"], BATCH_SIZE, True
)


writer = SummaryWriter(get_logging_dir("smart_attention", "adult"))


loss_fn = get_loss_fn(hparams["loss_fn"])


for communication_round in tqdm(
    range(
        hparams["communication_rounds_training"]
        + hparams["communication_rounds_aligning"]
    )
):
    is_aligning_round = communication_round >= hparams["communication_rounds_training"]
    client_models = {
        client_id: global_model.get_client_model(
            client_id, hparams["added_noise_in_training"] and not is_aligning_round
        )
        for client_id in range(hparams["num_clients"])
    }
    optimizers = {
        client_id: get_optimizer(
            hparams["optimizer"],
            hparams["learning_rate"],
            client_models[client_id].parameters(),
        )
        for client_id in range(hparams["num_clients"])
    }

    for m in client_models.values():
        m.to(DEVICE)

    for client_id in range(hparams["num_clients"]):

        for epoch in tqdm(range(hparams["client_epochs"]), leave=False):
            epoch_loss = train_one_epoch(
                client_models[client_id],
                dataloaders[client_id],
                optimizers[client_id],
                loss_fn,
                DEVICE,
            )
            writer.add_scalar(
                f"Loss/train/client{client_id}",
                epoch_loss,
                communication_round * hparams["client_epochs"] + epoch,
            )

    global_model = SmartAttentionLayer.get_global_model(
        list(client_models.values()),
        hparams["similarity_threshold_in_degree (theta)"] if is_aligning_round else MAX_THETA,
        method=hparams["aligning_method"] if is_aligning_round else "combine",
    )
    y_hats, ys = utils.evaluation.evaluate(
        global_model, test_dataloader, from_logits=True, return_outputs_only=True
    )
    writer.add_scalar(
        "Loss/test",
        loss_fn(y_hats.to(torch.float), ys.to(torch.float)).item(),
        communication_round * hparams["client_epochs"],
    )
    writer.add_scalar(
        "total_params",
        summary(global_model).total_params,
        communication_round * hparams["client_epochs"],
    )


metrics = utils.evaluation.evaluate(global_model, test_dataloader, from_logits=True)

writer.add_hparams(hparams, dict(metrics), run_name=".")

writer.add_text("Model Summary", str(summary(global_model, input_size=(1, 107))))
writer.add_text("hparams", json.dumps(hparams, indent=4))

dummy_input = torch.randn(1, 107)  # Example input
writer.add_graph(global_model, dummy_input)
writer.flush()
writer.close()

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  if x.shape[0] != self._cached_batch_size:
  scale_factor = 1 / sqrt(query.size(-1))
