# Federated Learning: Local context

In [1]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, SubsetRandomSampler
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"

## 1: MNIST Dataset

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)

indices = torch.randperm(len(dataset))
subset1_indices = indices[:600]
subset2_indices = indices[600:1200]

subset1 = DataLoader(dataset, sampler=SubsetRandomSampler(subset1_indices), batch_size=32)
subset2 = DataLoader(dataset, sampler=SubsetRandomSampler(subset2_indices), batch_size=32)

## 2: Definition of a simple CNN

In [3]:
import torch.nn as nn
import torch.nn.functional as F

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

## 3: Model parameter averaging function

In [4]:
def average_model_parameters(models, average_weight):
    averaged_params = {}
    for key in models[0].state_dict().keys():
        averaged_params[key] = sum(weight * models[i].state_dict()[key] for i, weight in enumerate(average_weight))
    return averaged_params

## 4: Update of Model Parameters

In [5]:
def update_model(model, averaged_params):
    model.load_state_dict(averaged_params)

## 5: Federated Training Algorithm implementation

In [6]:
def train_model(model, data_loader, optimizer, criterion, epochs=1):
    model.train()
    for epoch in range(epochs):
        for inputs, targets in data_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

def federated_training(model1, model2, subset1, subset2, epochs, common_params=True):
    optimizer1 = torch.optim.Adam(model1.parameters())
    optimizer2 = torch.optim.Adam(model2.parameters())
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        train_model(model1, subset1, optimizer1, criterion)
        train_model(model2, subset2, optimizer2, criterion)

        if common_params:
            avg_params = average_model_parameters([model1, model2], [0.5, 0.5])
            update_model(model1, avg_params)
            update_model(model2, avg_params)

## 6: Training without common parameters

In [7]:
def evaluate_model(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in data_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    return correct / total

In [8]:
model1 = CNN().to(device)
model2 = CNN().to(device)
dataloader = DataLoader(dataset, batch_size=32)

federated_training(model1, model2, subset1, subset2, 10, common_params=False)
print("Training without common parameters")
print("Model 1 accuracy: ", evaluate_model(model1, dataloader))
print("Model 2 accuracy: ", evaluate_model(model2, dataloader))

Training without common parameters
Model 1 accuracy:  0.8952166666666667
Model 2 accuracy:  0.8978166666666667


## 7: Training with common parameters

In [9]:
model1 = CNN().to(device)
model2 = CNN().to(device)

federated_training(model1, model2, subset1, subset2, 10, common_params=True)
print("Training with common parameters")
print("Model 1 accuracy: ", evaluate_model(model1, dataloader))
print("Model 2 accuracy: ", evaluate_model(model2, dataloader))

Training with common parameters
Model 1 accuracy:  0.9102833333333333
Model 2 accuracy:  0.9102833333333333


## 8: Impact of batch size on accuracy

In [10]:
for batch_size in [64, 32, 16, 8]:
    subset1 = DataLoader(dataset, sampler=SubsetRandomSampler(subset1_indices), batch_size=batch_size)
    subset2 = DataLoader(dataset, sampler=SubsetRandomSampler(subset2_indices), batch_size=batch_size)
    model1 = CNN().to(device)
    model2 = CNN().to(device)

    federated_training(model1, model2, subset1, subset2, 10, common_params=True)
    
    accuracy = evaluate_model(model1, DataLoader(dataset, batch_size=batch_size))
    print(f"Batch Size {batch_size}, accuracy: {accuracy}")

Batch Size 64, accuracy: 0.8915166666666666
Batch Size 32, accuracy: 0.91775
Batch Size 16, accuracy: 0.9244333333333333
Batch Size 8, accuracy: 0.9345833333333333


## 9: Repeat Experiments on CIFAR-10

In [11]:
transform_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

cifar_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_cifar)

cifar_indices = torch.randperm(len(cifar_dataset))
cifar_subset1_indices = cifar_indices[:600]
cifar_subset2_indices = cifar_indices[600:1200]

batch_sizes = [64, 32, 16, 8]

for batch_size in batch_sizes:
    subset1 = DataLoader(dataset, sampler=SubsetRandomSampler(cifar_subset1_indices), batch_size=batch_size)
    subset2 = DataLoader(dataset, sampler=SubsetRandomSampler(cifar_subset2_indices), batch_size=batch_size)
    model1 = CNN().to(device)
    model2 = CNN().to(device)

    federated_training(model1, model2, subset1, subset2, 10, common_params=True)
    
    accuracy = evaluate_model(model1, DataLoader(dataset, batch_size=batch_size))
    print(f"Batch Size {batch_size}, accuracy: {accuracy}")

Files already downloaded and verified
Batch Size 64, accuracy: 0.90405
Batch Size 32, accuracy: 0.9110833333333334
Batch Size 16, accuracy: 0.9192666666666667
Batch Size 8, accuracy: 0.9361666666666667
