# Federated Learning Simple Implementation

In [73]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import seaborn as sns
import matplotlib.pyplot as plt

import numpy as np

## MLP Model

In [74]:
class MLP(nn.Module):
    def __init__(self, layers_shape):
        super().__init__()
        self.linears = nn.ModuleList([nn.Linear(layers_shape[i], layers_shape[i+1]) for i in range(len(layers_shape) - 1)]) # Define layers list
        self.relu = nn.ReLU() # activation function
        self.soft = nn.Softmax() # output function

    def forward(self, x):
        x = x.view(x.size(0), -1) # flatten the input image
        for layer in self.linears[:-1]:
            x = self.relu(layer(x))
        # x = self.soft(self.linears[-1](x))
        x = self.linears[-1](x)
        return x

## Relevant Parameters

In [93]:
epoch = 5
comm_cycles = 150
num_clients = 10
sample_size = int(.3 * num_clients) # Use 30% of available clients
net_parameters = [ 28 * 28, # input
                512, 256, 128, 64,
                10 ] #output

## Data Loaders
- Divide the test & training data
- Divide the training data among the clients

In [76]:
# define transformation to apply to each image in the dataset
transform = transforms.Compose([
    transforms.ToTensor(), # convert the image to a PyTorch tensor
    transforms.Normalize((0.5,), (0.5,)) # normalize the image with mean=0.5 and std=0.5
])

# load the MNIST training and testing datasets
train_dataset = datasets.MNIST(root='data/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='data/', train=False, transform=transform, download=True)

# split the training data
train_split = torch.utils.data.random_split(train_dataset, [int(train_dataset.data.shape[0]/num_clients) for i in range(num_clients)])

# create data loaders to load the datasets in batches during training and testing
train_loader = [torch.utils.data.DataLoader(split, batch_size=64, shuffle=True) for split in train_split]
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

## Helper Functions for Federated Training
- `client_update` function train the client model on private client data. This is the local training round that takes place at num_selected clients.
- `server_aggregate` function aggregates the model weights received from every client and updates the global model with the updated weights.

In [77]:
def client_update(  client_model : torch.nn.Module,
                    optimizer,
                    criterion,
                    data_loader : torch.utils.data.DataLoader,
                    device,
                    epoch = 5):
    """
    Train the client model on client data
    """
    client_model.train()
    for e in range(epoch):
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad() # reset the gradients to zero
            output = client_model(images) # forward pass
            loss = criterion(output, labels) # compute the loss
            loss.backward() # compute the gradients
            optimizer.step() # update the parameters

    return loss.item() * images.size(0)

def server_aggregate(global_model : torch.nn.Module, client_models):
    """
    The means of the weights of the client models are aggregated to the global model
    """
    global_dict = global_model.state_dict() # Get a copy of the global model state_dict
    for key in global_dict.keys():
        global_dict[key] = torch.stack([client_models[i].state_dict()[key].float for i in range(len(client_models))],0).mean(0)
    global_model.load_state_dict(global_dict)
    
    # Update the client models using the global model
    for model in client_models:
        model.load_state_dict(global_model.state_dict())

## Test Function
Similar to previous implementation, 

In [78]:

def test(global_model, criterion, test_loader, device):
    """This function test the global model on test data and returns test loss and test accuracy """
    global_model.eval()
    test_loss = 0
    test_acc = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            output = global_model(images) # forward pass
            loss = criterion(output, labels)
            test_loss += loss.item() * images.size(0)# sum up batch loss
            _, pred = torch.max(output, 1) # get the predicted labels
            test_acc += torch.sum(pred == labels.data).item() # compute the testing accuracy

    test_loss /= len(test_loader.dataset)
    test_acc /= len(test_loader.dataset)

    return test_loss, test_acc

## Global & Clients instatiation
Implement the same elements as before, but:
- We need more instances of the model
- An optimizer for each model

In [79]:
global_model = MLP(net_parameters)
client_models = [MLP(net_parameters) for _ in range(num_clients)]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
global_model.to(device)
for model in client_models:
    model.to(device)
    model.load_state_dict(global_model.state_dict())

criterion = nn.CrossEntropyLoss() # computes the cross-entropy loss between the predicted and true labels
optimizers =[optim.Adam(model.parameters(), lr=0.001) for model in client_models]


In [80]:
# initialize lists to store the training and testing losses and accuracies
train_losses = []
test_losses = []
train_accs = []
test_accs = []

for cycle in range(comm_cycles):
    # Select random clients
    client_idx = np.random.permutation(num_clients)[:sample_size]
    train_loss = 0
    for i in range(sample_size):
        train_loss += client_update(client_models[client_idx[i]],
                                    optimizers[client_idx[i]],
                                    criterion, train_loader[client_idx[i]],
                                    device, epoch)
    train_losses.append(train_loss)

    # Aggregate
    server_aggregate(global_model, client_models)
    test_loss, test_acc = test(global_model, client_models)

    test_losses.append(test_loss)
    test_accs.append(test_acc)
    
    print('%d-th round' % r)
    print('average train loss %0.3g | test loss %0.3g | test acc: %0.3f' % (train_loss / sample_size, test_loss, test_acc))
    
    

SyntaxError: incomplete input (749670344.py, line 9)