In [17]:
import torch
import torchvision
import copy
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision import datasets
import numpy as np


In [18]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(128 * 4 * 4, 256)
        self.fc2 = nn.Linear(256, 10)
        
    def forward(self, xb):
        xb = self.pool(F.relu(self.conv1(xb)))
        xb = self.pool(F.relu(self.conv2(xb)))
        xb = self.pool(F.relu(self.conv3(xb)))
        xb = xb.view(-1, 128 * 4 * 4)  # Flatten the tensor
        xb = F.relu(self.fc1(xb))
        xb = self.fc2(xb)
        return xb


In [19]:
def load_datasets():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    train_dataset = datasets.CIFAR10(root='data/', train=True, transform=transform, download=True)
    test_dataset = datasets.CIFAR10(root='data/', train=False, transform=transform, download=True)
    
    return train_dataset, test_dataset
# Split dataset among clients
def split_data(train_dataset, num_clients):
    length = len(train_dataset) // num_clients
    client_datasets = random_split(train_dataset, [length]*num_clients)
    return client_datasets

In [20]:
def client_update(client_model, optimizer, train_loader, epochs=1):
    """Train the client model on the local dataset and return loss and accuracy."""
    client_model.train()
    criterion = nn.CrossEntropyLoss()
    total_loss = 0
    correct = 0
    total = 0
    
    for epoch in range(epochs):
        for images, labels in train_loader:
            optimizer.zero_grad()
            output = client_model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    avg_loss = total_loss / len(train_loader)
    accuracy = 100 * correct / total
    
    return client_model.state_dict(), avg_loss, accuracy


In [21]:
def server_aggregate(global_model, client_weights):
    """Aggregate the client models into the global model."""
    global_dict = global_model.state_dict()
    
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_weights[i][k].float() for i in range(len(client_weights))], 0).mean(0)
    
    global_model.load_state_dict(global_dict)
    return global_model


In [22]:
def federated_learning(train_dataset, num_clients=5, num_rounds=5, epochs=1):
    # Initialize global model
    global_model = SimpleCNN()
    
    # Split the dataset between clients
    client_datasets = split_data(train_dataset, num_clients)
    
    # Create data loaders for each client
    client_loaders = [DataLoader(client_datasets[i], batch_size=32, shuffle=True) for i in range(num_clients)]
    
    for round in range(num_rounds):
        print(f"Round {round+1}/{num_rounds}")
        
        client_weights = []
        client_losses = []
        client_accuracies = []
        
        # Train each client model
        for client_idx in range(num_clients):
            client_model = copy.deepcopy(global_model)
            optimizer = optim.SGD(client_model.parameters(), lr=0.01)
            
            # Update client model and get loss and accuracy
            client_weight, client_loss, client_accuracy = client_update(client_model, optimizer, client_loaders[client_idx], epochs)
            client_weights.append(client_weight)
            client_losses.append(client_loss)
            client_accuracies.append(client_accuracy)
        
        # Aggregate client models
        global_model = server_aggregate(global_model, client_weights)
        
        # Calculate and print the average loss and accuracy for the round
        avg_loss = np.mean(client_losses)
        avg_accuracy = np.mean(client_accuracies)
        
        print(f"Average Loss: {avg_loss:.4f}, Average Accuracy: {avg_accuracy:.2f}%")
    
    return global_model


In [23]:
def test(global_model, test_loader):
    global_model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            outputs = global_model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    print(f'Test Accuracy: {100 * correct / total:.2f}%')


In [24]:
if __name__ == "__main__":
    # Load datasets
    train_dataset, test_dataset = load_datasets()
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    # Perform federated learning
    global_model = federated_learning(train_dataset, num_clients=5, num_rounds=100, epochs=1)
    
    # Test the global model
    test(global_model, test_loader)


Files already downloaded and verified
Files already downloaded and verified
Round 1/100
Average Loss: 2.2960, Average Accuracy: 11.77%
Round 2/100
Average Loss: 2.2387, Average Accuracy: 19.85%
Round 3/100
Average Loss: 2.0674, Average Accuracy: 25.95%
Round 4/100
Average Loss: 1.9178, Average Accuracy: 31.07%
Round 5/100
Average Loss: 1.8060, Average Accuracy: 34.72%
Round 6/100
Average Loss: 1.7175, Average Accuracy: 38.05%
Round 7/100
Average Loss: 1.6420, Average Accuracy: 40.54%
Round 8/100
Average Loss: 1.5799, Average Accuracy: 42.90%
Round 9/100
Average Loss: 1.5266, Average Accuracy: 44.57%
Round 10/100
Average Loss: 1.4843, Average Accuracy: 46.39%
Round 11/100
Average Loss: 1.4510, Average Accuracy: 47.56%
Round 12/100
Average Loss: 1.4208, Average Accuracy: 48.87%
Round 13/100
Average Loss: 1.3900, Average Accuracy: 50.08%
Round 14/100
Average Loss: 1.3624, Average Accuracy: 50.95%
Round 15/100
Average Loss: 1.3357, Average Accuracy: 52.14%
Round 16/100
Average Loss: 1.3092