In [None]:
!pip install torchvision==0.15.2+cu118 -f https://download.pytorch.org/whl/torch_stable.html







In [None]:
import torch
import sys
!pip install --upgrade --force-reinstall numpy pillow torchvision
sys.path.insert(0, "/usr/local/lib/python3.11/dist-packages")
import torchvision
from torchvision.datasets import MNIST
from torch.utils.data import random_split, DataLoader
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import os
import matplotlib.pyplot as plt
import numpy as np

os.environ["TORCH_COMPILE_MODE"] = "force_compile_off"

In [None]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_dataset = MNIST('./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = MNIST('./data', train=False, download=True, transform=transforms.ToTensor())

train_dataset, dev_dataset = random_split(train_dataset, [50000, 10000])

In [None]:
batch_size = 64
num_clients = 4

### Federated Model Architecture and Client Architecture

In [None]:
# Define FederatedNet model architecture
class FederatedNet(nn.Module):
    def __init__(self):
        super(FederatedNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

class Client:
    def __init__(self, client_id, dataset):
        self.client_id = client_id
        self.dataset = dataset

    def train(self, global_model, epochs, lr):
        model = FederatedNet().to(device)
        model.load_state_dict(global_model.state_dict())
        optimizer = optim.Adam(model.parameters(), lr=lr)

        train_loader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True)

        for epoch in range(epochs):
            model.train()
            running_loss = 0.0
            for data, target in train_loader:
                data, target = data.to(device), target.to(device)
                optimizer.zero_grad()
                output = model(data)
                loss = F.nll_loss(output, target)
                loss.backward()
                optimizer.step()
                running_loss += loss.item() * data.size(0)
            print(f'Client {self.client_id}, Epoch {epoch+1}, Loss: {running_loss / len(train_loader.dataset)}')

        return model.state_dict()

In [None]:
# Initialize global model and randomize its parameters
global_model = FederatedNet().to(device)
for param in global_model.parameters():
    param.data = torch.randn_like(param)

client_datasets = random_split(train_dataset, [len(train_dataset) // num_clients] * num_clients)

In [None]:
# client instances
clients = [Client(client_id=i, dataset=client_datasets[i]) for i in range(num_clients)]

# Federated learning parameters
epochs_per_round = 5
learning_rate = 0.0001
rounds = 10

In [None]:
for round_num in range(rounds):
    print(f'Starting round {round_num + 1}...')

    # Aggregate client updates
    new_global_state_dict = {}
    for client in clients:
        client_model_state_dict = client.train(global_model, epochs_per_round, learning_rate)

        for key in client_model_state_dict:
            if key in new_global_state_dict:
                new_global_state_dict[key] += client_model_state_dict[key] / num_clients
            else:
                new_global_state_dict[key] = client_model_state_dict[key] / num_clients

    global_model.load_state_dict(new_global_state_dict)

    # Evaluation on training and validation sets
    def evaluate(model, dataset):
        model.eval()
        dataloader = DataLoader(dataset, batch_size=batch_size)
        total_loss = 0.0
        total_correct = 0
        with torch.no_grad():
            for data, target in dataloader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss = F.nll_loss(output, target, reduction='sum').item()
                total_loss += loss
                pred = output.argmax(dim=1, keepdim=True)
                total_correct += pred.eq(target.view_as(pred)).sum().item()
        return total_loss / len(dataset), total_correct / len(dataset)

    train_loss, train_acc = evaluate(global_model, train_dataset)
    dev_loss, dev_acc = evaluate(global_model, dev_dataset)
    print(f'\nRound {round_num + 1}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Dev Loss: {dev_loss:.4f}, Dev Acc: {dev_acc:.4f}\n')

In [None]:
# Evaluate final model
test_loss, test_acc = evaluate(global_model, test_dataset)
print(f'Final Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')

In [None]:
torch.save(global_model.state_dict(), f"global_model.pth")

In [None]:
# Show actual vs predicted labels using the global model on a client's data
import matplotlib.pyplot as plt

sample_loader = DataLoader(clients[0].dataset, batch_size=5, shuffle=True)
images, labels = next(iter(sample_loader))
images, labels = images.to(device), labels.to(device)

global_model.eval()
with torch.no_grad():
    outputs = global_model(images)
    preds = outputs.argmax(dim=1)

fig, axes = plt.subplots(1, 5, figsize=(12, 3))
for idx, ax in enumerate(axes):
    ax.imshow(images[idx][0].cpu(), cmap="gray")
    ax.set_title(f"True: {labels[idx].item()}\nPred: {preds[idx].item()}")
    ax.axis("off")
plt.suptitle("Client 0: Actual vs Predicted")
plt.show()
