In [None]:
import torch
from torch.utils.tensorboard import SummaryWriter

from utils.general import get_logging_dir
from data import artificial_1D_linear as data
from experiments.artificial_1D_linear.documentation import (
    evaluate,
    plot_data_split,
    plot_predictions,
)
from models import SmartAttentionLayer

from experiments.artificial_1D_linear.smart_fed_avg_util import (
    train_client,
    register_client_test_losses,
)


torch.manual_seed(42)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
DEVICE: torch.device = torch.device("cpu")
NUM_CLIENTS = 10

COMMUNICATION_ROUNDS = 20
CLIENT_EPOCHS = 150

SPLIT_TYPE = "interval"

In [None]:
CLIENT_IDs = range(NUM_CLIENTS)

BATCH_SIZE = 64

INPUT_FEATURES = 1
OUTPUT_FEATURES = 1
PRED_ARCHITECTURE = [5, 5, 5, 5]
INPUT_IMPRTNC_ARCHITECTURE = [5, 5]
CLIENT_IMPRTNC_ARCHITECTURE = [5, 5]
SIMILARITY_THRESHOLD = 30

LOSS_FN = torch.nn.MSELoss()

In [None]:
def register_hyperparameters(writer, last_loss):
    writer.add_hparams(
        {
            "client_epochs": CLIENT_EPOCHS,
            "num_clients": NUM_CLIENTS,
            "communication_rounds": COMMUNICATION_ROUNDS,
            "split_type": SPLIT_TYPE,
            "pred_architecture": str(PRED_ARCHITECTURE),
            "input_imprtnc_architecture": str(INPUT_IMPRTNC_ARCHITECTURE),
            "client_imprtnc_architecture": str(CLIENT_IMPRTNC_ARCHITECTURE),
            "similarity_threshold": SIMILARITY_THRESHOLD,
        },
        {
            "MSE Test": last_loss,
        },
        run_name=".",
    )

In [None]:
model_name = f"SmartAttentionLayer_{NUM_CLIENTS}clients_{SPLIT_TYPE}-split"
writer = SummaryWriter(get_logging_dir(model_name, "artificial_1D_linear"))

clients = [
    SmartAttentionLayer.initialize_from_scratch(
        INPUT_FEATURES,
        OUTPUT_FEATURES,
        NUM_CLIENTS,
        client_id,
        PRED_ARCHITECTURE,
        INPUT_IMPRTNC_ARCHITECTURE,
        CLIENT_IMPRTNC_ARCHITECTURE,
        device=DEVICE,
    )
    for client_id in CLIENT_IDs
]

client_train_dataloaders = data.get_client_train_dataloaders(
    NUM_CLIENTS, SPLIT_TYPE, BATCH_SIZE, shuffle=True
)

plot_data_split(client_train_dataloaders, writer)

In [None]:
print(clients[0])

In [None]:
def is_cooling_off_epoch(cr: int):
    return cr > COMMUNICATION_ROUNDS // 2


for cr in range(COMMUNICATION_ROUNDS):
    # train each client individually
    for client_no, client in zip(CLIENT_IDs, clients):
        train_client(
            client_no=client_no,
            client_model=client,
            data_loader=client_train_dataloaders[client_no],
            loss_fn=LOSS_FN,
            no_epochs=CLIENT_EPOCHS,
            communication_round=cr,
            writer=writer,
            device=DEVICE,
        )

    register_client_test_losses(
        clients=clients,
        client_ids=CLIENT_IDs,
        writer=writer,
        communication_round=cr,
        device=DEVICE,
    )

    global_model = SmartAttentionLayer.get_global_model(
        clients,
        similarity_threshold_in_degree=SIMILARITY_THRESHOLD if not is_cooling_off_epoch(cr) else 181,
        method = "combine" if not is_cooling_off_epoch(cr) else "average"
    )
    global_model.to(DEVICE)

    writer.add_scalar("test_loss", evaluate(global_model, device=DEVICE), cr * CLIENT_EPOCHS)

    clients = [
        global_model.get_client_model(
            client_id, add_noise=(not is_cooling_off_epoch(cr))
        )
        for client_id in CLIENT_IDs
    ]

    #plot_predictions(global_model, model_name, writer, epoch=cr, device=DEVICE)

#plot_predictions(global_model, model_name, writer, device=DEVICE)
register_hyperparameters(writer, last_loss=evaluate(global_model))
writer.flush()
writer.close()

In [None]:
print(global_model)

In [None]:
print(global_model.query_network.full_representation())

In [None]:
for client in clients:
    print(client.prediction_network.full_representation())
    print(client.prediction_mask)