In [None]:
import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Subset, DataLoader

# 1. Load a simple image dataset (like MNIST,  FashionMNIST)

In [None]:
transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=64,
    shuffle=False
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=1000,
    shuffle=False
)

# 2. Create a simple Convolutional Neural Network (2 convolutional layers and 2 dense layers, for example) and check that your training works on a single neural network (on a subset of the dataset)

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64*7*7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 64*7*7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CNN().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.NLLLoss()

In [None]:
subset_indices = list(range(64))
train_subset = Subset(train_dataset, subset_indices)

subset_train_loader = DataLoader(
    dataset=train_subset,
    batch_size=64,
    shuffle=True
)

model.train()
for data, target in subset_train_loader:
    data, target = data.to(device), target.to(device)

    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()

for name, param in model.named_parameters():
    print(f"Layer: {name} | Shape: {param.shape}")

Layer: conv1.weight | Shape: torch.Size([32, 1, 3, 3])
Layer: conv1.bias | Shape: torch.Size([32])
Layer: conv2.weight | Shape: torch.Size([64, 32, 3, 3])
Layer: conv2.bias | Shape: torch.Size([64])
Layer: fc1.weight | Shape: torch.Size([128, 3136])
Layer: fc1.bias | Shape: torch.Size([128])
Layer: fc2.weight | Shape: torch.Size([10, 128])
Layer: fc2.bias | Shape: torch.Size([10])


# 3. Create a function average_model_parameters(models: iterable, average_weight): that average the parameters of each model parameters following the approach in the article.

In [None]:
def average_model_parameters(models: list, weights: list):
    if not models:
        return None
    if len(models) != len(weights):
        raise ValueError("Number of models and weights must be the same.")

    averaged_model = type(models[0])().to(device)
    averaged_state_dict = averaged_model.state_dict()

    for key in averaged_state_dict.keys():
        averaged_state_dict[key] = torch.zeros_like(averaged_state_dict[key], device='cpu')

    for i, model in enumerate(models):
        model_state_dict = model.state_dict()
        for key in averaged_state_dict.keys():
            averaged_state_dict[key] += weights[i] * model_state_dict[key].cpu()

    averaged_model.load_state_dict(averaged_state_dict)
    averaged_model.to(device)
    return averaged_model

# 4. Create a scrip or function that reproduces Algorithm 1 in the article. Consider that all local models are trained on your local machine and not remotly. Do not implement the common weight initialization scheme for now.

In [None]:
def get_client_dataloaders(train_dataset, num_clients, num_data_points_per_client, local_batch_size, random_seed=42):
    client_dataloaders = []
    total_data_points = len(train_dataset)

    if num_data_points_per_client * num_clients > total_data_points:
        raise ValueError(
            f"Not enough data for {num_clients} clients with {num_data_points_per_client} data points each. "
            f"Total available in dataset: {total_data_points}"
        )

    import random
    random.seed(random_seed)
    all_indices = list(range(total_data_points))
    random.shuffle(all_indices)

    for i in range(num_clients):
        client_indices_start = i * num_data_points_per_client
        client_indices_end = client_indices_start + num_data_points_per_client
        client_specific_indices = all_indices[client_indices_start:client_indices_end]

        client_subset = Subset(train_dataset, client_specific_indices)
        client_loader = DataLoader(
            dataset=client_subset,
            batch_size=local_batch_size,
            shuffle=True
        )
        client_dataloaders.append(client_loader)
    return client_dataloaders


In [None]:
def client_update(model, dataloader, criterion, optimizer, num_local_epochs, device):
    model.train()
    for epoch in range(num_local_epochs):
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
    return model


In [None]:
def federated_averaging(
    global_model,
    train_dataset,
    num_clients,
    num_rounds,
    num_data_points_per_client,
    num_local_epochs,
    local_batch_size,
    learning_rate,
    criterion,
    optimizer_class,
    device
):

    client_models = [type(global_model)().to(device) for _ in range(num_clients)]

    client_dataloaders = get_client_dataloaders(
        train_dataset,
        num_clients,
        num_data_points_per_client,
        local_batch_size
    )

    for round_num in range(num_rounds):
        #print(f"round {round_num + 1}/{num_rounds}")

        # server distributes global model
        for client_model in client_models:
            client_model.load_state_dict(global_model.state_dict())

        # client training
        for i, client_model in enumerate(client_models):
            #print(f"client {i+1} training")
            client_optimizer = optimizer_class(client_model.parameters(), lr=learning_rate)
            client_update(client_model, client_dataloaders[i], criterion, client_optimizer, num_local_epochs, device)

        # parameter averaging
        weights = [1.0 / num_clients] * num_clients
        global_model = average_model_parameters(client_models, weights)

        #print(f"round {round_num + 1} averaged")

    return global_model


