## Import Dependencies  

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

## Preparation
Loading data from **MNist** dataset, and define image embeddings \\

In this case, we have two separated datasets, the unpartitioned and partitioned. We perform traditional training on unpartitioned dataset and perform federated learning on partitioned dataset. The partitioned dataset would be splitted into 10 parts.

In [15]:
# Define transformations to apply to the data
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to tensors
    transforms.Normalize((0.5,), (0.5,))  # Normalize the pixel values to range [-1, 1]
])

# Load the MNIST dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Define data loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)


## Define Model
Define the basic CNN model

In [16]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(64 * 5 * 5, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

### Non-partition

In [17]:
# Train function
def train(model, trainloader, criterion, optimizer, epochs=5):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        print(f'Epoch {epoch + 1}, Loss: {running_loss / len(trainloader)}')

# Test function
def test(model, testloader, verbose=True):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    if verbose: print(f'Accuracy on test set: {100 * correct / total}%')
    return 100 * correct / total

In [18]:
# Train non-partitioned model
non_partitioned_model = CNN()
non_partitioned_optimizer = optim.Adam(non_partitioned_model.parameters(), lr=0.001)
non_partitioned_criterion = nn.CrossEntropyLoss()
print("Training non-partitioned model...")
train(non_partitioned_model, trainloader, non_partitioned_criterion, non_partitioned_optimizer)
print("Testing non-partitioned model...")
test(non_partitioned_model, testloader)

Training non-partitioned model...
Epoch 1, Loss: 0.128456686290695
Epoch 2, Loss: 0.043731860509406154
Epoch 3, Loss: 0.029874081825158404
Epoch 4, Loss: 0.02139184741571565
Epoch 5, Loss: 0.01635035736183151
Testing non-partitioned model...
Accuracy on test set: 98.91%


98.91

### Partitioned

In [22]:
# Train partitioned model (simulate federated learning)
partition_size = len(trainset) // 5
data_partitions = [torch.utils.data.Subset(trainset, range(i * partition_size, (i + 1) * partition_size))
                   for i in range(5)]

partitioned_models = []
partitioned_optimizers = []
for _ in range(5):
    model = CNN()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    print("Training partitioned model...")
    train(model, torch.utils.data.DataLoader(data_partitions[_], batch_size=32, shuffle=True), criterion, optimizer, epochs=8)
    partitioned_models.append(model)
    partitioned_optimizers.append(optimizer)
    
    
# Computing Weight Averaging
partition_accuracy = [test(model, testloader, verbose=False) for model in partitioned_models]
total_accuracy = sum(partition_accuracy)
weights = [accuracy / total_accuracy for accuracy in partition_accuracy]

# Aggregate model updates
print("Aggregating model updates...")
for i in range(1, 5):
    weight = weights[i]
    for params_source, params_target in zip(partitioned_models[i].parameters(), partitioned_models[0].parameters()):
        params_target.data += weight * params_source.data

# # Average aggregated model parameters
# for params_target in partitioned_models[0].parameters():
#     params_target.data /= 5

aggregated_model = partitioned_models[0]

Training partitioned model...
Epoch 1, Loss: 0.34047181243573627
Epoch 2, Loss: 0.08616057464169959
Epoch 3, Loss: 0.057027377796669804
Epoch 4, Loss: 0.03519517294538673
Epoch 5, Loss: 0.02629296962179554
Epoch 6, Loss: 0.017296416286961176
Epoch 7, Loss: 0.012251208646999051
Epoch 8, Loss: 0.011355925985445386
Training partitioned model...
Epoch 1, Loss: 0.3295359623568753
Epoch 2, Loss: 0.07961865293669204
Epoch 3, Loss: 0.04843127875666445
Epoch 4, Loss: 0.03427299424144439
Epoch 5, Loss: 0.02516851408572014
Epoch 6, Loss: 0.016794198129073873
Epoch 7, Loss: 0.011194924220808995
Epoch 8, Loss: 0.015952731387180393
Training partitioned model...
Epoch 1, Loss: 0.317948392957449
Epoch 2, Loss: 0.08486129604062687
Epoch 3, Loss: 0.05290936131449416
Epoch 4, Loss: 0.037477208024744565
Epoch 5, Loss: 0.028545641055641075
Epoch 6, Loss: 0.01572863311515539
Epoch 7, Loss: 0.015899837481566162
Epoch 8, Loss: 0.01115009690691295
Training partitioned model...
Epoch 1, Loss: 0.34192597750326
E

In [23]:
# Test federated model
print("Testing aggregated model...")
test(aggregated_model, testloader)

Testing aggregated model...
Accuracy on test set: 94.85%


94.85

In [21]:
print(partition_accuracy)

[98.06, 98.46, 98.57, 98.03, 98.52]
