In [1]:
# from google.colab import drive
# drive.mount('/content/drive')
# %cd "/content/drive/MyDrive/ColabTemp"

In [3]:
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
from torch.utils.data import DataLoader
from torch import nn
from collections import OrderedDict
import random
import copy
import time
import numpy as np

In [3]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using '{device}' device")

Using 'cuda' device


In [4]:
mySeed = 42
torch.manual_seed(mySeed)
np.random.seed(mySeed)
random.seed(mySeed)
# tf.random.set_seed(mySeed)

In [5]:
# Hyper parameters

num_clients = 3
batch_size = 64
total_steps = 10
client_select_percentage = 1

learning_rate = 0.001
loss_fn = nn.CrossEntropyLoss()
client_epochs = 2

swap_step = 2
n_swap_bet_avg_p1 = 2 # p1=plus one to your number, if need 2 swap between avg, enter 3

remain = 1 # Remove some data for running faster in test

In [6]:
client_selects = None
client_weights = None

is_print_eval = False
start_bold = "\033[1m"
end_bold = "\033[0;0m"

## Load Data

In [7]:
# Download dataset
train_data = datasets.CIFAR10(
    root="../datasets",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10).scatter_(dim=0, index=torch.tensor(y), value=1)),
)

test_data = datasets.CIFAR10(
    root="../datasets",
    train=False,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10).scatter_(dim=0, index=torch.tensor(y), value=1)),
)

print(len(train_data))
print(train_data[0][0].shape)
print(train_data[0][1])

Files already downloaded and verified
Files already downloaded and verified
50000
torch.Size([3, 32, 32])
tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 0.])


In [8]:
# Remove some data for running faster in test
print(len(train_data))
train_data = torch.utils.data.Subset(train_data, range(0, int(len(train_data)*remain)))
print(len(train_data))

50000
50000


In [9]:
### Random dataset split
client_data_size = np.array([len(train_data)//num_clients]*num_clients)
data_remain = len(train_data) % num_clients
for i in range(data_remain):
    client_data_size[-1-i] += 1

client_datasets = torch.utils.data.random_split(train_data, client_data_size)

### None random dataset split
# client_datasets = list()
# i = 0
# for j in client_data_size:
#     client_datasets.append(torch.utils.data.Subset(train_data, range(i, i+j)))
#     i += j

In [10]:
# Create dataloader for each client
client_dataloaders = np.zeros(num_clients, dtype=object)
for i, dataset in enumerate(client_datasets):
    client_dataloaders[i] = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

test_dataloader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True)

## Training

In [11]:
def calc_out_conv_max_layers(in_w, in_h, kernels, strides, paddings=None, dilations=None):
    # In MaxPool2d, strides must same with kernels
    if paddings == None:
        paddings = np.zeros(len(kernels))
    
    if dilations == None:
        dilations = np.ones(len(kernels))
    
    out_w = in_w
    out_h = in_h
    for ker, pad, dil, stri in zip(kernels, paddings, dilations, strides):
        out_w = np.floor((out_w + 2*pad - dil * (ker-1) - 1)/stri + 1)
        out_h = np.floor((out_h + 2*pad - dil * (ker-1) - 1)/stri + 1)

    return int(out_w), int(out_h)

In [12]:
# Define Model
input_flat_size = torch.flatten(train_data[0][0]).shape[0]
nClasses = train_data[0][1].shape[0]

class NeuralNetworkCifar10MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(input_flat_size, 256)),
            ('relu1', nn.ReLU()),
            ('fc2', nn.Linear(256, 128)),
            ('relu2', nn.ReLU()),
            ('fc3', nn.Linear(128, 64)),
            ('relu3', nn.ReLU()),
            ('fc4', nn.Linear(64, nClasses)),
        ]))
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        probs = self.softmax(logits)
        return probs
    
    def get_weights(self):
        return list(self.parameters())
    
    def set_weights(self, parameters_list):
        for i, param in enumerate(self.parameters()):
            param.data = parameters_list[i].data



input_channels = train_data[0][0].shape[0]
conv_l1_out = 32
conv_kernel = 3
max_kernel = 2
kernels = [conv_kernel, max_kernel, conv_kernel, max_kernel]
strides = [1, max_kernel, 1, max_kernel]
out_w, out_h = calc_out_conv_max_layers(conv_l1_out, conv_l1_out, kernels, strides)

class NeuralNetworkCifar10Conv(nn.Module):
    def __init__(self):
        super().__init__()
        self.features_stack = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(input_channels, conv_l1_out, kernel_size=conv_kernel, stride=1, padding=0)),
            ('relu1', nn.ReLU(inplace=True)),
            ('pool1', nn.MaxPool2d(kernel_size=max_kernel)),
            ('conv2', nn.Conv2d(conv_l1_out, 64, kernel_size=conv_kernel)),
            ('relu2', nn.ReLU(inplace=True)),
            ('pool2', nn.MaxPool2d(kernel_size=max_kernel)),
            ('flat', nn.Flatten()),
            ('fc1', nn.Linear(64*out_w*out_h, 100)),
            ('relu3', nn.ReLU(inplace=True)),
            ('fc2', nn.Linear(100, nClasses)),
        ]))
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        logits = self.features_stack(x)
        probs = self.softmax(logits)
        return probs

    def get_weights(self):
        return list(self.parameters())
    
    def set_weights(self, parameters_list):
        for i, param in enumerate(self.parameters()):
            param.data = parameters_list[i].data

