# 🎓 Advanced Federated Learning with FedProx and a CNN

Welcome back! In this tutorial, we will elevate our federated learning knowledge by implementing a more advanced algorithm, FedProx, on a new dataset using a Convolutional Neural Network (CNN).

The main challenge in real-world federated learning is data heterogeneity, where each client's data distribution is different (non-IID). FedProx is a state-of-the-art solution designed to handle this issue by regularizing the local client training.

## Step 1: Setup and Data Preparation (with Non-IID Data)

First, we set up our environment. For this tutorial, we will use the CIFAR-10 dataset. To demonstrate the value of FedProx, we will deliberately create a non-IID data scenario where each client's data is heavily skewed towards a few classes

In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import copy
import random
import numpy as np

In [2]:
# Set a device for training
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# Define the data transformations for CIFAR-10
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

### Download and load the CIFAR-10 datasets


In [None]:
train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10('./data', train=False, download=True, transform=transform)

100.0%


In [6]:
# Let's define the number of clients and the number of classes
NUM_CLIENTS = 10
CLASSES_PER_CLIENT = 2  # Each client will get data from only 2 classes

In [9]:
def create_non_iid_partitions(dataset, num_clients, classes_per_client):
    """
    Create non-IID data for clients by assigning each client a subset of classes.
    """
    class_indices = [[] for _ in range(10)]
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)

    client_indices = [[] for _ in range(num_clients)]

    #Assign a unique set of `classes_per_client` to each client
    class_sets = [set(random.sample(range(10), classes_per_client)) for _ in range(num_clients)]
    for client_idx in range(num_clients):
        for class_idx in class_sets[client_idx]:
            client_indices[client_idx].extend(class_indices[class_idx])
    
    # Shuffle the indices for each client and create client subsets
    partitions = [torch.utils.data.Subset(dataset, random.sample(indices, len(indices))) for indices in client_indices]

    return partitions

In [11]:
client_data = create_non_iid_partitions(train_dataset, NUM_CLIENTS, CLASSES_PER_CLIENT)
client_trainloaders = [DataLoader(data, batch_size=32, shuffle=True) for data in client_data]
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False)

print(f"Data has been partitioned among {NUM_CLIENTS} clients, with each client having data from {CLASSES_PER_CLIENT} classes.")

Data has been partitioned among 10 clients, with each client having data from 2 classes.


### Code Explanation:

create_non_iid_partitions: This function is the key to simulating a real-world scenario. Instead of splitting the data randomly, it ensures each client's dataset is heavily skewed towards a few classes, making the federated learning problem more challenging.

The rest of the data preparation is similar, but for the CIFAR-10 dataset, which has 3-channel images.

## Step 2: Defining the Convolutional Neural Network (CNN) Model

A CNN is far more suitable for image data like CIFAR-10 than the simple MLP we used before.

In [None]:
# Define the CNN model architecture
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size=5) # 3 input channels (RGB), 6 output channels, 5x5 kernel
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120) # 16 channels, 5x5 feature map
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10) # 10 output classes for CIFAR-10

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)  # Flatten the tensor
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x