In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#| hide
import yaml

In [None]:
#| hide
def load_config(file_path):
    with open(file_path, 'r') as file:
        return yaml.safe_load(file)

In [None]:
#| hide
cfg = load_config('cfg.yaml')

In [None]:
#| hide
import os
from omegaconf import OmegaConf
cfg = OmegaConf.create(cfg)
cfg.data.name = 'MNIST'
from datetime import datetime
cfg.save_dir = os.path.join(cfg.save_dir, 
                            datetime.now().strftime("%Y%m%d_%H%M%S"))

In [None]:
import os
import torch.nn as nn
import torch
from fedai.vision.VisionBlock import VisionBlock
from fedai.utils import * # noqa: F403

In [None]:
class MLP(torch.nn.Module):
    def __init__(self, dim_in, dim_hidden, dim_out):
        super(MLP, self).__init__()
        self.layer_input = nn.Linear(dim_in, dim_hidden)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout()
        self.layer_hidden = nn.Linear(dim_hidden, dim_out)

    def forward(self, x):
        x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1])
        x = self.layer_input(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden(x)
        return x

In [None]:
def get_block(cfg, id, train=True):
    block = VisionBlock if cfg.data.modality == ['Vision'] else None
    return block(cfg, id, train=train)

In [None]:
def client_fn(client_cls, cfg, id, latest_round):
    model = MLP(28*28, 128, 10)
    criterion = torch.nn.CrossEntropyLoss()
    train_block = get_block(cfg, id)
    test_block = get_block(cfg, id, train=False)
    state = {'model': model, 'optimizer': None, 'criterion': criterion}

    if id in latest_round:
        comm_round = latest_round[id]
        state['model'] = load_state_from_disk(cfg, model, id, comm_round)
    
    return client_cls(id, cfg, state, block= [train_block, test_block])


In [None]:
from fedai.federated.agents import * # noqa: F403
from fedai.client_selector import *  # noqa: F403
from torch.nn.modules import CrossEntropyLoss # noqa: F403

client_selector = BaseClientSelector(cfg)  # noqa: F405
client_cls = FLAgent # noqa: F405
server = client_cls(cfg= cfg, block= None, id= 0, state= None, role= AgentRole.SERVER)
latest_round = {}

In [None]:
# # UNCOMMENT ME
# from fedai.trainers import *
# all_ids = client_selector.select()
# for t in range(1, 3):
#     lst_active_ids = all_ids[t]
#     len_clients_ds = []
#     for id in lst_active_ids:
#         client = client_fn(client_cls, cfg, id, latest_round)
#         len_clients_ds.append(200)
#         server.communicate(client, t) # read from the disk
#         trainer = Trainer(client) # the trainer object takes a client and make local training on its dataset.
#         client_history = trainer.train() # actual training loop
#         client.communicate(server, t) # save the state of the client to the disk
#         latest_round[id] = t # make sure you tell the client_fn where to look
#     server.aggregate(lst_active_ids, t, len_clients_ds, all_ids) # aggregate the models of the clients
    

data/MNIST/train data/MNIST/test

Dataset already generated.

data/MNIST/train data/MNIST/test

Dataset already generated.



  0%|          | 0/484 [08:15<?, ?it/s]
  0%|          | 0/484 [08:04<?, ?it/s]


role is AgentRole.CLIENT
role is AgentRole.CLIENT
data/MNIST/train data/MNIST/test

Dataset already generated.

data/MNIST/train data/MNIST/test

Dataset already generated.



  0%|          | 0/513 [08:55<?, ?it/s]
  0%|          | 0/513 [09:07<?, ?it/s]


role is AgentRole.CLIENT
role is AgentRole.CLIENT
role is AgentRole.SERVER
role is AgentRole.SERVER
data/MNIST/train data/MNIST/test

Dataset already generated.

data/MNIST/train data/MNIST/test

Dataset already generated.



  client_state_dict = torch.load(model_path, map_location='cpu')
  0%|          | 0/264 [04:41<?, ?it/s]
  0%|          | 0/264 [04:27<?, ?it/s]


role is AgentRole.CLIENT
role is AgentRole.CLIENT
data/MNIST/train data/MNIST/test

Dataset already generated.

data/MNIST/train data/MNIST/test

Dataset already generated.



  0%|          | 0/97 [01:42<?, ?it/s]
  0%|          | 0/97 [01:42<?, ?it/s]


role is AgentRole.CLIENT
role is AgentRole.CLIENT
role is AgentRole.SERVER
role is AgentRole.SERVER


  client_state_dict = torch.load(model_path, map_location='cpu')


In [None]:
# len_clients_ds, lst_active_ids, latest_round

([200, 200], array([10,  7]), {})

In [None]:
# server.aggregate([16, 19], 1, len_clients_ds, one_model= True)

  client_state_dict = torch.load(model_path, map_location='cpu')
