# Federated learning

First, we load a simple image dataset (like MNIST,  FashionMNIST) and we create a simple Convolutional Neural Network and check that your training works on a signle neural network (on a subset of the dataset).

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import copy
import numpy as np

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 50
LEARNING_RATE = 0.01

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

full_train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

def get_subset(dataset, num_samples, start_index=0):
    indices = list(range(start_index, start_index + num_samples))
    return Subset(dataset, indices)

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.mp = nn.MaxPool2d(2)

        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        in_size = x.size(0)
        x = torch.relu(self.mp(self.conv1(x)))
        x = torch.relu(self.mp(self.conv2(x)))
        x = x.view(in_size, -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return torch.log_softmax(x, dim=1)

def sanity_check():
    print("--- Running sanity check ---")
    model = SimpleCNN().to(DEVICE)
    optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE)
    subset = get_subset(full_train_dataset, 600)
    loader = DataLoader(subset, batch_size=BATCH_SIZE, shuffle=True)
    
    model.train()
    for epoch in range(5):
        for data, target in loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            optimizer.zero_grad()
            output = model(data)
            loss = nn.functional.nll_loss(output, target)
            loss.backward()
            optimizer.step()
    print("Sanity check passed.\n")

sanity_check()

--- Running sanity check ---
Sanity check passed.



We then create a function average_model_parameters(models: iterable, average_weight): that average the parameters of each model parameters following the approach in the article. Afterward, we create a function that reproduces Algorithm 1 in the article. We consider that all local models are trained on your local machine and not remotly. We do not implement the common weight initialization scheme for now.

In [2]:
def average_model_parameters(models, weights):
    with torch.no_grad():
        avg_model = copy.deepcopy(models[0])
        
        for param in avg_model.parameters():
            param.data.zero_()
            
        for i, model in enumerate(models):
            weight = weights[i]
            for avg_param, client_param in zip(avg_model.parameters(), model.parameters()):
                avg_param.data.add_(client_param.data * weight)
                
        return avg_model

def train_local_model(model, dataset, epochs=5):
    model.train()
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE)
    
    for _ in range(epochs):
        for data, target in loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            optimizer.zero_grad()
            output = model(data)
            loss = nn.functional.nll_loss(output, target)
            loss.backward()
            optimizer.step()
    return model

def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    return 100. * correct / len(test_loader.dataset)

We then run a training of two models with average coefficients being 0.5 for each model. Each model is trained on 600 data points each. We reuse the same setup as in the article (50 examples per local batch).

In [5]:
def run_failure_experiment():
    print("--- Running failure experiment ---")
    
    model_A = SimpleCNN().to(DEVICE)
    model_B = SimpleCNN().to(DEVICE)
    
    data_A = get_subset(full_train_dataset, 600, start_index=0)
    data_B = get_subset(full_train_dataset, 600, start_index=600)
    
    model_A = train_local_model(model_A, data_A, epochs=10)
    model_B = train_local_model(model_B, data_B, epochs=10)
    
    avg_model = average_model_parameters([model_A, model_B], [0.5, 0.5])
    
    test_loader = DataLoader(test_dataset, batch_size=1000)
    acc_A = evaluate_model(model_A, test_loader)
    acc_B = evaluate_model(model_B, test_loader)
    acc_Avg = evaluate_model(avg_model, test_loader)
    
    print(f"Model A Accuracy: {acc_A:.2f}%")
    print(f"Model B Accuracy: {acc_B:.2f}%")
    print(f"Averaged Model Accuracy: {acc_Avg:.2f}%")

run_failure_experiment()

--- Running failure experiment ---
Model A Accuracy: 42.95%
Model B Accuracy: 55.73%
Averaged Model Accuracy: 29.52%


We can see that this model accuray is very low: around 25%. The reason for that, as explained in the article, is that neural network loss functions are non-convex. Even if two models solve the same task, they will find different local minima. Because the order of neurons in hidden layers is arbitrary, "Neuron 1" in model A might represent a diagonal line detector, while "Neuron 1" in model B represents a circle detector. When you average these weights without a shared history, you destroy the internal structure of the neurons. The resulting averaged model lands in a high-loss region rather than a valid solution.

