In [1]:
# This is Convolution Neural Netwook (CNN) with the Federated Averaging Optimization Algorithm with
# 10 Virtual clients have been selected to train the models at each round with a limited client drop-rate
# hence we should expect possibly less than 10 clients chosen at random at any given round to train.
# These are under the communication of an virtually simulated orchastrated server.

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import syft as sy
import copy
import numpy as np
import time

import sys
sys.path.append('../')
from FLDataset.FLDataset import load_dataset
from FLDataset.FLDataset import getActualImgs
from FLDataset.utils import averageModels, averageGradients

import warnings
warnings.filterwarnings('ignore')

In [3]:
class Arguments():
    def __init__(self):
        self.images = 60000
        self.clients = 10
        self.epochs = 6
        self.local_batches = 64
        self.lr = 0.01
        self.C = 0.8
        self.drop_rate = 0.2
        self.torch_seed = 0
        self.log_interval = 10
        self.iid = 'iid'
        self.split_size = int(self.images / self.clients)
        self.samples = self.split_size / self.images 
        self.use_cuda = False
        self.save_model = False

args = Arguments()

use_cuda = args.use_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

In [4]:
# Create a virtual Server called hook
hook = sy.TorchHook(torch)

# Create virtual workers for clients
clients = []

for i in range(args.clients):
    clients.append({'hook': sy.VirtualWorker(hook, id="client{}".format(i+1))})

# Load dataset and split between clients
global_train, global_test, train_group, test_group = load_dataset(args.clients, args.iid)

for inx, client in enumerate(clients):
# Use get() to safely access the 'trainset' and 'testset' keys, setting them to None if they don't exist
    client['trainset'] = getActualImgs(global_train, list(train_group[inx])[:200], args.local_batches)
    client['testset'] = getActualImgs(global_test, list(test_group[inx]), args.local_batches)

    # Ensure that 'trainset' and 'testset' keys are not None and have valid data
    if client['trainset'] is None or client['testset'] is None:
        raise ValueError(f"Client {client['hook'].id} does not have valid train or test data.")

    # Calculate the total samples across all clients with valid 'trainset' data
    total_trainset_samples = sum(len(c['trainset']) for c in clients if 'trainset' in c and c['trainset'] is not None)

    # Calculate the proportion of samples for the client based on their trainset length
    client['samples'] = len(client['trainset']) / total_trainset_samples
    
# Define the global test data loader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

global_test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
global_test_loader = DataLoader(global_test_dataset, batch_size=args.local_batches, shuffle=True)

In [5]:
# Define CNN model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        # First convolutional layer
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        
        # Second convolutional layer
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)

        # First fully connected layer
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        
        # Second fully connected layer
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [6]:
def train(args, clients, device, epoch):
    client['model'].train()
    
    for batch_idx, (data, target) in enumerate(client['trainset']):

        data = data.send(client['hook'])
        target = target.send(client['hook'])
        client['model'].send(data.location)

        data, target = data.to(device), target.to(device)
        client['optim'].zero_grad()
        output = client['model'](data)
        loss = F.nll_loss(output, target)
        loss.backward()
        client['model'].get() 

        if batch_idx % args.log_interval == 0:
            loss = loss.get() 
            progress = (batch_idx+1) * args.local_batches
            print('Model {} Train Epoch: {} \tLoss: {:.6f}'.format(
                client['hook'].id, epoch, loss))   
      

In [7]:
# This is the test function that will be implemented
def test(args, model, device, test_loader, name):
    model.eval()   
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss

            pred = output.argmax(1, keepdim=True) # get the index of the max log-probability 
            correct += pred.eq(target.view_as(pred)).sum().item()
            
    print(f"\nTesting {name} model with {len(test_loader.dataset)} samples")

    test_loss /= len(test_loader.dataset)
    accuracy = 100 * (1 - (correct / len(test_loader.dataset)))
    print('\nTest set: Average loss for {} model: {:.2f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        name, test_loss, correct, len(test_loader.dataset), accuracy))

    print(f"\n{name} Model Prediction Error: {(100-accuracy):.2f}%\n")


In [8]:
class FedSGDOptim(optim.Optimizer):
    def __init__(self, params, lr=args.lr):
        defaults = dict(lr=lr)
        super(FedSGDOptim, self).__init__(params, defaults)
        
    def step(self, grad_model=None, closure = None):
        loss = None
        if closure is not None:
            loss = closure()
        for group in self.param_groups:
            lr = group['lr']
            for p in zip(group['params'], list(grad_model.parameters())):
                if p[0].grad is None:
                    continue
                p[0].data.add_(-group['lr'], p[1].grad.data.clone())
                
        return loss

In [9]:
torch.manual_seed(args.torch_seed)
global_model = Net().to(device)
optimizer = FedSGDOptim(global_model.parameters(), lr=args.lr)
grad_model = Net().to(device)


for client in clients:
    torch.manual_seed(args.torch_seed)
    client['model'] = Net().to(device)
    client['optim'] = optim.SGD(client['model'].parameters(), lr=args.lr)

    
for epoch in range(1, args.epochs + 1):
    print(f"FEDERATED LEARNING MODEL ROUND: {epoch}")
    
    for client in clients:
        train(args, client, device, epoch)
        
    grad_model = averageGradients(global_model, clients)
    
    # Testing the average model
    test(args, global_model, device, global_test_loader, 'Global')
    optimizer.step(grad_model) # Call the optimizer 
    test(args, global_model, device, global_test_loader, 'Global') # Check output after the optimizer call
    
    for client in clients:
        client['model'].load_state_dict(global_model.state_dict())
    
if (args.save_model):
    torch.save(global_model.state_dict(), "FedSGD.pt")

FEDERATED LEARNING MODEL ROUND: 1
Model client1 Train Epoch: 1 	Loss: 2.331440
Model client2 Train Epoch: 1 	Loss: 2.344059
Model client3 Train Epoch: 1 	Loss: 2.340851
Model client4 Train Epoch: 1 	Loss: 2.307723
Model client5 Train Epoch: 1 	Loss: 2.301387
Model client6 Train Epoch: 1 	Loss: 2.282655
Model client7 Train Epoch: 1 	Loss: 2.302820
Model client8 Train Epoch: 1 	Loss: 2.283996
Model client9 Train Epoch: 1 	Loss: 2.314949
Model client10 Train Epoch: 1 	Loss: 2.305409

Testing Global model with 10000 samples

Test set: Average loss for Global model: 2.31, Accuracy: 1004/10000 (89.96%)


Global Model Prediction Error: 10.04%


Testing Global model with 10000 samples

Test set: Average loss for Global model: 2.30, Accuracy: 1095/10000 (89.05%)


Global Model Prediction Error: 10.95%

FEDERATED LEARNING MODEL ROUND: 2
Model client1 Train Epoch: 2 	Loss: 2.307828
Model client2 Train Epoch: 2 	Loss: 2.320753
Model client3 Train Epoch: 2 	Loss: 2.319476
Model client4 Train Epoch: