In [None]:
import torch
from torch.utils.tensorboard import SummaryWriter

from data import artificial_1D_linear as data
from utils.general import get_logging_dir
from experiments.artificial_1D_linear.documentation import (
    evaluate,
    plot_data_split,
    plot_predictions,
)
from models import SmartAverageLayer

from experiments.artificial_1D_linear.smart_fed_avg_util import (
    train_client,
    register_client_test_losses,
)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
NUM_CLIENTS = 2

COMMUNICATION_ROUNDS = 10
CLIENT_EPOCHS = 100

SPLIT_TYPE = "random"

In [None]:
CLIENT_IDs = range(NUM_CLIENTS)

BATCH_SIZE = 64

INPUT_FEATURES = 1
OUTPUT_FEATURES = 1
ARCHITECTURE = [11, 12, 11, 12]

LOSS_FN = torch.nn.MSELoss()

In [None]:
def register_hyperparameters(writer, last_loss):
    writer.add_hparams(
        {
            "client_epochs": CLIENT_EPOCHS,
            "num_clients": NUM_CLIENTS,
            "communication_rounds": COMMUNICATION_ROUNDS,
            "split_type": SPLIT_TYPE,
            "architecture": str(ARCHITECTURE),
        },
        {
            "MSE Test": last_loss,
        },
        run_name=".",
    )

In [None]:
model_name = f"MyFed_{NUM_CLIENTS}clients_{SPLIT_TYPE}-split"
writer = SummaryWriter(get_logging_dir(model_name, "artificial_1D_linear"))

clients = [
    SmartAverageLayer.initialize_from_scratch(INPUT_FEATURES, OUTPUT_FEATURES, NUM_CLIENTS, client_id, ARCHITECTURE)
    for client_id in CLIENT_IDs
]

client_train_dataloaders = data.get_client_train_dataloaders(
    NUM_CLIENTS, SPLIT_TYPE, BATCH_SIZE, shuffle=True
)

plot_data_split(client_train_dataloaders, writer)

In [None]:
print(clients[0])

In [None]:
for cr in range(COMMUNICATION_ROUNDS):
    # train each client individually
    for client_no, client in zip(CLIENT_IDs, clients):
        train_client(
            client_no=client_no,
            client_model=client,
            data_loader=client_train_dataloaders[client_no],
            loss_fn=LOSS_FN,
            no_epochs=CLIENT_EPOCHS,
            communication_round=cr,
            writer=writer,
        )

    register_client_test_losses(
        clients=clients, client_ids=CLIENT_IDs, writer=writer, communication_round=cr
    )

    global_model = SmartAverageLayer.get_global_model(
        clients,
        similarity_threshold_in_degree=15,
    )

    writer.add_scalar("test_loss", evaluate(global_model), cr * CLIENT_EPOCHS)

    #### ----- Unimportant ----- NOTE: remove####################################################################################
    from IPython.display import clear_output

    clear_output(wait=True)
    print(f"Communication Round {cr}/{COMMUNICATION_ROUNDS}")
    print("\r global model: ", global_model, end="")
    # print("\n\n GLOBAL MODEL\n", list(global_model.named_parameters()))

    #### //// ----- Unimportant ------------------------------------------------------------------------------------------------

    clients = [
        global_model.get_client_model(
            client_id, add_noise=cr < COMMUNICATION_ROUNDS - 1
        )
        for client_id in CLIENT_IDs
    ]

    ###!!!!!!!!!!!!! NOTE:remove################################################################################################
    plot_predictions(global_model, model_name, writer, epoch=cr)
    # --------------------------------------------------------------------------------------------------------------------------

plot_predictions(global_model, model_name, writer)
register_hyperparameters(writer, last_loss=evaluate(global_model))
writer.close()

In [None]:
print(global_model.prediction_network.full_representation())
print(global_model.prediction_mask)

In [None]:
for client in clients:
    print(client.prediction_network.full_representation())
    print(client.prediction_mask)