In [1]:
# from google.colab import drive
# drive.mount('/content/drive')
# %cd "/content/drive/MyDrive/ColabTemp"

In [2]:
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
from torch.utils.data import DataLoader
from torch import nn
from collections import OrderedDict
import random
import copy
import time
import numpy as np

In [3]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using '{device}' device")

Using 'cuda' device


In [4]:
mySeed = 42
torch.manual_seed(mySeed)
np.random.seed(mySeed)
random.seed(mySeed)
# tf.random.set_seed(mySeed)

In [5]:
# Hyper parameters

num_clients = 3
batch_size = 1000
total_steps = 2
client_select_percentage = 1

learning_rate = 0.01
loss_fn = nn.CrossEntropyLoss()
client_epochs = 2

swap_step = 2
n_swap_bet_avg_p1 = 3 # p1=plus one to your number, if need 2 swap between avg, enter 3

remain = 1 # Remove some data for running faster in test

In [6]:
client_selects = None
client_weights = None

## Load Data

In [7]:
# Download dataset
train_data = datasets.CIFAR10(
    root="../datasets",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10).scatter_(dim=0, index=torch.tensor(y), value=1)),
)

test_data = datasets.CIFAR10(
    root="../datasets",
    train=False,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10).scatter_(dim=0, index=torch.tensor(y), value=1)),
)

print(len(train_data))
print(train_data[0][0].shape)
print(train_data[0][1])

Files already downloaded and verified
Files already downloaded and verified
50000
torch.Size([3, 32, 32])
tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 0.])


In [8]:
# Remove some data for running faster in test
print(len(train_data))
train_data = torch.utils.data.Subset(train_data, range(0, int(len(train_data)*remain)))
print(len(train_data))

50000
50000


In [9]:
### Random dataset split
client_data_size = np.array([len(train_data)//num_clients]*num_clients)
data_remain = len(train_data) % num_clients
for i in range(data_remain):
    client_data_size[-1-i] += 1

client_datasets = torch.utils.data.random_split(train_data, client_data_size)

### None random dataset split
# client_datasets = list()
# i = 0
# for j in client_data_size:
#     client_datasets.append(torch.utils.data.Subset(train_data, range(i, i+j)))
#     i += j

In [10]:
# Create dataloader for each client
client_dataloaders = np.zeros(num_clients, dtype=object)
for i, dataset in enumerate(client_datasets):
    client_dataloaders[i] = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

test_dataloader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True)

## Training

In [11]:
# Define Model
input_flat_size = torch.flatten(train_data[0][0]).shape[0]
nClasses = train_data[0][1].shape[0]

class NeuralNetworkMnistMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(input_flat_size, 256)),
            ('relu1', nn.ReLU()),
            ('fc2', nn.Linear(256, 128)),
            ('relu2', nn.ReLU()),
            ('fc3', nn.Linear(128, 64)),
            ('relu3', nn.ReLU()),
            ('fc4', nn.Linear(64, nClasses)),
        ]))
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        probs = self.softmax(logits)
        return probs
    
    def get_weights(self):
        return list(self.parameters())
    
    def set_weights(self, parameters_list):
        for i, param in enumerate(self.parameters()):
            param.data = parameters_list[i].data

In [12]:
global_model = NeuralNetworkMnistMLP().to(device)
print(global_model)

NeuralNetworkMnistMLP(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (fc1): Linear(in_features=3072, out_features=256, bias=True)
    (relu1): ReLU()
    (fc2): Linear(in_features=256, out_features=128, bias=True)
    (relu2): ReLU()
    (fc3): Linear(in_features=128, out_features=64, bias=True)
    (relu3): ReLU()
    (fc4): Linear(in_features=64, out_features=10, bias=True)
  )
  (softmax): Softmax(dim=1)
)


In [13]:
def select_clients_and_assign_weights(global_weights):
    global client_selects
    global client_weights

    lst = np.arange(0, num_clients)
    np.random.shuffle(lst)
    client_selects = lst[: int(len(lst)*client_select_percentage)]

    client_weights = {i: copy.deepcopy(global_weights)  for i in client_selects}

In [14]:
global_weights = global_model.get_weights()
select_clients_and_assign_weights(global_weights)

In [15]:
def scale_model_weights(weights, scalar):
    """ Scale the model weights """

    scaled_weights = list()
    for i in range(len(weights)):
        scaled_weights.append(weights[i] * scalar)

    return scaled_weights

In [16]:
def sum_scaled_weights(client_scaled_weights):
    """ Return the sum of the listed scaled weights.
        axis_O is equivalent to the average weight of the weights """

    avg_weights = list()
    # get the average gradient accross all client gradients
    for gradient_list_tuple in zip(*client_scaled_weights):
        gradient_list_tuple = [tensor.tolist()  for tensor in gradient_list_tuple]
        layer_mean = torch.sum(torch.tensor(gradient_list_tuple), axis=0).to(device)
        avg_weights.append(layer_mean)

    return avg_weights


