In [30]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import random

In [31]:
# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 12663522.26it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 311667.21it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:02<00:00, 732646.81it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2011459.06it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [79]:
indices = list(range(len(train_dataset)))
random.shuffle(indices)

subset1_indices = indices[:600]
subset2_indices = indices[600:1200]

subset1 = Subset(train_dataset, subset1_indices)
subset2 = Subset(train_dataset, subset2_indices)

batch_size = 50
train_loader1 = DataLoader(subset1, batch_size=batch_size, shuffle=True)
train_loader2 = DataLoader(subset2, batch_size=batch_size, shuffle=True)

In [90]:
class SimpleCNN(nn.Module) :
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 5)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.fc1 = nn.Linear(64*4*4, 512)  
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, 64*4*4)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model1 = SimpleCNN()
model2 = SimpleCNN()

In [91]:
def average_model_parameters(models, average_weight):
    avg_model = SimpleCNN()
    model_params = [list(model.parameters()) for model in models]
    
    avg_model_params = list(avg_model.parameters())
    
    for param_idx in range(len(model_params[0])):
        avg_param = sum([average_weight[i] * model_params[i][param_idx].data for i in range(len(models))])
        avg_model_params[param_idx].data.copy_(avg_param)
        
    return avg_model


In [92]:
def update_model_parameters(model: keras.Model, new_weights: list):
    if len(new_weights) != len(model.get_weights()):
        raise ValueError("The number of new weights must match the number of layers in the model.")

    model.set_weights(new_weights)

In [93]:
def update_model_parameters(model, params):
    for param, new_param in zip(model.parameters(), params):
        param.data.copy_(new_param)

In [103]:
def train_local_model(model, train_loader, epochs=1, lr=0.01):
    optimizer = optim.SGD(model.parameters(), lr=lr)
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0 
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            if batch_idx % 10 == 9:
                print(f'Epoch [{epoch+1}], Batch [{batch_idx+1}/{len(train_loader)}], Loss: {running_loss / 10:.4f}')
                running_loss = 0.0

def federated_averaging(models, train_loaders, global_epochs=10, local_epochs=1, lr=0.01, avg_weight=[0.5, 0.5]):
    global_model = SimpleCNN()
    
    for round in range(global_epochs):
        for i, model in enumerate(models):
            train_local_model(model, train_loaders[i], epochs=local_epochs, lr=lr)
        
        global_model = average_model_parameters(models, avg_weight)
        
        for model in models:
            update_model_parameters(model, global_model.parameters())
        
        print(f'Round {round+1} complete')

federated_averaging([model1, model2], [train_loader1, train_loader2], global_epochs=120)

Epoch [1], Batch [10/12], Loss: 0.0183
Epoch [1], Batch [10/12], Loss: 0.0299
Round 1 complete
Epoch [1], Batch [10/12], Loss: 0.0193
Epoch [1], Batch [10/12], Loss: 0.0317
Round 2 complete
Epoch [1], Batch [10/12], Loss: 0.0191
Epoch [1], Batch [10/12], Loss: 0.0318
Round 3 complete
Epoch [1], Batch [10/12], Loss: 0.0192
Epoch [1], Batch [10/12], Loss: 0.0311
Round 4 complete
Epoch [1], Batch [10/12], Loss: 0.0193
Epoch [1], Batch [10/12], Loss: 0.0260
Round 5 complete
Epoch [1], Batch [10/12], Loss: 0.0206
Epoch [1], Batch [10/12], Loss: 0.0298
Round 6 complete
Epoch [1], Batch [10/12], Loss: 0.0195
Epoch [1], Batch [10/12], Loss: 0.0323
Round 7 complete
Epoch [1], Batch [10/12], Loss: 0.0176
Epoch [1], Batch [10/12], Loss: 0.0315
Round 8 complete
Epoch [1], Batch [10/12], Loss: 0.0165
Epoch [1], Batch [10/12], Loss: 0.0272
Round 9 complete
Epoch [1], Batch [10/12], Loss: 0.0181
Epoch [1], Batch [10/12], Loss: 0.0286
Round 10 complete
Epoch [1], Batch [10/12], Loss: 0.0181
Epoch [1],

In [104]:
def test_model(model, test_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    accuracy = correct / len(test_loader.dataset)
    return accuracy

test_loader = DataLoader(datasets.MNIST('./data', train=False, transform=transform), batch_size=1000, shuffle=False)

accuracy1 = test_model(model1, test_loader)
accuracy2 = test_model(model2, test_loader)

print(f'Model 1 accuracy: {accuracy1 * 100:.2f}%')
print(f'Model 2 accuracy: {accuracy2 * 100:.2f}%')

Model 1 accuracy: 95.37%
Model 2 accuracy: 95.37%
