In [9]:
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Normalize, Compose
from torchvision.datasets import MNIST
from torch.utils.data import random_split

import torch.nn as nn
import torch.nn.functional as F
import copy
#define global variables: FL Node amount, batch size, learning rate, epochs
global FL_NODE_AMOUNT 
global BATCH_64
global BATCH_128
global LEARNING_RATE
global EPOCHS

FL_NODE_AMOUNT = 5
BATCH_64 = 64
BATCH_128 = 128
LEARNING_RATE = 0.01
EPOCHS = 8


grab dataset from MNIST and use random_split to seperate the dataset into parts for federated learning.

In [3]:
def get_mnist(data_path: str = './data'):
    '''This function downloads the MNIST dataset into the `data_path`
    directory if it is not there already. WE construct the train/test
    split by converting the images into tensors and normalising them'''
    
    # transformation to convert images to tensors and apply normalisation
    tr = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])

    # prepare train and test set
    trainset = MNIST(data_path, train=True, download=True, transform=tr)
    testset = MNIST(data_path, train=False, download=True, transform=tr)

    return trainset, testset

def prepare_dataset(num_partitions: int,
                    batch_size: int,
                    val_ratio: float = 0.1):

    """This function partitions the training set into N disjoint
    subsets, each will become the local dataset of a client. This
    function also subsequently partitions each traininset partition
    into train and validation. The test set is left intact and will
    be used by the central server to asses the performance of the
    global model. """

    # get the MNIST dataset
    trainset, testset = get_mnist()

    # split trainset into `num_partitions` trainsets
    num_images = len(trainset) // num_partitions

    partition_len = [num_images] * num_partitions

    trainsets = random_split(trainset, partition_len, torch.Generator().manual_seed(2023))

    # create dataloaders with train+val support
    trainloaders = []
    valloaders = []
    for trainset_ in trainsets:
        num_total = len(trainset_)
        num_val = int(val_ratio * num_total)
        num_train = num_total - num_val

        for_train, for_val = random_split(trainset_, [num_train, num_val], torch.Generator().manual_seed(2023))

        trainloaders.append(DataLoader(for_train, batch_size=batch_size, shuffle=True, num_workers=2))
        valloaders.append(DataLoader(for_val, batch_size=batch_size, shuffle=False, num_workers=2))

    # create dataloader for the test set
    testloader = DataLoader(testset, batch_size=128)

    return trainloaders, valloaders, testloader




class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x

def train(net, trainloader, optimizer, epochs):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    net.train()
    for _ in range(epochs):
        model_loss = 0.0
        for images, labels in trainloader:
            optimizer.zero_grad()
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
            model_loss += loss.item() / len(trainloader)
        #set the loss to have 4 decimal places
        model_loss = round(model_loss, 4)
    return net, model_loss

def test(net, testloader):
    """Validate the network on the entire test set."""
    
    correct = 0
    net.eval()
    with torch.no_grad():
        for images, labels in testloader:
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()
    accuracy = correct / len(testloader.dataset)
    return accuracy


def run_centralised(trainloader, testloader, epochs: int, lr: float, momentum: float=0.9):
    """A minimal (but complete) training loop"""
    # instantiate the model
    model = Net()
    print("Model initialised")
    # define optimiser with hyperparameters supplied
    optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)

    # get dataset and construct a dataloaders
    
    print("@train test loaders all clear@")
    # train for the specified number of 
    print("training process stating...")
    trained_model, loss = train(model, trainloader, optim, epochs)
    print("@training completed@")
    # training is completed, then evaluate model on the test set
    print("testing process starting ...")
    accuracy = test(trained_model, testloader)
    print("@testing  completed@")
    print(f"{loss = }")
    print(f"{accuracy = }")
    return loss, accuracy, trained_model
    

In [19]:
def weighted_federated_averaging(model_sets, weights):
    """
    Weighted federated averaging.

    Args:
        model_sets (list of model): The list containing the models from each client.
        weights (list of float): The list containing the weights for each model, which could be based on their loss or accuracy.

    Returns:
        global_model (model): The global model after weighted federated averaging.
    """

    global_model = Net()

    # Normalize weights
    total_weight = sum(weights)
    normalized_weights = [weight / total_weight for weight in weights]

    # Ensure the global model parameters are initialized to zero
    for global_param in global_model.parameters():
        global_param.data *= 0 

    # Accumulate weighted parameters from each model
    for model, weight in zip(model_sets, normalized_weights):
        for global_param, model_param in zip(global_model.parameters(), model.parameters()):
            global_param.data += model_param.data * weight
            
    return global_model

In [None]:
trainset, testset = get_mnist()


print("tradition CNN model for MNIST STARTS")

trainloader = DataLoader(trainset, batch_size=BATCH_64, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=BATCH_128)

loss, accracy, trained_model= run_centralised(trainloader,testloader,epochs=EPOCHS, lr=LEARNING_RATE)


print("\n------------------------------------\n")

In [12]:

print("FL process with CNN for MNIST STARTS")
trainloaders, valloaders, testloader = prepare_dataset(num_partitions=FL_NODE_AMOUNT, batch_size=BATCH_64)
print("@Data preparation completed@")

model_sets =[]
accuracy_sets = []
loss_sets = [] 
for client_index in range(FL_NODE_AMOUNT):
    print(f"Training client {client_index}")
    
    trainloader = trainloaders[client_index]
    # partial_model = Net(num_classes=10)
    # partial_optimizer = torch.optim.SGD(partial_model.parameters(), lr=LEARNING_RATE, momentum=0.9)
    # partial_model = train(trained_model, trainloader,partial_optimizer , EPOCHS)
    loss,accracy,partial_model = run_centralised(trainloader, testloader, epochs=EPOCHS, lr=LEARNING_RATE)
    model_sets.append(partial_model)
    accuracy_sets.append(accracy)
    loss_sets.append(loss)
    
    
    print(f"Client {client_index} training completed")
    print()


FL process with CNN for MNIST STARTS
@Data preparation completed@
Training client 0
Model initialised
@train test loaders all clear@
training process stating...
@training completed@
testing process starting ...
@testing  completed@
loss = 0.0246
accuracy = 0.9764
Client 0 training completed

Training client 1
Model initialised
@train test loaders all clear@
training process stating...
@training completed@
testing process starting ...
@testing  completed@
loss = 0.0344
accuracy = 0.9792
Client 1 training completed

Training client 2
Model initialised
@train test loaders all clear@
training process stating...
@training completed@
testing process starting ...
@testing  completed@
loss = 0.037
accuracy = 0.9796
Client 2 training completed

Training client 3
Model initialised
@train test loaders all clear@
training process stating...
@training completed@
testing process starting ...
@testing  completed@
loss = 0.0302
accuracy = 0.9779
Client 3 training completed

Training client 4
Model ini

In [21]:
#finally I get the average of the weights of the trained models
# this is the global model
print(accuracy_sets)
print("Weighted federated averaging Starting ...")
global_model = weighted_federated_averaging(model_sets, accuracy_sets)
print("calculating accuracy of the global model ...")
accuracy = test(global_model, testloader)

print(f"Global model accuracy: {accuracy}")
print("FL process with CNN for MNIST ENDS")

[0.9764, 0.9792, 0.9796, 0.9779, 0.9775]
Weighted federated averaging Starting ...
calculating accuracy of the global model ...
Global model accuracy: 0.1781
FL process with CNN for MNIST ENDS
