In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.utils.data as data
import random
import matplotlib.pyplot as plt

# define the global parameters

In [52]:
clients = 15  # the total number of clients
communicationRounds = 10  # the number of communications between the edges and the global model
numOfEdges = 2 # the number of edge-aggregators
edgeEpochs = 1 # the number of communications between an edge and its clients
numEpochs = 5  # the number of local epochs on clients
lr = 0.1  # training rate
localBatchSize = 10  # batch size

# load and process the MNIST data set

In [53]:
# download the data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainData = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testData = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testLoader = torch.utils.data.DataLoader(testData, batch_size=10, shuffle=False)
# perform a shuffle to ensure IID
indices = list(range(len(trainData)))
random.shuffle(indices)
dataShuffled = [trainData[i] for i in indices]
# calculating the size of the samples to ensure every client gets the same size
subsetSize = len(trainData) // clients
# create a subset for each client
clientData = [dataShuffled[i * subsetSize:(i + 1) * subsetSize] for i in range(clients)]

# define the global model

In [54]:
globalModel = nn.Sequential(
    nn.Conv2d(1, 32, kernel_size=3),
    nn.ReLU(),
    nn.Conv2d(32, 64, kernel_size=3),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Flatten(),
    nn.Linear(64 * 12 * 12, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

# define the edge-models

In [55]:
edgeModels = []
for _ in range(numOfEdges):
    edgeModel = nn.Sequential(
        nn.Conv2d(1, 32, kernel_size=3),
        nn.ReLU(),
        nn.Conv2d(32, 64, kernel_size=3),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Flatten(),
        nn.Linear(64 * 12 * 12, 128),
        nn.ReLU(),
        nn.Linear(128, 10)
    )
    edgeModels.append(edgeModel)

# training

In [56]:
accuracies = []
# iterate over all rounds
for rnd in range(communicationRounds):
    # print round number
    print(f"Cloud Communication Round: {rnd + 1}")
    # to save edges parameters
    params = []
    # iterate over all edge models
    for i, edgeModel in enumerate(edgeModels):
        # calculate the start and end indices of the clients on the current edge
        startClient = (i * (clients // numOfEdges))
        endClient = ((i + 1) * (clients // numOfEdges))
        # select clients for the current edge
        selectedClients = list(range(startClient, endClient))
        # to save local clients models parameters within the current edge
        edgeParams = []
        # iterate over selected clients on the current edge
        for edgeRound in range(edgeEpochs):
            print(f"Edge Communication Round: {edgeRound + 1} on edge {i+1}")
            for client in selectedClients:
                # local model
                clientModel = nn.Sequential(*edgeModel)
                # load the data to the model
                trainLoader = data.DataLoader(clientData[client], batch_size=localBatchSize, shuffle=False)
                # client optimizer and loss function
                criterion = nn.CrossEntropyLoss()
                optimizer = optim.SGD(clientModel.parameters(), lr=lr)
                # train model on client k for the defined local epochs
                for epoch in range(numEpochs):
                    # perform forward pass and backpropagation, updating the local model
                    for inputs, labels in trainLoader:
                        optimizer.zero_grad()
                        outputs = clientModel(inputs)
                        loss = criterion(outputs, labels)
                        loss.backward()
                        optimizer.step()
                # print client's k training progress
                print(f"Client {client + 1} belonging to edge {i + 1}, finished training for: {numEpochs} epoch(s)")
                # aggregate local models parameters of the clients of the current edge
                local_params = [param.data for param in clientModel.parameters()]
                edgeParams.append(local_params)
        # average the parameters from all local models within the current edge
        averagedParameters = [
            torch.mean(torch.stack(params), dim=0) for params in zip(*edgeParams)
        ]
        # update the model on the current edge
        for edgeParam, avgParam in zip(edgeModel.parameters(), averagedParameters):
            edgeParam.data = avgParam
    # aggregate all edge models
    averagedParameters = [
        torch.mean(torch.stack(params), dim=0) for params in zip(*(edgeModel.parameters() for edgeModel in edgeModels))
    ]
    # update the global model parameters
    for globalParam, avgParam in zip(globalModel.parameters(), averagedParameters):
        globalParam.data = avgParam
    # send the global model parameters back to the edges for the next communication round
    for edgeModel in edgeModels:
        edgeModel.load_state_dict(globalModel.state_dict())
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testLoader:
            outputs = globalModel(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    accuracies.append(accuracy)

print(f"Training is finished after {communicationRounds} communication rounds on a total of {clients} clients clustered across {numOfEdges} edges!")

# Plot the evolution of global model accuracy
plt.plot(range(1, communicationRounds + 1), accuracies)
plt.xlabel("Communication Round")
plt.ylabel("Accuracy (%)")
plt.title("Evolution of Global Model Accuracy")
plt.show()

Cloud Communication Round: 1
Edge Communication Round: 1 on edge 1
Client 1 belonging to edge 1, finished training for: 5 epoch(s)
Client 2 belonging to edge 1, finished training for: 5 epoch(s)
Client 3 belonging to edge 1, finished training for: 5 epoch(s)
Client 4 belonging to edge 1, finished training for: 5 epoch(s)
Client 5 belonging to edge 1, finished training for: 5 epoch(s)
Client 6 belonging to edge 1, finished training for: 5 epoch(s)
Client 7 belonging to edge 1, finished training for: 5 epoch(s)
Edge Communication Round: 1 on edge 2
Client 8 belonging to edge 2, finished training for: 5 epoch(s)
Client 9 belonging to edge 2, finished training for: 5 epoch(s)
Client 10 belonging to edge 2, finished training for: 5 epoch(s)
Client 11 belonging to edge 2, finished training for: 5 epoch(s)
Client 12 belonging to edge 2, finished training for: 5 epoch(s)
Client 13 belonging to edge 2, finished training for: 5 epoch(s)
Client 14 belonging to edge 2, finished training for: 5 epo

KeyboardInterrupt: 

# testing data

In [46]:
device = torch.device("cpu")
correct = 0
total = 0
with torch.no_grad():
    for images, labels in testLoader:
        images, labels = images.to(device), labels.to(device)
        outputs = globalModel(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy}%")

Test Accuracy: 99.18%
