# Importing libraries

In [1]:
import torch 
from torch import nn
from torch import optim 
from torchvision import datasets, transforms

In [2]:
import numpy as np
import copy 
import random

## Configuration

In [3]:
NUM_CLIENTS = 5
BATCH_SIZE = 64
LOCAL_EPOCHS = 2
FED_ROUNDS = 10
LR = 0.01
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Simple NN

In [4]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128,10)
        )
    
    def forward(self, x):
        return self.net(x)

## Load Dataset

In [5]:
transform = transforms.Compose([
    transforms.ToTensor()
])

In [6]:
train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)

100%|██████████| 9.91M/9.91M [00:11<00:00, 877kB/s] 
100%|██████████| 28.9k/28.9k [00:00<00:00, 134kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.25MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 2.03MB/s]


In [7]:
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=128, shuffle=False
)

In [8]:
train_dataset

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
           )

In [9]:
test_dataset

Dataset MNIST
    Number of datapoints: 10000
    Root location: ./data
    Split: Test
    StandardTransform
Transform: Compose(
               ToTensor()
           )

In [10]:
test_loader

<torch.utils.data.dataloader.DataLoader at 0x2d21999dd90>

## Split Data to Clients

In [11]:
def split_dataset(dataset, num_clients):
    data_per_client = len(dataset) // num_clients
    indices = np.random.permutation(len(dataset))
    client_data = []

    for i in range(num_clients):
        start = i * data_per_client
        end = start + data_per_client
        subset = torch.utils.data.Subset(dataset, indices[start:end])
        loader = torch.utils.data.DataLoader(
            subset, batch_size=BATCH_SIZE, shuffle=True
        )
        client_data.append(loader)
    return client_data

In [12]:
client_loader = split_dataset(train_dataset, NUM_CLIENTS)

In [14]:
for i in client_loader:
    print(len(i.dataset))

12000
12000
12000
12000
12000


## Client Training

In [15]:
def train_client(model, loader):
    model.train()
    optimizer = optim.SGD(model.parameters(), lr=LR)
    criterion = nn.CrossEntropyLoss()

    for _ in range(LOCAL_EPOCHS):
        for data, target in loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
    return model.state_dict()

## Fed Average

In [16]:
def federated_average(weights):
    avg_weights = copy.deepcopy(weights[0])
    print("Average weights: ", avg_weights)

    for key in avg_weights.keys():
        for i in range(1, len(weights)):
            avg_weights[key] += weights[i][key]
        avg_weights[key] = avg_weights[key] / len(weights)

    return avg_weights

## Evaluation

In [17]:
def evaluate(model):
    model.eval()
    correct, total = 0, 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            output = model(data)
            preds = torch.argmax(output, dim=1)
            correct += (preds == target).sum().item()
            total += target.size(0)
    accuracy = 100 * correct / total 
    return accuracy

## Federated Training loop

In [18]:
global_model = SimpleNN().to(DEVICE)

for round_num in range(FED_ROUNDS):
    print(f"Round {round_num+1}/{FED_ROUNDS}")
    client_weights = []

    for client_id in range(NUM_CLIENTS):
        local_model = copy.deepcopy(global_model)
        local_model.to(DEVICE)

        weights = train_client(local_model, client_loader[client_id])
        client_weights.append(weights)

    global_weights = federated_average(client_weights)
    global_model.load_state_dict(global_weights)

    acc = evaluate(global_model)
    print(f"Test Accuracy after round {round_num+1}: {acc:.2f}%\n")

print("Training Completed")

Round 1/10
Average weights:  OrderedDict({'net.1.weight': tensor([[-0.0057, -0.0183,  0.0232,  ...,  0.0197, -0.0251, -0.0137],
        [ 0.0263, -0.0041,  0.0186,  ...,  0.0032,  0.0221, -0.0117],
        [-0.0166,  0.0199,  0.0273,  ...,  0.0026,  0.0233, -0.0056],
        ...,
        [-0.0225, -0.0309, -0.0198,  ..., -0.0107, -0.0106, -0.0213],
        [-0.0155, -0.0158, -0.0323,  ...,  0.0128,  0.0333, -0.0135],
        [-0.0203,  0.0062, -0.0311,  ...,  0.0357,  0.0227,  0.0328]]), 'net.1.bias': tensor([ 0.0126,  0.0029, -0.0377,  0.0453,  0.0334,  0.0109,  0.0075, -0.0035,
         0.0332,  0.0072,  0.0438, -0.0044, -0.0238, -0.0327, -0.0156,  0.0237,
        -0.0117,  0.0091, -0.0072, -0.0277,  0.0288,  0.0126, -0.0018,  0.0389,
         0.0339, -0.0050,  0.0109, -0.0301, -0.0255,  0.0161,  0.0080,  0.0037,
         0.0319,  0.0215,  0.0207, -0.0026,  0.0044, -0.0115, -0.0275,  0.0083,
         0.0233, -0.0192, -0.0063,  0.0290, -0.0052, -0.0073,  0.0124,  0.0101,
         0.01