In [17]:
# from google.colab import drive
# drive.mount('/content/drive')
# %cd "/content/drive/MyDrive/ColabTemp"

In [18]:
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
from torch.utils.data import DataLoader
from torch import nn
from collections import OrderedDict
import random
import copy
import time
import numpy as np

In [19]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using '{device}' device")

Using 'cuda' device


In [20]:
mySeed = 42
torch.manual_seed(mySeed)
np.random.seed(mySeed)
random.seed(mySeed)
# tf.random.set_seed(mySeed)

In [21]:
# Hyper parameters

num_clients = 3
batch_size = 1000
total_steps = 2

learning_rate = 0.01
loss_fn = nn.CrossEntropyLoss()
client_epochs = 2

remain = 0.1 # Remove some data for running faster in test

## Load Data

In [22]:
# Download dataset
train_data = datasets.CIFAR10(
    root="../datasets",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10).scatter_(dim=0, index=torch.tensor(y), value=1)),
)

test_data = datasets.CIFAR10(
    root="../datasets",
    train=False,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10).scatter_(dim=0, index=torch.tensor(y), value=1)),
)

print(len(train_data))
print(train_data[0][0].shape)
print(train_data[0][1])

Files already downloaded and verified
Files already downloaded and verified
50000
torch.Size([3, 32, 32])
tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 0.])


In [23]:
# Remove some data for running faster in test
print(len(train_data))
train_data = torch.utils.data.Subset(train_data, range(0, int(len(train_data)*remain)))
print(len(train_data))

50000
5000


In [24]:
### Random dataset split
client_data_size = np.array([len(train_data)//num_clients]*num_clients)
data_remain = len(train_data) % num_clients
for i in range(data_remain):
    client_data_size[-1-i] += 1

client_datasets = torch.utils.data.random_split(train_data, client_data_size)

### None random dataset split
# client_datasets = list()
# i = 0
# for j in client_data_size:
#     client_datasets.append(torch.utils.data.Subset(train_data, range(i, i+j)))
#     i += j

In [25]:
# Create dataloader for each client
client_dataloaders = np.zeros(num_clients, dtype=object)
for i, dataset in enumerate(client_datasets):
    client_dataloaders[i] = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

test_dataloader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True)

## Training

In [26]:
# Define Model
input_flat_size = torch.flatten(train_data[0][0]).shape[0]
nClasses = train_data[0][1].shape[0]

class NeuralNetworkMnistMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(input_flat_size, 256)),
            ('relu1', nn.ReLU()),
            ('fc2', nn.Linear(256, 128)),
            ('relu2', nn.ReLU()),
            ('fc3', nn.Linear(128, 64)),
            ('relu3', nn.ReLU()),
            ('fc4', nn.Linear(64, nClasses)),
        ]))
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        probs = self.softmax(logits)
        return probs
    
    def get_weights(self):
        return list(self.parameters())
    
    def set_weights(self, parameters_list):
        for i, param in enumerate(self.parameters()):
            param.data = parameters_list[i].data

In [27]:
global_model = NeuralNetworkMnistMLP().to(device)
print(global_model)

NeuralNetworkMnistMLP(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (fc1): Linear(in_features=3072, out_features=256, bias=True)
    (relu1): ReLU()
    (fc2): Linear(in_features=256, out_features=128, bias=True)
    (relu2): ReLU()
    (fc3): Linear(in_features=128, out_features=64, bias=True)
    (relu3): ReLU()
    (fc4): Linear(in_features=64, out_features=10, bias=True)
  )
  (softmax): Softmax(dim=1)
)


In [28]:
global_weights = global_model.get_weights()
client_selects = torch.arange(0, num_clients)
client_weights = [copy.deepcopy(global_weights)  for _ in range(len(client_selects))]

In [29]:
def train(dataloader, model, loss_fn, optimizer):
    data_size = len(dataloader.dataset)
    running_loss = 0

    for batch, (x, y) in enumerate(dataloader):
        x, y = x.to(device), y.to(device)

        # Compute prediction error
        pred = model(x)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        running_loss += loss.item()
        print_step = np.ceil(len(dataloader)/10)
        if batch % print_step == 0:
            loss_per_batch = running_loss / print_step
            current_item = (batch+1)*len(x)
            print(f"loss: {loss_per_batch:>7f}  [{current_item:>5d}/{data_size:>5d}]")
            running_loss = 0

In [30]:
def train_clinet(dataloader, model):
    optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)
    for epoch in range(client_epochs):
        train(dataloader, model, loss_fn, optimizer)

In [31]:
for step in range(0, total_steps):
    for client in client_selects:
        local_model = NeuralNetworkMnistMLP().to(device)
        local_model.set_weights(client_weights[client])
        train_clinet(client_dataloaders[client], local_model)
        client_weights[client] = local_model.get_weights()

        del local_model


loss: 2.302497  [ 1000/ 1666]
loss: 2.310790  [ 1332/ 1666]
loss: 2.324357  [ 1000/ 1666]
loss: 2.331970  [ 1332/ 1666]
loss: 2.302649  [ 1000/ 1667]
loss: 2.338618  [ 1334/ 1667]
loss: 2.300561  [ 1000/ 1667]
loss: 2.322997  [ 1334/ 1667]
loss: 2.302762  [ 1000/ 1667]
loss: 2.302717  [ 1334/ 1667]
loss: 2.345971  [ 1000/ 1667]
loss: 2.358479  [ 1334/ 1667]
loss: 2.375062  [ 1000/ 1666]
loss: 2.358223  [ 1332/ 1666]
loss: 2.352713  [ 1000/ 1666]
loss: 2.339923  [ 1332/ 1666]
loss: 2.348088  [ 1000/ 1667]
loss: 2.346797  [ 1334/ 1667]
loss: 2.364607  [ 1000/ 1667]
loss: 2.346546  [ 1334/ 1667]
loss: 2.351098  [ 1000/ 1667]
loss: 2.362106  [ 1334/ 1667]
loss: 2.358114  [ 1000/ 1667]
loss: 2.342354  [ 1334/ 1667]


In [32]:
print(client_weights[0][0][0])
print(client_weights[1][0][0])
print(client_weights[2][0][0])

tensor([-0.0056,  0.0065, -0.0187,  ...,  0.0074, -0.0216, -0.0271],
       device='cuda:0', grad_fn=<SelectBackward0>)
tensor([ 0.0081,  0.0240, -0.0023,  ..., -0.0110, -0.0406, -0.0453],
       device='cuda:0', grad_fn=<SelectBackward0>)
tensor([ 0.0030,  0.0145, -0.0104,  ...,  0.0154, -0.0109, -0.0206],
       device='cuda:0', grad_fn=<SelectBackward0>)