In [13]:
# global_model = NeuralNetworkCifar10MLP().to(device)
global_model = NeuralNetworkCifar10Conv().to(device)
print(global_model)

NeuralNetworkConv(
  (features_stack): Sequential(
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
    (relu1): ReLU(inplace=True)
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (relu2): ReLU(inplace=True)
    (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (flat): Flatten(start_dim=1, end_dim=-1)
    (fc1): Linear(in_features=2304, out_features=100, bias=True)
    (relu3): ReLU(inplace=True)
    (fc2): Linear(in_features=100, out_features=10, bias=True)
  )
  (softmax): Softmax(dim=1)
)


In [14]:
def select_clients_and_assign_weights(global_weights):
    global client_selects
    global client_weights

    lst = np.arange(0, num_clients)
    np.random.shuffle(lst)
    client_selects = lst[: int(len(lst)*client_select_percentage)]

    client_weights = {i: copy.deepcopy(global_weights)  for i in client_selects}

In [15]:
global_weights = global_model.get_weights()
select_clients_and_assign_weights(global_weights)

global_history = {"times": {"train":list(), "swap":list(), "step":list()},
                  "accuracy": list(),
                  "loss": list()}

In [16]:
def scale_model_weights(weights, scalar):
    """ Scale the model weights """

    scaled_weights = list()
    for i in range(len(weights)):
        scaled_weights.append(weights[i] * scalar)

    return scaled_weights

In [17]:
def sum_scaled_weights(client_scaled_weights):
    """ Return the sum of the listed scaled weights.
        axis_O is equivalent to the average weight of the weights """

    avg_weights = list()
    # get the average gradient accross all client gradients
    for gradient_list_tuple in zip(*client_scaled_weights):
        gradient_list_tuple = [tensor.tolist()  for tensor in gradient_list_tuple]
        layer_mean = torch.sum(torch.tensor(gradient_list_tuple), axis=0).to(device)
        avg_weights.append(layer_mean)

    return avg_weights


### Explaining the function with example ###
# t = (torch.tensor([[[2, 3],[3, 4]], [[3, 4],[4, 5]], [[4, 5],[5, 6]]]),
#      torch.tensor([[[5, 6],[6, 7]], [[6, 7],[7, 8]], [[7, 8],[8, 9]]]))
# t = [i.tolist() for i in t]
# for y in zip(*t):
#     print(y)
#     print(torch.sum(torch.tensor(y), axis=0))

In [18]:
def fed_avg():
    # calculate total training data across clients
    global_count = 0
    for client in client_selects:
        global_count += len(client_dataloaders[client].dataset)

    # initial list to collect clients weight after scalling
    client_scaled_weights = list()
    for client in client_selects:
        local_count = len(client_dataloaders[client].dataset)
        scaling_factor = local_count / global_count
        scaled_weights = scale_model_weights(client_weights[client], scaling_factor)
        client_scaled_weights.append(scaled_weights)

    # to get the average over all the clients model, we simply take the sum of the scaled weights
    avg_weights = sum_scaled_weights(client_scaled_weights)

    return avg_weights

In [19]:
def fed_swap(client):
    random_num = random.randint(0, len(client_selects)-1)
    random_client = client_selects[random_num]

    temp_weights = client_weights[random_client]
    client_weights[random_client] = client_weights[client]

    return temp_weights

In [20]:
def test_neural_network(dataloader, model, loss_fn):
    data_size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct_items = 0, 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            pred = model(x)
            test_loss += loss_fn(pred, y).item()
            correct_items += (pred.argmax(1) == y.argmax(1)).sum().item()

    avg_loss = test_loss / num_batches
    accuracy = correct_items / data_size
    # print(f"Test Error: \nAccuracy: {(accuracy*100):>0.1f}%, Loss: {avg_loss:>8f}\n")

    return accuracy, avg_loss

In [21]:
def train_neural_network(dataloader, model, loss_fn, optimizer):
    data_size = len(dataloader.dataset)
    running_loss = 0

    for batch, (x, y) in enumerate(dataloader):
        x, y = x.to(device), y.to(device)

        # Compute prediction error
        pred = model(x)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        running_loss += loss.item()
        print_step = np.ceil(len(dataloader)/10)
        if batch % print_step == 0:
            loss_per_batch = running_loss / print_step
            current_item = (batch+1)*len(x)
            # print(f"loss: {loss_per_batch:>7f}  [{current_item:>5d}/{data_size:>5d}]")
            running_loss = 0

In [22]:
def train_clinet(dataloader, model):
    optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)
    for epoch in range(client_epochs):
        train_neural_network(dataloader, model, loss_fn, optimizer)

