In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import copy

In [3]:
class SampleNet(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.input_layer = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3) # 28*28*1 -> 26*26*4
        self.pooling = nn.MaxPool2d(kernel_size=2) 
        self.activation = nn.ReLU()
        self.conv = nn.Conv2d(in_channels=4, out_channels=16, kernel_size=3)
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(in_features=11*11*16, out_features=512)
        self.dropout = nn.Dropout(0.2)
        self.output = nn.Linear(in_features=512, out_features=10)
    
    def forward(self, x):
        x = self.pooling(self.activation(self.input_layer(x))) # 28 * 28 * 1 -> 26 * 26 * 4 -> 13 * 13 * 4
        x = self.activation(self.conv(x)) # 13 * 13 * 4 -> 11 * 11 * 16
        return self.output(self.dropout(self.linear(self.flatten(x)))) # 11 * 11 * 16 -> 512 -> 10
        

In [4]:
def evaluate_model(model, test_loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)
    accuracy = 100 * correct / total
    return accuracy

def client_update(model, data, target, lr=0.01):
    model = copy.deepcopy(model)
    model.train()
    
    optimizer = optim.SGD(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    
    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()

    return model.state_dict()


In [5]:
def average_weights(w_list):
    avg_w = copy.deepcopy(w_list[0])
    for key in avg_w.keys():
        for i in range(1, len(w_list)):
            avg_w[key] += w_list[i][key]
        avg_w[key] = torch.div(avg_w[key], len(w_list))
    return avg_w


In [11]:
global_model = SampleNet()
num_rounds = 500
num_clients = 3

In [12]:
transform = transforms.ToTensor()
mnist_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

client_data_indices = np.array_split(np.arange(len(mnist_data)), 3)
client_loaders = [DataLoader(Subset(mnist_data, idx), batch_size=32, shuffle=True) for idx in client_data_indices]
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)


In [14]:
for round in range(num_rounds):
    print(f"\n--- Round {round+1} ---")
    local_weights = []

    for idx in range(num_clients):
        client = iter(client_loaders[idx])
        try:
            data, target = next(client)
        except StopIteration:
            continue
        
        # print(f"Client: {idx + 1}, Target: {target}")
        updated = client_update(global_model, data, target)
        local_weights.append(updated)
        
    if local_weights:
        global_weights = average_weights(local_weights)
        global_model.load_state_dict(global_weights)
    
    acc = evaluate_model(global_model, test_loader)
    print(f"Global Test Accuracy: {acc:.2f}%")


--- Round 1 ---
Global Test Accuracy: 12.50%

--- Round 2 ---
Global Test Accuracy: 12.30%

--- Round 3 ---
Global Test Accuracy: 12.35%

--- Round 4 ---
Global Test Accuracy: 13.97%

--- Round 5 ---
Global Test Accuracy: 14.96%

--- Round 6 ---
Global Test Accuracy: 14.53%

--- Round 7 ---
Global Test Accuracy: 13.98%

--- Round 8 ---
Global Test Accuracy: 13.25%

--- Round 9 ---
Global Test Accuracy: 14.09%

--- Round 10 ---
Global Test Accuracy: 13.96%

--- Round 11 ---
Global Test Accuracy: 14.72%

--- Round 12 ---
Global Test Accuracy: 14.89%

--- Round 13 ---
Global Test Accuracy: 15.59%

--- Round 14 ---
Global Test Accuracy: 15.65%

--- Round 15 ---
Global Test Accuracy: 16.02%

--- Round 16 ---
Global Test Accuracy: 16.15%

--- Round 17 ---
Global Test Accuracy: 16.99%

--- Round 18 ---
Global Test Accuracy: 16.29%

--- Round 19 ---
Global Test Accuracy: 16.43%

--- Round 20 ---
Global Test Accuracy: 14.01%

--- Round 21 ---
Global Test Accuracy: 13.90%

--- Round 22 ---
Glob