# 5. Run a training of two models with average coefficients being 0.5 for each model. Each model should be trained on 600 data points each. Reuse the same setup as in the article (50 examples per local batch). You should see that the approach does not work.  Why ?

In [None]:
num_clients = 2
num_data_points_per_client = 600
num_rounds = 5
num_local_epochs = 1
local_batch_size = 50
learning_rate = 0.01

In [None]:
global_model = CNN().to(device)
criterion = nn.NLLLoss()
optimizer_class = optim.SGD

final_global_model = federated_averaging(
    global_model=global_model,
    train_dataset=train_dataset,
    num_clients=num_clients,
    num_rounds=num_rounds,
    num_data_points_per_client=num_data_points_per_client,
    num_local_epochs=num_local_epochs,
    local_batch_size=local_batch_size,
    learning_rate=learning_rate,
    criterion=criterion,
    optimizer_class=optimizer_class,
    device=device
)

def test_model(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item() * data.size(0)  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    return test_loss, accuracy

test_loss, accuracy = test_model(final_global_model, test_loader, criterion, device)
print(f'loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')


loss: 2.2831, Accuracy: 9.86%


according to the article the client models need to have the same initialization or they wll diverge during training which makes averaging ineffective.

# 6. Update the training setup so that your models are initialized with a common set of  parameters. Run a training in this setting.

In [None]:
num_clients_common_init = 2
num_data_points_per_client_common_init = 600
num_rounds_common_init = 5
num_local_epochs_common_init = 1
local_batch_size_common_init = 50
learning_rate_common_init = 0.01

In [None]:
global_model_common_init = CNN().to(device)
criterion = nn.NLLLoss()
optimizer_class = optim.SGD

final_global_model_common_init = federated_averaging(
    global_model=global_model_common_init,
    train_dataset=train_dataset,
    num_clients=num_clients_common_init,
    num_rounds=num_rounds_common_init,
    num_data_points_per_client=num_data_points_per_client_common_init,
    num_local_epochs=num_local_epochs_common_init,
    local_batch_size=local_batch_size_common_init,
    learning_rate=learning_rate_common_init,
    criterion=criterion,
    optimizer_class=optimizer_class,
    device=device
)

test_loss_common_init, accuracy_common_init = test_model(final_global_model_common_init, test_loader, criterion, device)
print(f'Loss: {test_loss_common_init:.4f}, Accuracy: {accuracy_common_init:.2f}%')

Loss: 2.2783, Accuracy: 27.35%


# 7. Make a study to see the impact of the number of data points on the performance of the combined model. Run training with 2, 3, 5 models. With each setting having :
 - 25, 50, 100, 200 and 500 data points each.

In [None]:
num_clients_settings = [2, 3, 5]
num_data_points_per_client_settings = [25, 50, 100, 200, 500]

study_num_rounds = 5
study_num_local_epochs = 1
study_local_batch_size = 50
study_learning_rate = 0.01

In [None]:
results = []

for num_clients in num_clients_settings:
    for num_data_points_per_client in num_data_points_per_client_settings:
        print(f"\nclients: {num_clients} data_points: {num_data_points_per_client}")

        # Initialize a new global model for each run to ensure common initialization
        current_global_model = CNN().to(device)

        final_model_study = federated_averaging(
            global_model=current_global_model,
            train_dataset=train_dataset,
            num_clients=num_clients,
            num_rounds=study_num_rounds,
            num_data_points_per_client=num_data_points_per_client,
            num_local_epochs=study_num_local_epochs,
            local_batch_size=study_local_batch_size,
            learning_rate=study_learning_rate,
            criterion=criterion,
            optimizer_class=optimizer_class,
            device=device
        )

        test_loss, accuracy = test_model(final_model_study, test_loader, criterion, device)
        print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {accuracy:.2f}%')

        results.append({
            'num_clients': num_clients,
            'num_data_points_per_client': num_data_points_per_client,
            'test_loss': test_loss,
            'test_accuracy': accuracy
        })


clients: 2 data_points: 25
Test Loss: 2.3012, Test Accuracy: 8.93%

clients: 2 data_points: 50
Test Loss: 2.3033, Test Accuracy: 11.11%

clients: 2 data_points: 100
Test Loss: 2.3034, Test Accuracy: 9.05%

clients: 2 data_points: 200
Test Loss: 2.2961, Test Accuracy: 9.59%

clients: 2 data_points: 500
Test Loss: 2.2789, Test Accuracy: 34.56%

clients: 3 data_points: 25
Test Loss: 2.3034, Test Accuracy: 11.25%

clients: 3 data_points: 50
Test Loss: 2.2996, Test Accuracy: 11.99%

clients: 3 data_points: 100
Test Loss: 2.2963, Test Accuracy: 17.29%

clients: 3 data_points: 200
Test Loss: 2.2957, Test Accuracy: 9.74%

clients: 3 data_points: 500
Test Loss: 2.2916, Test Accuracy: 17.24%

clients: 5 data_points: 25
Test Loss: 2.3021, Test Accuracy: 10.00%

clients: 5 data_points: 50
Test Loss: 2.2996, Test Accuracy: 12.95%

clients: 5 data_points: 100
Test Loss: 2.2991, Test Accuracy: 13.55%

clients: 5 data_points: 200
Test Loss: 2.2925, Test Accuracy: 23.04%

clients: 5 data_points: 500
T

# 8. Report the results in a table.

In [None]:
from IPython.display import display

display(df_results)

Unnamed: 0,num_clients,num_data_points_per_client,test_loss,test_accuracy
0,2,25,2.302991,9.84
1,2,50,2.299743,10.28
2,2,100,2.288402,13.41
3,2,200,2.297872,11.13
4,2,500,2.291102,20.99
5,3,25,2.298772,12.36
6,3,50,2.301378,12.63
7,3,100,2.300431,10.52
8,3,200,2.300315,13.05
9,3,500,2.279836,22.37


# 9. Repeat the study on another dataset like  HAM 10000 (skin lesion dataset)

\* extraction of ham10000 dataset causes it to break. i will be doing study on fashionmnist instead

In [94]:
transform_fashion = transforms.Compose([
    transforms.ToTensor()
])

train_dataset_fashion = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform_fashion)
test_dataset_fashion = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform_fashion)