We are then updating the training setup so that our models are initialized with of common set of  parameters. We then run a training in this setting and we make a study to see the impact of the number of data points on the performance of the combined model. We run training with 2, 3, 5 models, with each setting having : 25, 50, 100, 200 and 500 data points each. We then generate an array to represent accuracy of the global model with each model and datapoints number.

In [6]:
def run_federated_study():
    print("\n--- Running federated learning study ---")
    
    num_models_list = [2, 3, 5]
    data_points_list = [25, 50, 100, 200, 500]
    results = {}
    
    test_loader = DataLoader(test_dataset, batch_size=1000)

    print(f"{'# Models':<10} | {'Data points':<20} | {'Global model acc':<15}")
    print("-" * 55)

    for n_models in num_models_list:
        results[n_models] = {}
        for n_data in data_points_list:
            global_model = SimpleCNN().to(DEVICE)
            
            client_datasets = []
            for i in range(n_models):
                start_idx = i * n_data
                client_datasets.append(get_subset(full_train_dataset, n_data, start_index=start_idx))
            
            for round_num in range(5):
                local_models = []
                
                for ds in client_datasets:
                    local_model = copy.deepcopy(global_model)
                    local_model = train_local_model(local_model, ds, epochs=5)
                    local_models.append(local_model)
                
                weights = [1.0/n_models] * n_models
                global_model = average_model_parameters(local_models, weights)
            
            acc = evaluate_model(global_model, test_loader)
            results[n_models][n_data] = acc
            
            print(f"{n_models:<10} | {n_data:<20} | {acc:.2f}%")

    return results

study_results = run_federated_study()


--- Running federated learning study ---
# Models   | Data points          | Global model acc
-------------------------------------------------------
2          | 25                   | 19.49%
2          | 50                   | 12.67%
2          | 100                  | 27.21%
2          | 200                  | 56.72%
2          | 500                  | 84.49%
3          | 25                   | 10.10%
3          | 50                   | 18.53%
3          | 100                  | 28.28%
3          | 200                  | 62.49%
3          | 500                  | 83.90%
5          | 25                   | 25.02%
5          | 50                   | 23.95%
5          | 100                  | 41.49%
5          | 200                  | 51.34%
5          | 500                  | 82.66%


We then repeat the study on another dataset like  HAM 10000:
For this to work, you have to download `https://www.kaggle.com/datasets/kmader/skin-cancer-mnist-ham10000` and to put it in `data/skin/`.

In [6]:
from torch.utils.data import Dataset, DataLoader, Subset
import pandas as pd
import os
from PIL import Image

BATCH_SIZE = 50
LEARNING_RATE = 0.01

IMG_DIR_1 = './data/skin/HAM10000_images_part_1'
IMG_DIR_2 = './data/skin/HAM10000_images_part_2' 
CSV_PATH = './data/skin/HAM10000_metadata.csv'

class HAM10000Dataset(Dataset):
    def __init__(self, csv_file, dir1, dir2, transform=None):
        self.df = pd.read_csv(csv_file)
        self.dir1 = dir1
        self.dir2 = dir2
        self.transform = transform
        
        self.label_map = {
            'nv': 0, 'mel': 1, 'bkl': 2, 'bcc': 3, 
            'akiec': 4, 'vasc': 5, 'df': 6
        }
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_filename = self.df.iloc[idx, 1] + ".jpg"
        label_str = self.df.iloc[idx, 2]
        label = self.label_map[label_str]

        path1 = os.path.join(self.dir1, img_filename)
        path2 = os.path.join(self.dir2, img_filename)
        
        if os.path.exists(path1):
            final_path = path1
        elif os.path.exists(path2):
            final_path = path2
        else:
            raise FileNotFoundError(f"Image {img_filename} not found in either folder.")

        image = Image.open(final_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label

ham_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.763, 0.546, 0.570], std=[0.141, 0.153, 0.170]) 
])

