# Utils

## Imports

In [4]:
import torch
import torchvision
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import random_split, Dataset, DataLoader, ConcatDataset, Subset
import torch.optim.lr_scheduler as lr_scheduler
from sklearn.model_selection import train_test_split
from torchsummary import summary
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

## Model

In [5]:
class CNN(nn.Module):
  def __init__(self):
      super(CNN, self).__init__()
      self.layer1 = nn.Sequential(nn.Conv2d(3, 32, kernel_size=3),
                                  nn.ReLU(),
                                  nn.MaxPool2d(kernel_size=2, stride=2),
                                  nn.BatchNorm2d(32))
      self.layer2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3),
                                  nn.ReLU(),
                                  nn.BatchNorm2d(64),
                                  nn.MaxPool2d(kernel_size=2, stride=2))
      self.layer3 = nn.Sequential(nn.Linear(2304, 128),
                                  nn.ReLU(),
                                  nn.Dropout(p = 0.5))
      self.fc = nn.Linear(128, 10)

  def forward(self, x):
    x = self.layer1(x)
    x = self.layer2(x)
    x = x.view(x.size(0), -1)
    x = self.layer3(x)
    x = self.fc(x)
    return x


## Datasets

In [6]:
def get_experiment_datasets(dataset_name = 'CIFAR10'):

  if dataset_name == 'CIFAR10':
    transform = transforms.Compose([transforms.
                                  transforms.ToTensor(),
                                  transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

  return train_dataset, test_dataset




def partition_training_dataset(train_dataset, num_clients):
  return torch.utils.data.random_split(train_dataset, [len(train_dataset) // num_clients] * num_clients)



## Training Loops

In [14]:
# Training loop on the client side
def local_training(client_dataloader, ind, model, optimizer, device, local_training_args = dict(), local_logs = False):

    criterion = local_training_args['criterion']

    for epoch in range(local_training_args['num_epochs']):
      model.train()
      loss = 0.0

      for i, (images, labels) in enumerate(client_dataloader):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss += loss.item() #* images.size(0)

      epoch_loss = loss / len(client_dataloader)

      if local_logs:
        print(f"Client {ind} Training: Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")



# Training loop on the server side + aggregation
def federated_training(federated_training_args, local_training_args, server_model, client_data, device):

  clients = [[]]*federated_training_args['num_clients']

  print('Federated Training!')
  for round in tqdm(range(federated_training_args['num_rounds'])):

    print(f"---------- Round {round+1} ----------")
    server_model.train()
    server_model.zero_grad()


    for client_idx in range(federated_training_args['num_clients']):

      client_model = CNN().to(device)
      client_model.load_state_dict(server_model.state_dict())

      optimizer = torch.optim.Adam(client_model.parameters(), **local_training_args['optimizer_args'])
      client_dataloader = DataLoader(client_data[client_idx], batch_size=local_training_args['batch_size'], shuffle=True)

      # Train on the client's local dataset
      local_training(client_dataloader, client_idx+1, client_model,
                      optimizer, device, local_training_args)
      clients[client_idx] = client_model


    server_model = federated_aggregation(server_model, clients, federated_training_args)

  return server_model


def federated_aggregation(server_model, clients, federated_training_args):

  sd_server = server_model.state_dict()
  sd_clients = [cl.state_dict() for cl in clients]
  for key in sd_server:
    for sd_client in sd_clients:
      sd_server[key] += sd_client[key]
    sd_server[key] = (sd_server[key] / federated_training_args['num_clients']).float()

  server_model.load_state_dict(sd_server)

  return server_model



## Evaluations

In [18]:

def evaluate_acc(test_dataloader, model, criterion, device):
  loss_test = []

  with torch.no_grad():
    correct = 0
    total = 0
    for i, (input, target) in enumerate(test_dataloader):

      target = target.to(device)
      input = input.to(device)

      # compute output
      output = server_model(input)
      loss = criterion(output, target)
      loss_test.append(loss.item())

      total += target.size(0)
      _, predicted = torch.max(output.data, 1)
      correct += (predicted == target).sum().item()

    print('Accuracy on the test images: {} %'.format(100 * correct / total))


# Experiments

In [16]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the server model
server_model = CNN().to(device)

# Define the hyperparameters

local_training_args = {'optimizer_args' : {'lr' : 1e-3}, 'num_epochs' : 5, 'batch_size' : 128, 'criterion' : nn.CrossEntropyLoss()}

federated_training_args = {'num_rounds' : 4, 'num_clients' : 5}

train_dataset, test_dataset = get_experiment_datasets('CIFAR10')
client_data = partition_training_dataset(train_dataset, federated_training_args['num_clients'])


Files already downloaded and verified
Files already downloaded and verified


In [17]:

server_model = federated_training(federated_training_args, local_training_args, server_model, client_data, device)

Federated Training!


  0%|          | 0/4 [00:00<?, ?it/s]

---------- Round 1 ----------


 25%|██▌       | 1/4 [01:11<03:33, 71.29s/it]

---------- Round 2 ----------


 50%|█████     | 2/4 [02:24<02:24, 72.41s/it]

---------- Round 3 ----------


 75%|███████▌  | 3/4 [03:43<01:15, 75.24s/it]

---------- Round 4 ----------


100%|██████████| 4/4 [04:57<00:00, 74.34s/it]


In [20]:

test_dataloader = DataLoader(test_dataset, batch_size = 128)

evaluate_acc(test_dataloader, server_model, local_training_args['criterion'], device)

Accuracy on the test images: 67.39 %


# Global Model

In [21]:
lr = 1e-3
num_epochs = 20
batch_size = 128
criterion = nn.CrossEntropyLoss()

global_data = train_dataset
global_dataloader = DataLoader(global_data, batch_size=batch_size, shuffle=True)

global_model = CNN().to(device)
optimizer = torch.optim.Adam(global_model.parameters(), lr = lr)

print("Global Model Training:")
for epoch in tqdm(range(num_epochs)):
  loss_global = 0.0

  for i, (input, target) in enumerate(global_dataloader):

    target = target.to(device)
    input = input.to(device)

    # find loss
    output = global_model(input)
    loss = criterion(output, target)
    loss_global += loss.item()

    # optimizer
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  epoch_loss = loss_global / len(global_dataloader)
  # print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")


Global Model Training:


100%|██████████| 20/20 [04:22<00:00, 13.11s/it]


In [23]:
evaluate_acc(test_dataloader, global_model, local_training_args['criterion'], device)

Accuracy on the test images: 67.34 %
