## 🎓 Federated Learning from Scratch with PyTorch


### Step 1: Setup and Data Preparation

First, let's install PyTorch and prepare our data. Federated learning is all about distributed data, so we'll simulate this by splitting a single dataset (MNIST) into several smaller datasets, one for each "client."


In [2]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import copy # We will use this to deep-copy our model to clients
import random # For client selection

In [3]:
# Set a device for training (GPU if available, otherwise CPU)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# Define the data transformations for our dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

#### Download and load the MNIST training and test datasets


In [5]:
# In a real-world scenario, this data would already be on the clients' devices.
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

In [6]:
# Let's define the number of clients we'll simulate
NUM_CLIENTS = 10
CLIENT_BATCH_SIZE = 32

#### Partition the data for each client.


In [7]:
# We'll split the training data equally among the 10 clients.
client_data = torch.utils.data.random_split(train_dataset, 
                                            [len(train_dataset) // NUM_CLIENTS] * NUM_CLIENTS
                                            )


In [8]:
# Create a DataLoader for each client's data
client_trainloaders = [
    DataLoader(data, batch_size=CLIENT_BATCH_SIZE, shuffle=True) for data in client_data
]

In [9]:
# Create a single test DataLoader for the server's evaluation
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False)


In [10]:
print(f"Data has been partitioned among {NUM_CLIENTS} clients.")
print(f"Each client has {len(client_data[0])} samples.")
print(f"Test dataset has {len(test_dataset)} samples.")
print(f"Test DataLoader created with batch size {test_dataloader.batch_size}.")

Data has been partitioned among 10 clients.
Each client has 6000 samples.
Test dataset has 10000 samples.
Test DataLoader created with batch size 128.


### Code Explanation:

torch, nn, optim: Standard PyTorch imports for building and training neural networks.

copy: We'll use copy.deepcopy to create independent copies of the global model for each client.

random: To randomly select a subset of clients for each training round.

datasets.MNIST: We use the MNIST dataset because it's simple and a great starting point.

torch.utils.data.random_split: This function is our "magic wand" for simulating decentralized data. It splits the train_dataset into 10 non-overlapping subsets, each representing a single client's private data.

## Step 2: Defining the Neural Network Model

We'll use a simple Multi-Layer Perceptron (MLP) for this task. The model is defined once and will be used by both the server (as the global model) and the clients (as their local models).

In [11]:
# Define the MLP model architecture
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        # 28x28 images, so input size is 784
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10) # Output layer for 10 classes (digits 0-9)

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x


### Code Explanation:

This is a standard PyTorch nn.Module class. It defines the structure of our model.

The forward method specifies how data flows through the network.

We're using a simple architecture: flatten the image, pass it through two fully connected layers with a ReLU activation, and a final layer for the 10 output classes.

## Step 3: The Client-side Training Loop

Each client needs a function to perform local training. This function will take the client's data and the current global model, train it for a few epochs, and return the updated model parameters.


In [12]:
def client_training(model, trainloader, epochs=1):
    """
    Performs a single round of local training on a client's data.

    Args:
        model (nn.Module): The global model parameters from the server.
        trainloader (DataLoader): The client's local data loader.
        epochs (int): Number of local epochs to train for.

    Returns:
        OrderedDict: The updated state_dict (model parameters) after local training.
    """
    # Create a local copy of the model
    local_model = copy.deepcopy(model).to(DEVICE)
    local_model.train()  # Set the model to training mode

    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(local_model.parameters(), lr=0.01)

    for epoch in range(epochs):
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = local_model(images)
            loss = criterion(outputs, labels)
            # Backward pass and optimization
            loss.backward()
            optimizer.step()
            
    return local_model.state_dict()


### Code Explanation:

client_training function: This simulates a single client.

copy.deepcopy(model): This is crucial! Each client needs its own independent copy of the global model to train on. We don't want them to modify the global model directly.

local_model.train(): Sets the model to training mode.

criterion and optimizer: We use a standard loss function and optimizer for our local training.