### Explaining the function with example ###
# t = (torch.tensor([[[2, 3],[3, 4]], [[3, 4],[4, 5]], [[4, 5],[5, 6]]]),
#      torch.tensor([[[5, 6],[6, 7]], [[6, 7],[7, 8]], [[7, 8],[8, 9]]]))
# t = [i.tolist() for i in t]
# for y in zip(*t):
#     print(y)
#     print(torch.sum(torch.tensor(y), axis=0))

In [17]:
def fed_avg():
    # calculate total training data across clients
    global_count = 0
    for client in client_selects:
        global_count += len(client_dataloaders[client].dataset)

    # initial list to collect clients weight after scalling
    client_scaled_weights = list()
    for client in client_selects:
        local_count = len(client_dataloaders[client].dataset)
        scaling_factor = local_count / global_count
        scaled_weights = scale_model_weights(client_weights[client], scaling_factor)
        client_scaled_weights.append(scaled_weights)

    # to get the average over all the clients model, we simply take the sum of the scaled weights
    avg_weights = sum_scaled_weights(client_scaled_weights)

    return avg_weights

In [18]:
def test_neural_network(dataloader, model, loss_fn):
    data_size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct_items = 0, 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            pred = model(x)
            test_loss += loss_fn(pred, y).item()
            correct_items += (pred.argmax(1) == y.argmax(1)).sum().item()

    avg_loss = test_loss / num_batches
    accuracy = correct_items / data_size
    print(f"Test Error: \nAccuracy: {(accuracy*100):>0.1f}%, Loss: {avg_loss:>8f}\n")

    return accuracy, avg_loss

In [19]:
def train_neural_network(dataloader, model, loss_fn, optimizer):
    data_size = len(dataloader.dataset)
    running_loss = 0

    for batch, (x, y) in enumerate(dataloader):
        x, y = x.to(device), y.to(device)

        # Compute prediction error
        pred = model(x)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        running_loss += loss.item()
        print_step = np.ceil(len(dataloader)/10)
        if batch % print_step == 0:
            loss_per_batch = running_loss / print_step
            current_item = (batch+1)*len(x)
            print(f"loss: {loss_per_batch:>7f}  [{current_item:>5d}/{data_size:>5d}]")
            running_loss = 0

In [20]:
def train_clinet(dataloader, model):
    optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)
    for epoch in range(client_epochs):
        train_neural_network(dataloader, model, loss_fn, optimizer)

In [21]:
for step in range(0, total_steps):
    for client in client_selects:
        local_model = NeuralNetworkMnistMLP().to(device)
        local_model.set_weights(client_weights[client])
        train_clinet(client_dataloaders[client], local_model)
        client_weights[client] = local_model.get_weights()

        del local_model
    
    if (step % swap_step == 0) and (step % (swap_step*n_swap_bet_avg_p1) != 0):
        pass
    
    if (step % (swap_step*n_swap_bet_avg_p1) == 0):
        avg_weights = fed_avg()
        global_model.set_weights(avg_weights) # update global model
        select_clients_and_assign_weights(avg_weights)

        # test global model and print out metrics after each communication round
        global_acc, global_loss = test_neural_network(test_dataloader, global_model, loss_fn)

loss: 1.151384  [ 1000/16666]
loss: 2.323791  [ 3000/16666]
loss: 2.349789  [ 5000/16666]
loss: 2.366760  [ 7000/16666]
loss: 2.351067  [ 9000/16666]
loss: 2.369150  [11000/16666]
loss: 2.363650  [13000/16666]
loss: 2.365150  [15000/16666]
loss: 2.351346  [11322/16666]
loss: 1.188575  [ 1000/16666]
loss: 2.371650  [ 3000/16666]
loss: 2.357650  [ 5000/16666]
loss: 2.360150  [ 7000/16666]
loss: 2.357650  [ 9000/16666]
loss: 2.363650  [11000/16666]
loss: 2.351650  [13000/16666]
loss: 2.363150  [15000/16666]
loss: 2.368602  [11322/16666]
loss: 1.151338  [ 1000/16667]
loss: 2.325502  [ 3000/16667]
loss: 2.338013  [ 5000/16667]
loss: 2.349757  [ 7000/16667]
loss: 2.367121  [ 9000/16667]
loss: 2.353091  [11000/16667]
loss: 2.335261  [13000/16667]
loss: 2.337777  [15000/16667]
loss: 2.356467  [11339/16667]
loss: 1.187067  [ 1000/16667]
loss: 2.375157  [ 3000/16667]
loss: 2.367133  [ 5000/16667]
loss: 2.352149  [ 7000/16667]
loss: 2.365149  [ 9000/16667]
loss: 2.366650  [11000/16667]
loss: 2.35

In [22]:
print(client_weights[0][0][0])
print(client_weights[1][0][0])
print(client_weights[2][0][0])

tensor([-0.0501, -0.0390, -0.0499,  ..., -0.0143, -0.0179, -0.0232],
       device='cuda:0', grad_fn=<SelectBackward0>)
tensor([-0.0305, -0.0242, -0.0409,  ..., -0.0089, -0.0124, -0.0168],
       device='cuda:0', grad_fn=<SelectBackward0>)
tensor([-0.0155, -0.0046, -0.0155,  ...,  0.0180,  0.0144,  0.0096],
       device='cuda:0', grad_fn=<SelectBackward0>)
