In [1]:
%load_ext autoreload
%autoreload 2

import copy, os, socket, sys, time
from pathlib import Path
from tqdm import tqdm

import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../")))
from libs import data, fl, model

In [2]:
class FedArgs():
    def __init__(self):
        self.name = "client-x"
        self.num_clients = 50
        self.epochs = 51
        self.local_rounds = 1
        self.client_batch_size = 32
        self.test_batch_size = 128
        self.learning_rate = 1e-4
        self.weight_decay = 1e-5
        self.cuda = False
        self.seed = 1
        self.dataset = "mnist"
        self.model = model.ModelMNIST()
        self.train_func = fl.train_model
        self.eval_func = fl.evaluate
        self.tb = SummaryWriter('./../out/runs/fl/client-run', comment="fl")
        
fedargs = FedArgs()

In [3]:
#fedargs.name = "client-1"
project = 'fl-kafka-client'
name = 'fedavg-cnn-mnist-na-' + fedargs.name
fedargs.tb = SummaryWriter('../out/runs/' + project + '/' + name, comment="fl")
fedargs.num_clients = 1

In [4]:
use_cuda = fedargs.cuda and torch.cuda.is_available()
torch.manual_seed(fedargs.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

In [5]:
host = socket.gethostname()
clients = [host + ": " + fedargs.name]

In [6]:
# Initialize Global and Client models
global_model = copy.deepcopy(fedargs.model)
# Load Data to clients
train_loader, test_loader = data.load_dataset(fedargs.dataset, fedargs.client_batch_size)

client_details = {"name": clients[0],
                  "train_loader": train_loader,
                  "test_loader": test_loader,
                  "model": copy.deepcopy(global_model),
                  "model_update": None}

In [7]:
def process(client, epoch, model, train_loader, test_loader, fedargs, device):
    print("Epoch: {}, Processing Client {}".format(epoch, client))

    # Train    
    model_update, model, loss = fedargs.train_func(model, train_loader, 
                                                   fedargs.learning_rate,
                                                   fedargs.weight_decay,
                                                   fedargs.local_rounds, device)

    epoch = epoch + 1

    # Test, Plot and Log
    test_output = fedargs.eval_func(model, test_loader, device)
    fedargs.tb.add_scalar("Accuracy/" + client, test_output["accuracy"], epoch)
    fedargs.tb.add_scalar("Test Loss/" + client, test_output["test_loss"], epoch)

    return model

In [None]:
# Federated Training
for epoch in tqdm(range(fedargs.epochs)):
    print("Federated Training Epoch {} of {}".format(epoch, fedargs.epochs))

    client_details['model'] = process(client_details['name'], epoch, 
                                              client_details['model'],
                                              client_details['train_loader'],
                                              client_details['test_loader'],
                                              fedargs, device)

  0%|          | 0/51 [00:00<?, ?it/s]

Federated Training Epoch 0 of 51
Epoch: 0, Processing Client bladecluster.iitp.org: client-x


  2%|▏         | 1/51 [00:42<35:04, 42.09s/it]

Federated Training Epoch 1 of 51
Epoch: 1, Processing Client bladecluster.iitp.org: client-x


  4%|▍         | 2/51 [01:23<33:55, 41.55s/it]

Federated Training Epoch 2 of 51
Epoch: 2, Processing Client bladecluster.iitp.org: client-x


  6%|▌         | 3/51 [02:06<34:00, 42.52s/it]

Federated Training Epoch 3 of 51
Epoch: 3, Processing Client bladecluster.iitp.org: client-x


  8%|▊         | 4/51 [02:50<33:40, 42.99s/it]

Federated Training Epoch 4 of 51
Epoch: 4, Processing Client bladecluster.iitp.org: client-x


 10%|▉         | 5/51 [03:33<32:48, 42.80s/it]

Federated Training Epoch 5 of 51
Epoch: 5, Processing Client bladecluster.iitp.org: client-x


 12%|█▏        | 6/51 [04:16<32:13, 42.97s/it]

Federated Training Epoch 6 of 51
Epoch: 6, Processing Client bladecluster.iitp.org: client-x
