In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

# Define the model
class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 128)  # First fully connected layer: input size is 28*28 (MNIST image size), output size is 128
        self.relu = nn.ReLU()  # ReLU activation function
        self.fc2 = nn.Linear(128, 10)  # Second fully connected layer: input size is 128, output size is 10 (number of classes in MNIST)
        self.softmax = nn.Softmax(dim=1)  # Softmax function applied along the dimension 1 (class scores)

    def forward(self, x):
        x = self.flatten(x)  # Flatten the input tensor to a 1D vector
        x = self.fc1(x)  # Apply the first fully connected layer
        x = self.relu(x)  # Apply ReLU activation
        x = self.fc2(x)  # Apply the second fully connected layer
        return self.softmax(x)  # Apply the Softmax function to get class probabilities

# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# Transform: convert images to tensors and normalize them with mean=0.5 and std=0.5

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Load the training dataset

test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# Load the test dataset

# Simulate federated learning clients
num_clients = 5  # Number of clients
client_data_size = len(train_dataset) // num_clients  # Size of data each client will have
clients = []

for i in range(num_clients):
    client_indices = list(range(i * client_data_size, (i + 1) * client_data_size))
    # Select data indices for each client
    clients.append(Subset(train_dataset, client_indices))
    # Create a subset of the training data for the client

# Federated learning process
global_model = SimpleMLP()  # Initialize the global model

def federated_avg(weights):
    avg_weights = [torch.mean(torch.stack([client_weights[layer] for client_weights in weights]), dim=0)
                   for layer in range(len(weights[0]))]
    # Compute the average weights across all clients for each layer
    return avg_weights

num_rounds = 5  # Number of federated learning rounds
for round_num in range(num_rounds):
    print(f'Federated Learning Round {round_num + 1}')
    client_weights = []  # List to store weights from all clients
    
    for client_data in clients:
        model = SimpleMLP()  # Initialize a new model for each client
        model.load_state_dict(global_model.state_dict())  # Load the current global model's weights
        optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam optimizer with learning rate 0.001
        criterion = nn.CrossEntropyLoss()  # CrossEntropyLoss for classification

        dataloader = DataLoader(client_data, batch_size=64, shuffle=True)
        # DataLoader for the client's data with batch size of 64

        model.train()  # Set model to training mode
        for epoch in range(1):  # Train for 1 epoch
            for x_client, y_client in dataloader:
                optimizer.zero_grad()  # Zero the gradients
                output = model(x_client)  # Forward pass
                loss = criterion(output, y_client)  # Compute loss
                loss.backward()  # Backward pass (compute gradients)
                optimizer.step()  # Update model weights
        
        client_weights.append([param.data.clone() for param in model.parameters()])
        # Store the model's weights after training on client data
    
    # Aggregate weights
    new_weights = federated_avg(client_weights)
    with torch.no_grad():  # Update global model weights without tracking gradients
        for param, new_weight in zip(global_model.parameters(), new_weights):
            param.data = new_weight

# Evaluate the global model on the test set
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
# DataLoader for the test data with batch size of 1000

global_model.eval()  # Set global model to evaluation mode
correct = 0  # Counter for correct predictions
total = 0  # Counter for total predictions
with torch.no_grad():  # No gradient computation needed
    for x_test, y_test in test_loader:
        output = global_model(x_test)  # Forward pass
        _, predicted = torch.max(output.data, 1)  # Get the class with the highest score
        total += y_test.size(0)  # Increment total by the batch size
        correct += (predicted == y_test).sum().item()  # Count the number of correct predictions
print(f'Accuracy of the global model on the test data: {100 * correct / total:.2f}%')
# Print the accuracy of the global model on the test data


Federated Learning Round 1
Federated Learning Round 2
Federated Learning Round 3
Federated Learning Round 4
Federated Learning Round 5
Accuracy of the global model on the test data: 92.51%