In [None]:
def change_time_format(seconds):
    m, s = divmod(seconds, 60)
    h, m = divmod(m, 60)

    if h:
        return f"{h:.0f}h-{m:.0f}m-{s:.0f}s"
    elif m:
        return f"{m:.0f}m-{s:.0f}s"
    else:
        return f"{s:.2f}s"

In [None]:
def print_log(training_time, swapping_time, step_time, step, metric_index=-1):
    training_time = change_time_format(training_time)
    swapping_time = change_time_format(swapping_time)
    step_time = change_time_format(step_time)
    print(f"round: {step} | training_time: {training_time} | swapping_time: {swapping_time} | step_time: {step_time}")

    global is_print_eval
    if is_print_eval:
        is_print_eval = False
        print(f"round: {step} / global_acc: {start_bold}{global_history['accuracy'][metric_index]:.4%}{end_bold} / global_loss: {start_bold}{global_history['loss'][metric_index]:.4f}{end_bold}\n")

In [23]:
for step in range(0, total_steps):
    training_time_start = time.time()
    for client in client_selects:
        # local_model = NeuralNetworkCifar10MLP().to(device)
        local_model = NeuralNetworkCifar10Conv().to(device)
        local_model.set_weights(client_weights[client])
        train_clinet(client_dataloaders[client], local_model)
        client_weights[client] = local_model.get_weights()

        del local_model
    
    training_time = time.time() - training_time_start
    global_history["times"]["train"].append(training_time)


    swapping_time_start = time.time()
    if (step % swap_step == 0) and (step % (swap_step*n_swap_bet_avg_p1) != 0):
        for client in client_selects:
            client_weights[client] = fed_swap(client)
    
    swapping_time = time.time() - swapping_time_start
    global_history["times"]["swap"].append(swapping_time)
    

    if (step % (swap_step*n_swap_bet_avg_p1) == 0):
        avg_weights = fed_avg()
        global_model.set_weights(avg_weights) # update global model
        select_clients_and_assign_weights(avg_weights)

        is_print_eval = True
        # test global model and print out metrics after each communication round
        global_acc, global_loss = test_neural_network(test_dataloader, global_model, loss_fn)
        global_history["accuracy"].append(global_acc)
        global_history["loss"].append(global_loss)
    
    step_time = time.time() - training_time_start
    global_history["times"]["step"].append(step_time)
    print_log(training_time, swapping_time, step_time, step)

loss: 1.151160  [ 1000/16666]
loss: 2.328869  [ 3000/16666]
loss: 2.300258  [ 5000/16666]
loss: 2.296053  [ 7000/16666]
loss: 2.279709  [ 9000/16666]
loss: 2.251862  [11000/16666]
loss: 2.224060  [13000/16666]
loss: 2.216454  [15000/16666]
loss: 2.223981  [11322/16666]
loss: 1.105631  [ 1000/16666]
loss: 2.200179  [ 3000/16666]
loss: 2.198383  [ 5000/16666]
loss: 2.183627  [ 7000/16666]
loss: 2.200420  [ 9000/16666]
loss: 2.176920  [11000/16666]
loss: 2.163659  [13000/16666]
loss: 2.160729  [15000/16666]
loss: 2.146627  [11322/16666]
loss: 1.151361  [ 1000/16667]
loss: 2.299383  [ 3000/16667]
loss: 2.278898  [ 5000/16667]
loss: 2.288036  [ 7000/16667]
loss: 2.264523  [ 9000/16667]
loss: 2.233517  [11000/16667]
loss: 2.231377  [13000/16667]
loss: 2.225331  [15000/16667]
loss: 2.212402  [11339/16667]
loss: 1.106602  [ 1000/16667]


In [None]:
print(client_weights[0][0][0])
print(client_weights[1][0][0])
print(client_weights[2][0][0])

tensor([[[-0.0600,  0.0692, -0.2063],
         [ 0.0159, -0.0937,  0.0617],
         [-0.0744,  0.0236,  0.1050]],

        [[-0.1349, -0.2547,  0.1150],
         [-0.1693, -0.0140,  0.0597],
         [-0.1047,  0.0774,  0.0587]],

        [[-0.1031, -0.0083, -0.0446],
         [ 0.1298,  0.0762, -0.1789],
         [-0.0396,  0.0385,  0.0804]]], device='cuda:0')
tensor([[[-0.0600,  0.0692, -0.2063],
         [ 0.0159, -0.0937,  0.0617],
         [-0.0744,  0.0236,  0.1050]],

        [[-0.1349, -0.2547,  0.1150],
         [-0.1693, -0.0140,  0.0597],
         [-0.1047,  0.0774,  0.0587]],

        [[-0.1031, -0.0083, -0.0446],
         [ 0.1298,  0.0762, -0.1789],
         [-0.0396,  0.0385,  0.0804]]], device='cuda:0')
tensor([[[-0.0600,  0.0692, -0.2063],
         [ 0.0159, -0.0937,  0.0617],
         [-0.0744,  0.0236,  0.1050]],

        [[-0.1349, -0.2547,  0.1150],
         [-0.1693, -0.0140,  0.0597],
         [-0.1047,  0.0774,  0.0587]],

        [[-0.1031, -0.0083, -0.0446],


## Result