In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#| hide
import yaml

In [None]:
#| hide
def load_config(file_path):
    with open(file_path, 'r') as file:
        return yaml.safe_load(file)

In [None]:
#| hide
cfg = load_config('cfg.yaml')

In [None]:
#| hide
from omegaconf import OmegaConf
cfg = OmegaConf.create(cfg)
cfg.data.name = 'MNIST'

In [None]:
import os
import torch.nn as nn
import torch
from fedai.vision.VisionBlock import VisionBlock
from fedai.utils import * # noqa: F403

In [None]:
class MLP(torch.nn.Module):
    def __init__(self, dim_in, dim_hidden, dim_out):
        super(MLP, self).__init__()
        self.layer_input = nn.Linear(dim_in, dim_hidden)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout()
        self.layer_hidden = nn.Linear(dim_hidden, dim_out)

    def forward(self, x):
        x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1])
        x = self.layer_input(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden(x)
        return x

In [None]:
def get_block(cfg, id, train=True):
    block = VisionBlock if cfg.data.modality == ['Vision'] else None
    return block(cfg, id, train=train)

In [None]:
def client_fn(client_cls, cfg, id, latest_round):
    model = MLP(28*28, 128, 10)
    criterion = torch.nn.CrossEntropyLoss()
    train_block = get_block(cfg, id)
    test_block = get_block(cfg, id, train=False)
    state = {'model': model, 'optimizer': None, 'criterion': criterion}

    if id in latest_round:
        comm_round = latest_round[id]
        state['model'] = load_state_from_disk(cfg, model, id, comm_round)
    
    return client_cls(id, cfg, state, block= [train_block, test_block])


In [None]:
from fedai.federated.agents import * # noqa: F403
from fedai.client_selector import *  # noqa: F403
from torch.nn.modules import CrossEntropyLoss # noqa: F403

client_selector = BaseClientSelector(cfg)  # noqa: F405
client_cls = FLAgent # noqa: F405
server = client_cls(cfg= cfg, block= None, id= 0, state= None, role= AgentRole.SERVER)
latest_round = {}

In [None]:
client = client_fn(client_cls, cfg, 0, latest_round)

data/MNIST/train data/MNIST/test

Dataset already generated.

data/MNIST/train data/MNIST/test

Dataset already generated.



In [None]:
print(client)

FLAgent: FLAgent
    Index : 0
    Model: MLP
    Criterion: CrossEntropyLoss
    Optimizer: Adam


In [None]:
# from fedai.trainers import * # noqa: F403
# trainer = Trainer(client) # noqa: F405

In [None]:
# all_metrics = trainer.train()


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

In [None]:
# all_metrics

{'train_loss': 0.02933939113452834,
 'test_loss': 0.022610422519483794,
 'train_accuracy': 0.9012605042016826,
 'test_accuracy': 0.9498464912280699}

In [None]:
# all_ids = client_selector.select()
# for t in range(1, 3):
#     lst_active_ids = all_ids[t]
#     len_clients_ds = []
#     for id in lst_active_ids:
#         client = client_fn(client_cls, cfg, id, latest_round)
#         len_clients_ds.append(200)
#         server.communicate(client, t) # read from the disk
#         trainer = Trainer(client) # the trainer object takes a client and make local training on its dataset.
#         client_history = trainer.train() # actual training loop
#         client.communicate(server, t) # save the state of the client to the disk

#     server.aggregate(lst_active_ids, t) # aggregate the models of the clients


data/MNIST/train data/MNIST/test

Dataset already generated.

data/MNIST/train data/MNIST/test

Dataset already generated.





[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

ConfigAttributeError: Missing key save_dir
    full_key: save_dir
    object_type=dict