try:
    full_dataset = HAM10000Dataset(
        csv_file=CSV_PATH, 
        dir1=IMG_DIR_1, 
        dir2=IMG_DIR_2, 
        transform=ham_transform
    )
    print(f"Dataset loaded. Total number of images: {len(full_dataset)}")
except FileNotFoundError as e:
    print(f"ERROR: {e}")
    print("Ensure you have unzipped both 'part_1' and 'part_2' folders into ./data/skin/")
    exit()

train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])



class SimpleCNN_HAM(nn.Module):
    def __init__(self):
        super(SimpleCNN_HAM, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5) 
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5)
        
        self.mp = nn.MaxPool2d(2)
        
        self.fc1 = nn.Linear(800, 120)
        self.fc2 = nn.Linear(120, 7)

    def forward(self, x):
        in_size = x.size(0)
        x = torch.relu(self.mp(self.conv1(x)))
        x = torch.relu(self.mp(self.conv2(x)))
        x = x.view(in_size, -1) # Flatten
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return torch.log_softmax(x, dim=1)


def train_local_model(model, dataset, epochs=5):
    model.train()
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE)
    
    for _ in range(epochs):
        for data, target in loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            optimizer.zero_grad()
            output = model(data)
            loss = nn.functional.nll_loss(output, target)
            loss.backward()
            optimizer.step()
    return model

def average_model_parameters(models, weights):
    with torch.no_grad():
        avg_model = copy.deepcopy(models[0])
        for param in avg_model.parameters():
            param.data.zero_()
        
        for i, model in enumerate(models):
            w = weights[i]
            for avg_p, client_p in zip(avg_model.parameters(), model.parameters()):
                avg_p.data.add_(client_p.data * w)
    return avg_model

def evaluate_model(model, test_dataset):
    model.eval()
    loader = DataLoader(test_dataset, batch_size=100)
    correct = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    return 100. * correct / len(loader.dataset)

def get_subset_indices(dataset, num_samples, start_index=0):
    return Subset(dataset, range(start_index, start_index + num_samples))



def run_ham10000_study():
    print(f"\n--- Running federated study ---")
    print(f"Device: {DEVICE}")
    
    num_models_list = [2, 3]
    data_points_list = [50, 200, 500]
    
    print(f"\n{'# Models':<10} | {'Data points':<15} | {'Global ccc':<15}")
    print("-" * 45)
    
    for n_models in num_models_list:
        for n_data in data_points_list:
            global_model = SimpleCNN_HAM().to(DEVICE)
            
            client_datasets = []
            valid_setup = True
            for i in range(n_models):
                start_idx = i * n_data
                if start_idx + n_data > len(train_dataset):
                    print(f"Not enough data for {n_models} models with {n_data} samples.")
                    valid_setup = False
                    break
                client_datasets.append(get_subset_indices(train_dataset, n_data, start_idx))
            
            if not valid_setup: continue

            for round_num in range(3):
                local_models = []
                
                for ds in client_datasets:
                    local_model = copy.deepcopy(global_model)
                    local_model = train_local_model(local_model, ds, epochs=3)
                    local_models.append(local_model)
                
                weights = [1.0/n_models] * n_models
                global_model = average_model_parameters(local_models, weights)
            
            acc = evaluate_model(global_model, test_dataset)
            print(f"{n_models:<10} | {n_data:<15} | {acc:.2f}%")

if __name__ == "__main__":
    run_ham10000_study()

Dataset loaded. Total number of images: 10015

--- Running federated study ---
Device: cpu

# Models   | Data/Client     | Global Acc     
---------------------------------------------
2          | 50              | 67.80%
2          | 200             | 67.80%
2          | 500             | 67.80%
3          | 50              | 67.80%
3          | 200             | 67.80%
3          | 500             | 67.80%