test_loader_fashion = DataLoader(
    dataset=test_dataset_fashion,
    batch_size=1000,
    shuffle=False
)


In [95]:
results_fashion = []

for num_clients in num_clients_settings:
    for num_data_points_per_client in num_data_points_per_client_settings:
        print(f"\nFashionMNIST - clients: {num_clients} data_points: {num_data_points_per_client}")

        current_global_model_fashion = CNN().to(device)

        final_model_study_fashion = federated_averaging(
            global_model=current_global_model_fashion,
            train_dataset=train_dataset_fashion, # Use FashionMNIST train dataset
            num_clients=num_clients,
            num_rounds=study_num_rounds,
            num_data_points_per_client=num_data_points_per_client,
            num_local_epochs=study_num_local_epochs,
            local_batch_size=study_local_batch_size,
            learning_rate=study_learning_rate,
            criterion=criterion,
            optimizer_class=optimizer_class,
            device=device
        )

        test_loss_fashion, accuracy_fashion = test_model(final_model_study_fashion, test_loader_fashion, criterion, device) # Use FashionMNIST test loader
        print(f'Test Loss: {test_loss_fashion:.4f}, Test Accuracy: {accuracy_fashion:.2f}%')

        results_fashion.append({
            'num_clients': num_clients,
            'num_data_points_per_client': num_data_points_per_client,
            'test_loss': test_loss_fashion,
            'test_accuracy': accuracy_fashion
        })



FashionMNIST - clients: 2 data_points: 25
Test Loss: 2.2963, Test Accuracy: 12.90%

FashionMNIST - clients: 2 data_points: 50
Test Loss: 2.3047, Test Accuracy: 10.00%

FashionMNIST - clients: 2 data_points: 100
Test Loss: 2.2976, Test Accuracy: 20.75%

FashionMNIST - clients: 2 data_points: 200
Test Loss: 2.2920, Test Accuracy: 25.75%

FashionMNIST - clients: 2 data_points: 500
Test Loss: 2.2591, Test Accuracy: 29.65%

FashionMNIST - clients: 3 data_points: 25
Test Loss: 2.2987, Test Accuracy: 10.07%

FashionMNIST - clients: 3 data_points: 50
Test Loss: 2.2933, Test Accuracy: 15.67%

FashionMNIST - clients: 3 data_points: 100
Test Loss: 2.2979, Test Accuracy: 10.00%

FashionMNIST - clients: 3 data_points: 200
Test Loss: 2.2773, Test Accuracy: 26.48%

FashionMNIST - clients: 3 data_points: 500
Test Loss: 2.2667, Test Accuracy: 10.29%

FashionMNIST - clients: 5 data_points: 25
Test Loss: 2.3006, Test Accuracy: 10.11%

FashionMNIST - clients: 5 data_points: 50
Test Loss: 2.2997, Test Acc

In [96]:
import pandas as pd
from IPython.display import display

df_results_fashion = pd.DataFrame(results_fashion)
display(df_results_fashion)


Unnamed: 0,num_clients,num_data_points_per_client,test_loss,test_accuracy
0,2,25,2.296307,12.9
1,2,50,2.304664,10.0
2,2,100,2.297592,20.75
3,2,200,2.292038,25.75
4,2,500,2.259111,29.65
5,3,25,2.298743,10.07
6,3,50,2.293273,15.67
7,3,100,2.297863,10.0
8,3,200,2.277279,26.48
9,3,500,2.266665,10.29
