In [1]:
import torch
import numpy as np
import random, os

from utils.lenet import LeNet
from utils.evaluation import evaluate_model
from utils.fedavg import federated_avg
from utils.train import client_training
from utils.distillation import ensemble_distillation
from utils.datamodule import data_loaders

In [2]:
SEED = 42

random.seed(SEED)
os.environ['PYTHONHASHSEED'] = str(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCHSIZE = 64
conf = {
    'num_clients': 20,
    'C_count': 5,
    'server_clients_split': .05,
    'batch_size': BATCHSIZE,
    'lr': .005,
    'epochs': 20,
    'num_rounds': 10,
    'TESTRUN': False,
    'device': DEVICE,
    'global_conf': {'lr': 1e-5,
                    'batch_size': BATCHSIZE,
                    'device': DEVICE,
                    'epochs': 5,},
}

In [4]:
def federated_learning(conf):
    # Set device
    device = conf['device']

    # Initializations
    server_dataloader, clients_dataloaders, test_loader = data_loaders(conf)
    global_model = LeNet().to(device)

    # Communication rounds
    for roundidx in range(conf['num_rounds']):
        print(f"COMMUNICATION ROUND {roundidx+1}")
        # Get C unique random indices from the list
        fraction_indices = random.sample(range(len(clients_dataloaders)), conf['C_count'])
        print(f"Current round's clients: {fraction_indices}")

        # Get the DataLoaders corresponding to these indices
        client_fraction_loaders = [clients_dataloaders[i] for i in fraction_indices]
        
        # Train local models and get updates
        delta_sum_dict, delta_dicts = client_training(client_fraction_loaders, global_model.state_dict(), fraction_indices, conf)

        # Perform FedAVG
        global_model = federated_avg(global_model, delta_sum_dict)
        
        # Perform distillation
        global_state_dict = ensemble_distillation(delta_dicts, global_model.state_dict(), server_dataloader, conf['global_conf'])
        global_model.load_state_dict(global_state_dict)

        # Evaluate model
        evaluate_model(global_model, test_loader, roundidx, conf)

        print(f"\nround {roundidx + 1} complete!", end='\n\n\n')


In [5]:
federated_learning(conf)

COMMUNICATION ROUND 1
Current round's clients: [3, 0, 8, 7, 16]
Client 3 training complete!
Client 0 training complete!
Client 8 training complete!
Client 7 training complete!
Client 16 training complete!
Accuracy of the global model at round 1: 28.64%

round 1 complete!


COMMUNICATION ROUND 2
Current round's clients: [4, 3, 17, 2, 13]
Client 4 training complete!
Client 3 training complete!
Client 17 training complete!
Client 2 training complete!
Client 13 training complete!
Accuracy of the global model at round 2: 64.51%

round 2 complete!


COMMUNICATION ROUND 3
Current round's clients: [1, 0, 2, 6, 7]
Client 1 training complete!
Client 0 training complete!
Client 2 training complete!
Client 6 training complete!
Client 7 training complete!
Accuracy of the global model at round 3: 81.92%

round 3 complete!


COMMUNICATION ROUND 4
Current round's clients: [16, 0, 17, 6, 13]
Client 16 training complete!
Client 0 training complete!
Client 17 training complete!
Client 6 training complete