In [1]:
from box import Box
cfg = Box(
    {
        'model': 'MLP',
        'dataset': {
            'name': 'mnist',
            'num_clients': 10,
            'split_type': 'iid',
            'require_local_test': True,
            'global_local_ratio': 0.5,
            'val': False,
        },
        'train': {
            'algorithm': 'FedAvg',
            'num_clients_per_round': 3,
            'local_rounds': 3,
            'epochs': 10,
            'optimizer': 'SGD',
            'lr': 0.001,
            'device': 'cuda',
        },
        'loader':{
            'train_batch_size': 64,
            'eval_batch_size': 1000,
        }
        
    }
)

In [2]:
from datasets import load_dataset
from box import Box
# No validation set as we do not need temprature scalling now.
trainset, testset, valset = load_dataset(
    data_cfg=cfg.dataset)
print(f'trainset: {len(trainset)} | testset {len(testset)}' )

trainset: 60000 | testset 10000


Split the datasets-IID

In [3]:
from split import IIDSplitter
splitter = IIDSplitter(cfg.dataset.num_clients)
train_data_map = splitter.split(trainset, train=True, local=False)
test_data_map = splitter.split(testset, train=False, local=True,
                               global_local_ratio=0.5)

Splitting dataset into 10 clients.
{0: 6000, 1: 6000, 2: 6000, 3: 6000, 4: 6000, 5: 6000, 6: 6000, 7: 6000, 8: 6000, 9: 6000}
Splitting dataset into 10 clients.
{0: 500, 1: 500, 2: 500, 3: 500, 4: 500, 5: 500, 6: 500, 7: 500, 8: 500, 9: 500}


initialize models

In [4]:
from model import ModelFactory
model = ModelFactory().create_model(cfg)
print(model)

MLP(
  (layers): ModuleList(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)


In [5]:
from algorithm import FedAvg
from torch.optim import SGD, Adam
fedavg = FedAvg(
    model = model,
    num_clients=cfg.dataset.num_clients,
    num_clients_per_round=cfg.train.num_clients_per_round,
    local_rounds=cfg.train.local_rounds, epochs=cfg.train.epochs,
    optimizer=cfg.train.optimizer, lr=cfg.train.lr,
    device=cfg.train.device,
    trainset=trainset, testset=testset, valset=valset,
    local_test=True, local_val=False,
    train_data_map=train_data_map, test_data_map=test_data_map,
    valid_data_map=None,
    train_batch_size=cfg.loader.train_batch_size,
    eval_batch_size=cfg.loader.eval_batch_size
)

In [6]:
fedavg.run()

Global round 1 started
Client 2 started local training
Client 2 local round 1/3


100%|██████████| 94/94 [00:03<00:00, 28.39it/s, Running Average Loss=2.1087]


Client 2 local round 2/3


100%|██████████| 94/94 [00:01<00:00, 63.68it/s, Running Average Loss=1.3975]


Client 2 local round 3/3


100%|██████████| 94/94 [00:01<00:00, 63.46it/s, Running Average Loss=0.7983]


Client 2 finished local training
Local Training {'average_loss': 0.0006557872196038564, 'accuracy': 0.8488333333333333}
MulticlassAccuracy: 0.8119308948516846
MulticlassCalibrationError: 0.23069055378437042
Local Test: {'MulticlassAccuracy': 0.8119308948516846, 'MulticlassCalibrationError': 0.23069055378437042}
Client 9 started local training
Client 9 local round 1/3


100%|██████████| 94/94 [00:01<00:00, 62.54it/s, Running Average Loss=2.0382]


Client 9 local round 2/3


100%|██████████| 94/94 [00:01<00:00, 63.18it/s, Running Average Loss=1.1614]


Client 9 local round 3/3


100%|██████████| 94/94 [00:01<00:00, 63.88it/s, Running Average Loss=0.6046]


Client 9 finished local training
Local Training {'average_loss': 0.00048611686130364736, 'accuracy': 0.8905}
MulticlassAccuracy: 0.8219518661499023
MulticlassCalibrationError: 0.1708812266588211
Local Test: {'MulticlassAccuracy': 0.8219518661499023, 'MulticlassCalibrationError': 0.1708812266588211}
Client 8 started local training
Client 8 local round 1/3


100%|██████████| 94/94 [00:01<00:00, 60.65it/s, Running Average Loss=2.0670]


Client 8 local round 2/3


100%|██████████| 94/94 [00:01<00:00, 63.60it/s, Running Average Loss=1.2822]


Client 8 local round 3/3


100%|██████████| 94/94 [00:01<00:00, 62.73it/s, Running Average Loss=0.7183]


Client 8 finished local training
Local Training {'average_loss': 0.0005895339945952097, 'accuracy': 0.8533333333333334}
MulticlassAccuracy: 0.8204490542411804
MulticlassCalibrationError: 0.1793261468410492
Local Test: {'MulticlassAccuracy': 0.8204490542411804, 'MulticlassCalibrationError': 0.1793261468410492}
All clients finished local training
Aggregating global model
MulticlassAccuracy: 0.8421181201934814
MulticlassCalibrationError: 0.17298672199249268
Global Test: {'MulticlassAccuracy': 0.8421181201934814, 'MulticlassCalibrationError': 0.17298672199249268}


TypeError: 'NoneType' object is not subscriptable