In [19]:
# %%capture
# !pip install torch torchvision numpy matplotlib scikit-learn pandas

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
from typing import List, Dict, Tuple
from collections import defaultdict
import random
import matplotlib.pyplot as plt
import copy

%matplotlib inline

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Print PyTorch version and GPU availability
print(f"PyTorch Version: {torch.__version__}")
print(f"GPU Available: {torch.cuda.is_available()}")

PyTorch Version: 2.5.1
GPU Available: False


In [20]:
# Data preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load MNIST dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                    download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False,
                                    download=True, transform=transform)

# Create data loaders
train_loader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)

In [21]:
class CNNModel(nn.Module):
    def __init__(self, num_classes=10):
        super(CNNModel, self).__init__()
        # Feature extraction layers
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1, 1),  # 28x28 -> 28x28
            nn.ReLU(),
            nn.MaxPool2d(2, 2),         # 28x28 -> 14x14
            nn.Conv2d(32, 64, 3, 1, 1), # 14x14 -> 14x14
            nn.ReLU(),
            nn.MaxPool2d(2, 2),         # 14x14 -> 7x7
        )
        
        # Fully connected layers
        self.classifier = nn.Sequential(
            nn.Linear(64 * 7 * 7, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 128)
        )
        
        # Output layer
        self.fc_out = nn.Linear(128, num_classes)
        
        # Prototype layer
        self.prototype_layer = nn.Linear(128, num_classes)

    def forward(self, x):
        # Feature extraction
        x = self.features(x)
        x = x.view(x.size(0), -1)
        
        # Get features from classifier
        features = self.classifier(x)
        
        # Generate prototypes
        prototypes = self.prototype_layer(features)
        
        # Get classification output
        output = self.fc_out(features)
        
        return output, prototypes, features

# Initialize device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [22]:
def create_non_iid_data(dataset, num_clients=6, num_shards=200):
    """
    Create non-IID data distribution by dividing data into shards
    and distributing them among clients.
    """
    # Calculate number of items per shard
    num_items = len(dataset) // num_shards
    
    # Create index list and shuffle
    idx_shard = list(range(num_shards))
    random.shuffle(idx_shard)
    
    # Sort data by label
    label_indices = defaultdict(list)
    for idx, label in enumerate(dataset.targets):
        label_indices[label.item()].append(idx)
    
    # Distribute shards to clients
    client_data = [[] for _ in range(num_clients)]
    shards_per_client = num_shards // num_clients
    
    for client_idx in range(num_clients):
        start_shard = client_idx * shards_per_client
        end_shard = start_shard + shards_per_client
        
        for shard_idx in idx_shard[start_shard:end_shard]:
            start_idx = shard_idx * num_items
            end_idx = start_idx + num_items
            client_data[client_idx].extend(range(start_idx, end_idx))
    
    return [torch.utils.data.Subset(dataset, indices) for indices in client_data]

# Create non-IID data splits
client_datasets = create_non_iid_data(trainset, num_clients=6)

In [23]:
class FedProtoClient:
    def __init__(self, model, train_data, test_data, client_id, device, learning_rate=0.01):
        self.model = model.to(device)
        self.train_data = train_data
        self.test_data = test_data
        self.client_id = client_id
        self.device = device
        self.optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
        self.criterion = nn.CrossEntropyLoss()
    
    def train(self, epochs):
        self.model.train()
        train_loader = torch.utils.data.DataLoader(self.train_data, batch_size=32, shuffle=True)
        
        for epoch in range(epochs):
            epoch_loss = 0
            correct = 0
            total = 0
            
            for batch_idx, (data, target) in enumerate(train_loader):
                data, target = data.to(self.device), target.to(self.device)
                
                self.optimizer.zero_grad()
                output, prototypes, _ = self.model(data)
                
                # Compute losses
                cls_loss = self.criterion(output, target)
                proto_loss = self.compute_prototype_loss(prototypes, target)
                total_loss = cls_loss + 0.1 * proto_loss
                
                total_loss.backward()
                self.optimizer.step()
                
                epoch_loss += total_loss.item()
                
                # Calculate accuracy
                pred = output.argmax(dim=1)
                correct += pred.eq(target).sum().item()
                total += target.size(0)
            
            acc = 100. * correct / total
            print(f'Client {self.client_id}, Epoch {epoch+1}: '
                  f'Loss: {epoch_loss/len(train_loader):.4f}, '
                  f'Accuracy: {acc:.2f}%')
    
    def compute_prototype_loss(self, prototypes, targets):
        prototype_loss = 0
        for c in range(10):  # 10 classes for MNIST
            class_mask = (targets == c)
            if class_mask.sum() > 0:
                class_protos = prototypes[class_mask]
                centroid = class_protos.mean(0)
                prototype_loss += ((class_protos - centroid)**2).mean()
        return prototype_loss
    
    def evaluate(self):
        self.model.eval()
        test_loader = torch.utils.data.DataLoader(self.test_data, batch_size=32)
        correct = 0
        total = 0
        
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(self.device), target.to(self.device)
                output, _, _ = self.model(data)
                pred = output.argmax(dim=1)
                correct += pred.eq(target).sum().item()
                total += target.size(0)
                
        accuracy = 100. * correct / total
        return accuracy