The inner for loop is a standard PyTorch training loop for one local epoch.

The function returns the state_dict, which is a dictionary containing the updated model's parameters (weights and biases).

## Step 4: The Server-side Aggregation

The server's job is to collect the updated parameters from the clients and combine them. We will implement the most common aggregation algorithm, Federated Averaging (FedAvg). FedAvg calculates a weighted average of the client model parameters, where the weight is proportional to the size of the client's training data.

In [13]:
def aggregate_parameters(client_updates):
    """
    Aggregates parameters from multiple clients using FedAvg.

    Args:
        client_updates (list): A list of client state_dicts (parameters).

    Returns:
        OrderedDict: The aggregated global state_dict.
    """

    if not client_updates:
        raise ValueError("No client updates provided for aggregation.")
        return None
    
    # We assume all clients have the same number of data points for simplicity
    # In a real-world scenario, you would use weights based on data size.
    global_state_dict = copy.deepcopy(client_updates[0])  # Start with the first client's parameters

    for name in global_state_dict:
        # Average the parameters across clients
        global_state_dict[name] = torch.zero_like(global_state_dict[name])

    for client_state_dict in client_updates:
        for name, param in client_state_dict.items():
            global_state_dict[name] += param / len(client_updates)

    return global_state_dict



    

### Code Explanation:

The aggregate_parameters function takes a list of state_dicts from the clients.

It initializes a new dictionary (global_state_dict) with zeros.

It then iterates through each client's state_dict and adds its parameters to the new global dictionary.

Finally, it divides by the number of clients to get the average. For simplicity, we are assuming each client has the same amount of data, so it's a simple average.

## Step 5: The Federated Training Loop

Now we put all the pieces together in a main loop that simulates the federated training process over multiple rounds.

In [15]:
def server_evaluation(model, dataloader):

    """
    Evaluates the global model on the server's test dataset.

    Args:
        model (nn.Module): The global model to evaluate.
        dataloader (DataLoader): The test dataset DataLoader.

    Returns:
        float: The accuracy of the model on the test dataset.
    """
    model.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total if total > 0 else 0
    return accuracy


In [17]:
# Initialize the global model on the server
global_model = MLP().to(DEVICE)
print("Global model initialized on the server.")
print(f"initial global model accuracy: {server_evaluation(global_model, test_dataloader):.2f}%")


Global model initialized on the server.
initial global model accuracy: 14.98%


In [None]:
# Federated Learning Main Loop
NUM_ROUNDS = 5
CLIENTS_PER_ROUND = 5 # Number of clients to select for each round

for round_num in range(NUM_ROUNDS):
    print(f"\n--- Starting Federated Learning Round {round_num + 1}/{NUM_ROUNDS} ---")

    # 1. Server selects a subset of clients for the current round
    participating_clients_indices = random.sample(range(NUM_CLIENTS), CLIENTS_PER_ROUND)
    print(f"Server selects clients: {participating_clients_indices}")

    # 2. Server sends the global model to the selected clients
    client_updates = []

    for client_idx in participating_clients_indices:
        # Simulate local training on each client
        print(f"Client {client_idx} is training...")
        client_dataloader = client_trainloaders[client_idx]
        local_state_dict = client_training(global_model, client_dataloader, epochs=1)
        client_updates.append(local_state_dict)

    # 3. Server aggregates the updates from all the participating clients
    new_global_state_dict = aggregate_parameters(client_updates)
    print(f"Global model parameters Aggregated.")

    # 4. Server updates the global model with the aggregated parameters
    if new_global_state_dict is not None:
        global_model.load_state_dict(new_global_state_dict)
        print(f"Global model parameters updated.")

    # 5. Evaluate the updated global model on the server's test dataset
    accuracy = server_evaluation(global_model, test_dataloader)
    print(f"Global model accuracy after round {round_num + 1}: {accuracy:.2f}%")
    print("Round completed.\n")

print("\n--- Federated Learning Finished ---")
