In [1]:
import torch.nn as nn
from torchvision import datasets, transforms
import random
import torch
from torch.utils.data import DataLoader, Subset, random_split
import torch.optim as optim
import torch.nn.functional as F

## Load the MNIST dataset

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./', train=False, download=True, transform=transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 15877749.05it/s]


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 469194.49it/s]


Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 4385234.16it/s]


Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2858293.89it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw






## Extract two subsets of 600 data points each

In [3]:
random.seed(42)

indices = list(range(len(train_dataset)))
random.shuffle(indices)

# Split into two subsets of 600 data points each
subset1_indices = indices[:600]
subset2_indices = indices[600:1200]

# Create subset datasets
subset1 = torch.utils.data.Subset(train_dataset, subset1_indices)
subset2 = torch.utils.data.Subset(train_dataset, subset2_indices)

## Create a simple Convolutional Neural Network

In [4]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(64 * 24 * 24, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

## Create a function average_model_parameters

In [5]:
def average_model_parameters(models, average_weight):
    averaged_params = {}
    for param_name in models[0].state_dict():
        avg_param = sum(weight * model.state_dict()[param_name] for weight, model in zip(average_weight, models))
        averaged_params[param_name] = avg_param
    return averaged_params

## Create a function that updates

In [6]:
def update_model_parameters(model, averaged_params):
    with torch.no_grad(): 
        for param_name, avg_param in averaged_params.items():
            model.state_dict()[param_name].copy_(avg_param)

## Federated training function and evaluation

In [7]:
def federated_training(subset1, subset2, epochs=20, batch_size=50, average_weight=[0.5, 0.5], initialize_common_params=False):
    model1 = SimpleCNN()
    model2 = SimpleCNN()
    criterion = nn.CrossEntropyLoss()
    # With SGD optimizer, I get low accucary (0.1). So I choose Adam for better performance
    optimizer1 = optim.Adam(model1.parameters(), lr=0.001)
    optimizer2 = optim.Adam(model2.parameters(), lr=0.001)

    loader1 = DataLoader(subset1, batch_size=batch_size, shuffle=True)
    loader2 = DataLoader(subset2, batch_size=batch_size, shuffle=True)

    for epoch in range(epochs):
        model1.train()
        total_loss = 0
        for images, labels in loader1:
            optimizer1.zero_grad()
            outputs = model1(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer1.step()
            total_loss += loss.item()
        
        model2.train()
        total_loss = 0
        for images, labels in loader2:
            optimizer2.zero_grad()
            outputs = model2(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer2.step()
            total_loss += loss.item()

        if initialize_common_params:
            # Average the model parameters
            averaged_params = average_model_parameters([model1, model2], average_weight)
            update_model_parameters(model1, averaged_params)
            update_model_parameters(model2, averaged_params)

    return model1, model2

In [8]:
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

## Result

In [9]:
print("Training without initializing common parameters:")
model1, model2 = federated_training(subset1, subset2, epochs=20, batch_size=50, average_weight=[0.5, 0.5], initialize_common_params=False)

test_loader = DataLoader(test_dataset, batch_size=50, shuffle=False)

accuracy_model1 = evaluate_model(model1, test_loader)
accuracy_model2 = evaluate_model(model2, test_loader)

print(f"Accuracy of model 1 (no common params): {accuracy_model1:.4f}")
print(f"Accuracy of model 2 (no common params): {accuracy_model2:.4f}")

print("\nTraining with initializing common parameters (Federated Averaging):")
model1, model2 = federated_training(subset1, subset2, epochs=20, batch_size=50, average_weight=[0.5, 0.5], initialize_common_params=True)

accuracy_model1 = evaluate_model(model1, test_loader)
accuracy_model2 = evaluate_model(model2, test_loader)

print(f"Accuracy of model 1 (with common params): {accuracy_model1:.4f}")
print(f"Accuracy of model 2 (with common params): {accuracy_model2:.4f}")

Training without initializing common parameters:
Accuracy of model 1 (no common params): 0.9028
Accuracy of model 2 (no common params): 0.9091

Training with initializing common parameters (Federated Averaging):
Accuracy of model 1 (with common params): 0.9262
Accuracy of model 2 (with common params): 0.9262


In [18]:
def find_minimum_data_points(train_dataset, test_dataset):
    data_points_array = [100, 200, 300, 400, 500]
    indices = list(range(len(train_dataset)))
    random.shuffle(indices)
    
    for elem in data_points_array:
        subset1_indices = indices[:elem]
        subset2_indices = indices[elem:elem * 2]
        subset1 = torch.utils.data.Subset(train_dataset, subset1_indices)
        subset2 = torch.utils.data.Subset(train_dataset, subset2_indices)
        model1, model2 = federated_training(subset1, subset2, epochs=20, batch_size=50, average_weight=[0.5, 0.5], initialize_common_params=True)

        test_loader = DataLoader(test_dataset, batch_size=50, shuffle=False)
        
        accuracy_model1 = evaluate_model(model1, test_loader)
        accuracy_model2 = evaluate_model(model2, test_loader)
        
        print(f"Accuracy of model 1 with {elem} data points : {accuracy_model1:.4f}")
        print(f"Accuracy of model 2 with {elem} data points : {accuracy_model2:.4f}")
    
    return
                
        

In [19]:
find_minimum_data_points(train_dataset, test_dataset)

Accuracy of model 1 with 100 data points : 0.2396
Accuracy of model 2 with 100 data points : 0.2396
Accuracy of model 1 with 200 data points : 0.2982
Accuracy of model 2 with 200 data points : 0.2982
Accuracy of model 1 with 300 data points : 0.3393
Accuracy of model 2 with 300 data points : 0.3393
Accuracy of model 1 with 400 data points : 0.3690
Accuracy of model 2 with 400 data points : 0.3690
Accuracy of model 1 with 500 data points : 0.3968
Accuracy of model 2 with 500 data points : 0.3968


If we accept an accuracy of 90%, 400 data points is enough

## CIFAR-10

In [12]:
train_dataset = datasets.CIFAR10(root='./', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./', train=False, download=True, transform=transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 49143547.93it/s]


Extracting ./cifar-10-python.tar.gz to ./
Files already downloaded and verified


In [13]:
random.seed(42)

indices = list(range(len(train_dataset)))
random.shuffle(indices)

subset1_indices = indices[:600]
subset2_indices = indices[600:1200]

subset1 = torch.utils.data.Subset(train_dataset, subset1_indices)
subset2 = torch.utils.data.Subset(train_dataset, subset2_indices)

In [14]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(64 * 28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [15]:
print("Training without initializing common parameters:")
model1, model2 = federated_training(subset1, subset2, epochs=40, batch_size=50, average_weight=[0.5, 0.5], initialize_common_params=False)

test_loader = DataLoader(test_dataset, batch_size=50, shuffle=False)

accuracy_model1 = evaluate_model(model1, test_loader)
accuracy_model2 = evaluate_model(model2, test_loader)

print(f"Accuracy of model 1 (no common params): {accuracy_model1:.4f}")
print(f"Accuracy of model 2 (no common params): {accuracy_model2:.4f}")

print("\nTraining with initializing common parameters (Federated Averaging):")
model1, model2 = federated_training(subset1, subset2, epochs=40, batch_size=50, average_weight=[0.5, 0.5], initialize_common_params=True)

accuracy_model1 = evaluate_model(model1, test_loader)
accuracy_model2 = evaluate_model(model2, test_loader)

print(f"Accuracy of model 1 (with common params): {accuracy_model1:.4f}")
print(f"Accuracy of model 2 (with common params): {accuracy_model2:.4f}")

Training without initializing common parameters:
Accuracy of model 1 (no common params): 0.3446
Accuracy of model 2 (no common params): 0.3491

Training with initializing common parameters (Federated Averaging):
Accuracy of model 1 (with common params): 0.3969
Accuracy of model 2 (with common params): 0.3969


In [20]:
find_minimum_data_points(train_dataset, test_dataset)

Accuracy of model 1 with 100 data points : 0.2655
Accuracy of model 2 with 100 data points : 0.2655
Accuracy of model 1 with 200 data points : 0.2975
Accuracy of model 2 with 200 data points : 0.2975
Accuracy of model 1 with 300 data points : 0.3356
Accuracy of model 2 with 300 data points : 0.3356
Accuracy of model 1 with 400 data points : 0.3666
Accuracy of model 2 with 400 data points : 0.3666
Accuracy of model 1 with 500 data points : 0.3805
Accuracy of model 2 with 500 data points : 0.3805


In CIFAR-10 study's, the accuracy is very lower than MNIST study's.
Indeed CIFAR-10's images are more complex than MNIST's, then we should add more layers on the CNN architecture or using data augmentation in order to let model learn pattern more precisely.