In [24]:
class FedProtoServer:
    def __init__(self, global_model, device, num_clients=6):
        self.global_model = global_model.to(device)
        self.device = device
        self.num_clients = num_clients
    
    def aggregate_models(self, client_models):
        """
        Aggregate client models using FedAvg algorithm
        """
        global_state_dict = self.global_model.state_dict()
        
        # Aggregate parameters
        for key in global_state_dict.keys():
            global_state_dict[key] = torch.stack([
                client_models[i].state_dict()[key].to(self.device) 
                for i in range(len(client_models))
            ]).mean(0)
        
        self.global_model.load_state_dict(global_state_dict)
    
    def distribute_model(self):
        """
        Create a copy of the global model for distribution
        """
        return copy.deepcopy(self.global_model)

In [25]:
# Initialize global model
global_model = CNNModel().to(device)

# Create server
server = FedProtoServer(global_model, device)

# Initialize clients
clients = [
    FedProtoClient(
        copy.deepcopy(global_model),
        client_datasets[i],
        testset,
        client_id=i,
        device=device
    ) for i in range(6)
]

# Training parameters
num_rounds = 10
local_epochs = 5
global_accuracies = []

# Training loop
for round_idx in range(num_rounds):
    print(f"\nRound {round_idx + 1}")
    
    # Distribute global model to clients
    for client in clients:
        client.model = copy.deepcopy(server.global_model)
        
    # Local training
    for client in clients:
        client.train(local_epochs)
        
    # Aggregate models
    client_models = [client.model for client in clients]
    server.aggregate_models(client_models)
    
    # Evaluate global model
    accuracies = []
    for client in clients:
        client.model = copy.deepcopy(server.global_model)
        acc = client.evaluate()
        accuracies.append(acc)
        print(f"Client {client.client_id} Accuracy: {acc:.2f}%")
    
    avg_accuracy = sum(accuracies) / len(accuracies)
    global_accuracies.append(avg_accuracy)
    print(f"Round {round_idx + 1} Average Accuracy: {avg_accuracy:.2f}%")


Round 1
Client 0, Epoch 1: Loss: 2.3058, Accuracy: 9.49%
Client 0, Epoch 2: Loss: 2.3057, Accuracy: 9.42%
Client 0, Epoch 3: Loss: 2.3051, Accuracy: 9.92%
Client 0, Epoch 4: Loss: 2.3055, Accuracy: 9.47%
Client 0, Epoch 5: Loss: 2.3054, Accuracy: 9.26%
Client 1, Epoch 1: Loss: 2.3056, Accuracy: 9.09%
Client 1, Epoch 2: Loss: 2.3048, Accuracy: 10.68%
Client 1, Epoch 3: Loss: 2.3053, Accuracy: 9.62%
Client 1, Epoch 4: Loss: 2.3050, Accuracy: 9.82%
Client 1, Epoch 5: Loss: 2.3050, Accuracy: 9.94%
Client 2, Epoch 1: Loss: 2.3058, Accuracy: 9.54%
Client 2, Epoch 2: Loss: 2.3065, Accuracy: 9.08%
Client 2, Epoch 3: Loss: 2.3062, Accuracy: 9.47%
Client 2, Epoch 4: Loss: 2.3065, Accuracy: 9.51%
Client 2, Epoch 5: Loss: 2.3068, Accuracy: 9.39%
Client 3, Epoch 1: Loss: 2.3044, Accuracy: 9.72%
Client 3, Epoch 2: Loss: 2.3054, Accuracy: 9.58%
Client 3, Epoch 3: Loss: 2.3050, Accuracy: 9.06%
Client 3, Epoch 4: Loss: 2.3048, Accuracy: 9.20%
Client 3, Epoch 5: Loss: 2.3049, Accuracy: 9.99%
Client 4, 

KeyboardInterrupt: 

In [None]:
# Plot the accuracy progression
plt.figure(figsize=(10, 6))
plt.plot(range(1, len(global_accuracies) + 1), global_accuracies, marker='o')
plt.title('Global Model Accuracy over Communication Rounds')
plt.xlabel('Communication Round')
plt.ylabel('Average Accuracy (%)')
plt.grid(True)
plt.show()

# Save the final model
torch.save(server.global_model.state_dict(), 'fedproto_mnist_model.pth')