In [74]:
# from google.colab import drive
# drive.mount('/content/drive')
# %cd "/content/drive/MyDrive/ColabTemp"

In [75]:
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
from torch.utils.data import DataLoader
from torch import nn
from collections import OrderedDict
import random
import numpy as np

In [76]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using '{device}' device")

Using 'cuda' device


In [77]:
mySeed = 42
torch.manual_seed(mySeed)
np.random.seed(mySeed)
random.seed(mySeed)
# tf.random.set_seed(mySeed)

In [78]:
# Hyper parameters

num_clients = 5
batch_size = 8

remain = 0.01 # Remove some data for running faster in test

## Load Data

In [79]:
# Download dataset
train_data = datasets.CIFAR10(
    root="../datasets",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10).scatter_(dim=0, index=torch.tensor(y), value=1)),
)

test_data = datasets.CIFAR10(
    root="../datasets",
    train=False,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10).scatter_(dim=0, index=torch.tensor(y), value=1)),
)

print(len(train_data))
print(train_data[0][0].shape)
print(train_data[0][1])

Files already downloaded and verified
Files already downloaded and verified
50000
torch.Size([3, 32, 32])
tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 0.])


In [80]:
# Remove some data for running faster in test
print(len(train_data))
train_data = torch.utils.data.Subset(train_data, range(0, int(len(train_data)*remain)))
print(len(train_data))

50000
500


In [81]:
### Random dataset split
client_data_size = np.array([len(train_data)//num_clients]*num_clients)
data_remain = len(train_data) % num_clients
for i in range(data_remain):
    client_data_size[-1-i] += 1

client_datasets = torch.utils.data.random_split(train_data, client_data_size)

### None random dataset split
# client_datasets = list()
# i = 0
# for j in client_data_size:
#     client_datasets.append(torch.utils.data.Subset(train_data, range(i, i+j)))
#     i += j

In [82]:
# Create dataloader for each client
client_dataloaders = np.zeros(num_clients, dtype=object)
for i, dataset in enumerate(client_datasets):
    client_dataloaders[i] = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

test_dataloader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True)

## Training

In [349]:
# Define Model
input_flat_size = torch.flatten(train_data[0][0]).shape[0]
nClasses = train_data[0][1].shape[0]

class NeuralNetworkMnistMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(input_flat_size, 256)),
            ('relu1', nn.ReLU()),
            ('fc2', nn.Linear(256, 128)),
            ('relu2', nn.ReLU()),
            ('fc3', nn.Linear(128, 64)),
            ('relu3', nn.ReLU()),
            ('fc4', nn.Linear(64, nClasses)),
        ]))
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        probs = self.softmax(logits)
        return probs
    
    def get_weights(self):
        return list(self.parameters())
    
    def set_weights(self, parameters_list):
        for i, param in enumerate(self.parameters()):
            param.data = parameters_list[i].data

In [350]:
model = NeuralNetworkMnistMLP().to(device)
print(model)

NeuralNetworkMnistMLP(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (fc1): Linear(in_features=3072, out_features=256, bias=True)
    (relu1): ReLU()
    (fc2): Linear(in_features=256, out_features=128, bias=True)
    (relu2): ReLU()
    (fc3): Linear(in_features=128, out_features=64, bias=True)
    (relu3): ReLU()
    (fc4): Linear(in_features=64, out_features=10, bias=True)
  )
  (softmax): Softmax(dim=1)
)


In [351]:
global_weights = model.get_weights()
client_selects = torch.arange(0, num_clients)
client_weights = [global_weights]*len(client_selects)

Parameter containing:
tensor([[-0.0130, -0.0093,  0.0144,  ..., -0.0096, -0.0016, -0.0117],
        [ 0.0154,  0.0090,  0.0077,  ...,  0.0140,  0.0155, -0.0038],
        [-0.0115,  0.0137, -0.0087,  ...,  0.0096,  0.0049, -0.0007],
        ...,
        [-0.0061, -0.0133,  0.0085,  ..., -0.0007, -0.0012, -0.0113],
        [ 0.0166, -0.0064, -0.0151,  ...,  0.0163, -0.0075,  0.0139],
        [ 0.0161,  0.0070,  0.0145,  ...,  0.0043, -0.0153, -0.0122]],
       device='cuda:0', requires_grad=True)