In [1]:
import torch
from torch.utils.tensorboard import SummaryWriter

from data import artificial_1D_linear as data
from experiments.artificial_1D_linear.documentation import (
    evaluate,
    plot_data_split,
    plot_predictions,
)
from models import SmartAttentionLayer

from experiments.artificial_1D_linear.smart_fed_avg_util import (
    train_client,
    register_client_test_losses,
)
from utils.general import get_logging_dir


torch.manual_seed(42)

<torch._C.Generator at 0x127c58310>

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
DEVICE: torch.device = torch.device("cpu")
NUM_CLIENTS = 5

COMMUNICATION_ROUNDS = 20
BOOSTING_ROUNDS = 10
CLIENT_EPOCHS = 100

SPLIT_TYPE = "interval"

In [None]:
CLIENT_IDs = range(NUM_CLIENTS)

BATCH_SIZE = 64

INPUT_FEATURES = 1
OUTPUT_FEATURES = 1
PRED_ARCHITECTURE = [5, 5, 5, 5]
INPUT_IMPRTNC_ARCHITECTURE = [5, 5]
CLIENT_IMPRTNC_ARCHITECTURE = [5, 5]
SIMILARITY_THRESHOLD = 30

LOSS_FN = torch.nn.MSELoss()

In [None]:
def register_hyperparameters(writer, last_loss):
    writer.add_hparams(
        {
            "client_epochs": CLIENT_EPOCHS,
            "num_clients": NUM_CLIENTS,
            "communication_rounds": COMMUNICATION_ROUNDS,
            "split_type": SPLIT_TYPE,
            "pred_architecture": str(PRED_ARCHITECTURE),
            "input_imprtnc_architecture": str(INPUT_IMPRTNC_ARCHITECTURE),
            "client_imprtnc_architecture": str(CLIENT_IMPRTNC_ARCHITECTURE),
            "similarity_threshold": SIMILARITY_THRESHOLD,
            "boosting_rounds": BOOSTING_ROUNDS
        },
        {
            "MSE Test": last_loss,
        },
        run_name=".",
    )

In [None]:
from models import SmartAttentionBoosting
model_name = f"SmartAttentionBoosting_{NUM_CLIENTS}clients_{SPLIT_TYPE}-split"

writer = SummaryWriter(get_logging_dir(model_name, "artificial_1D_linear"))


global_model = SmartAttentionBoosting(
    INPUT_FEATURES,
        OUTPUT_FEATURES,
        NUM_CLIENTS,
        PRED_ARCHITECTURE,
        INPUT_IMPRTNC_ARCHITECTURE,
        CLIENT_IMPRTNC_ARCHITECTURE,
        device=DEVICE)

client_train_dataloaders = data.get_client_train_dataloaders(
    NUM_CLIENTS, SPLIT_TYPE, BATCH_SIZE, shuffle=True
)

plot_data_split(client_train_dataloaders, writer)

In [None]:
def is_cooling_off_epoch(cr: int):
    return cr > COMMUNICATION_ROUNDS // 2


for br in range(BOOSTING_ROUNDS):
    global_model.add_new_boosting_layer()

    for cr in range(COMMUNICATION_ROUNDS):

        clients = [
            global_model.get_client_model(
                client_id, add_noise=(not is_cooling_off_epoch(cr))
            )
            for client_id in CLIENT_IDs
        ]
        # train each client individually

        for client_no, client in zip(CLIENT_IDs, clients):
            train_client(
                client_no=client_no,
                client_model=client,
                data_loader=client_train_dataloaders[client_no],
                loss_fn=LOSS_FN,
                no_epochs=CLIENT_EPOCHS,
                communication_round=br * COMMUNICATION_ROUNDS + cr,
                writer=writer,
                device=DEVICE,
            )

        register_client_test_losses(
            clients=clients,
            client_ids=CLIENT_IDs,
            writer=writer,
            communication_round=br * COMMUNICATION_ROUNDS * CLIENT_EPOCHS
            + cr * CLIENT_EPOCHS,
            device=DEVICE,
            plot_client_predictions=True
        )
        global_model.register_new_client_models(
            clients,
            similarity_threshold_in_degree=(
                SIMILARITY_THRESHOLD if not is_cooling_off_epoch(cr) else 181
            ),
            method="combine" if not is_cooling_off_epoch(cr) else "average",
        )
        global_model.to(DEVICE)

        writer.add_scalar(
            "test_loss",
            evaluate(global_model, device=DEVICE),
            br * COMMUNICATION_ROUNDS * CLIENT_EPOCHS + cr * CLIENT_EPOCHS,
        )

        plot_predictions(
            global_model,
            model_name,
            writer,
            epoch=br * COMMUNICATION_ROUNDS + cr,
            device=DEVICE,
        )

        print(global_model)

plot_predictions(global_model, model_name, writer, device=DEVICE)
register_hyperparameters(writer, last_loss=evaluate(global_model))
writer.flush()
writer.close()

In [None]:
print(global_model)

In [None]:
print(global_model.query_network.full_representation())

In [None]:
for client in clients:
    print(client.prediction_network.full_representation())
    print(client.prediction_mask)