Paper: https://arxiv.org/abs/2003.00295

In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.ops import MLP
from PIL import Image
import numpy as np
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100, CIFAR10, MNIST
import copy
from tqdm import tqdm
from collections import OrderedDict
import random
import math
import matplotlib.pyplot as plt
from client_selector import ClientSelector
from data_splitter import DataSplitter

In [None]:
!pip install wandb -qU
import wandb
wandb.login()

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.7/6.7 MB[0m [31m16.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.3/207.3 kB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m289.0/289.0 kB[0m [31m19.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[?25h

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Data

In [None]:
K = 100

params = {
    'K': K,
    'C': 0.1,
    'B': 10,
    'J': 4,
    # 'lr_server': 1e-1,
    'lr_client': 1e-2,
    'momentum': 0,
    'method': 'fedavg',
    'tau': 1e-3,
    'gamma': 0.1,
    'participation': 'uniform',
    'rounds': 2000
}

In [None]:
preprocess = transforms.Compose([
    transforms.RandomCrop((28, 28)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
])

train_dataset = CIFAR100('datasets/cifar100', train=True, transform=preprocess, download=True)
test_dataset = CIFAR100('datasets/cifar100', train=False, transform=preprocess, download=True)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
data_split_params = {
    'K': K,
    'split_method': 'iid'
}

data_splitter = DataSplitter(data_split_params, train_dataset)
client_datasets = data_splitter.split()

In [None]:
client_selector = ClientSelector(params)

## Model

In [None]:
class LeNet5_circa(nn.Module):
    def __init__(self):
        super( LeNet5_circa, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=5)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(4 * 4 * 64, 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 100)

    def forward(self, x):
        x = self.pool(self.conv1(x).relu())
        x = self.pool(self.conv2(x).relu())
        x = torch.flatten(x, 1)
        x = self.fc1(x).relu()
        x = self.fc2(x).relu()
        x = self.fc3(x)

        return x


model = LeNet5_circa().cuda()
model.to('cuda')

criterion = torch.nn.CrossEntropyLoss().cuda()

In [None]:
wandb.init(
    project='fl',
    name=f'fed {data_split_params["split_method"]}, J={params["J"]}, lr={params["lr_client"]}',
    config={**params, **data_split_params}
)

VBox(children=(Label(value='0.011 MB of 0.011 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

In [None]:
path = lambda t: f'/content/drive/My Drive/fl/{data_split_params["split_method"]}-J{params["J"]}-lr{params["lr_client"]}-{t}.pt'

backup = 0
if backup:
    model.load_state_dict(torch.load(path(backup)))
model

LeNet5_circa(
  (conv1): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=1024, out_features=384, bias=True)
  (fc2): Linear(in_features=384, out_features=192, bias=True)
  (fc3): Linear(in_features=192, out_features=100, bias=True)
)

## Utils

In [None]:
def reduce_w(w_list, f):
    return OrderedDict([
            (key, f([x[key] for x in w_list])) for key in w_list[0].keys()
        ])


def tensor_sum(tensors_list):
    return torch.sum(torch.stack(tensors_list), dim=0)


def w_norm2(w):
    res = 0
    for key in w.keys():
        res += torch.linalg.vector_norm(w[key]) ** 2
    return math.sqrt(res)


def fed_adagrad(v, delta, params):
    delta_norm2 = w_norm2(delta)
    return v + delta_norm2


def fed_yogi(v, delta, params):
    delta_norm2 = w_norm2(delta)
    return v - (1-params['beta2']) * delta_norm2 * torch.sign(v - delta_norm2)


def fed_adam(v, delta, params):
    delta_norm2 = w_norm2(delta)
    return params['beta2'] * v + (1-params['beta2']) * delta_norm2


methods = {
    'adagrad': fed_adagrad,
    'yogi': fed_yogi,
    'adam': fed_adam
}

## Training

In [None]:
T = params['rounds']
test_freq = 10
save_freq = 100

In [None]:
def test(model):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    test_loss = test_loss / len(test_loader)
    test_accuracy = 100. * correct / total
    print(f'Test Loss: {test_loss:.6f} Acc: {test_accuracy:.2f}%')
    return test_accuracy, test_loss


def client_update(model, k, params):
    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=params['lr_client'], momentum=params['momentum'])
    loader = DataLoader(client_datasets[k], batch_size=params['B'])

    for i in range(params['J']):
        for batch_idx, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

    return model.state_dict()


def train(model, params):
    accuracies = []
    losses = []
    v = params['tau'] ** 2
    w = model.state_dict()
    m = reduce_w([w], lambda x: torch.mul(x[0], 0.0))
    for t in tqdm(range(backup, T)):
        s = client_selector.sample()

        w_clients = []
        for k in s:
            w_clients.append(client_update(copy.deepcopy(model), k, params))

        if params['method'] == 'fedavg':
            w = reduce_w(
                w_clients,
                lambda x: tensor_sum(x) / len(w_clients)
            )
        else:
            deltas = [
                reduce_w(
                    [w, w_client],
                    lambda x: x[1] - x[0]
                ) for w_client in w_clients
            ]

            # n_weights = [len(client_datasets[k])/len(w_clients) for k in s]
            delta = reduce_w(
                deltas,
                lambda x: tensor_sum(x) / len(deltas)
            )

            m = reduce_w(
                [m, delta],
                lambda x: params['beta1'] * x[0] + (1-params['beta1']) * x[1]
            )

            v = methods[params['method']](v, delta, params)
            w = reduce_w(
                [w, m],
                lambda x: x[0] + params['lr_server'] * x[1] / (math.sqrt(v) + params['tau'])
            )

        model.load_state_dict(w)

        if t % test_freq == 0 or t == T-1:
            acc, loss = test(model)
            accuracies.append(acc)
            losses.append(loss)
            wandb.log({'acc': acc, 'loss': loss, 'round': t})

        if t % save_freq == 0 or t == T-1:
            torch.save(model.state_dict(), path(t))

    return accuracies, losses


accuracies, losses = train(model, params)

  0%|          | 1/1000 [00:17<4:52:41, 17.58s/it]

Test Loss: 4.601394 Acc: 1.86%


  1%|          | 11/1000 [02:24<3:43:22, 13.55s/it]

Test Loss: 4.011751 Acc: 8.54%


  2%|▏         | 21/1000 [04:32<3:41:57, 13.60s/it]

Test Loss: 3.699341 Acc: 13.91%


  3%|▎         | 31/1000 [06:38<3:37:09, 13.45s/it]

Test Loss: 3.516832 Acc: 17.37%


  4%|▍         | 41/1000 [08:46<3:37:05, 13.58s/it]

Test Loss: 3.402261 Acc: 20.20%


  5%|▌         | 51/1000 [10:53<3:39:54, 13.90s/it]

Test Loss: 3.351591 Acc: 21.90%


  6%|▌         | 61/1000 [13:03<3:37:31, 13.90s/it]

Test Loss: 3.293699 Acc: 23.35%


  7%|▋         | 71/1000 [15:11<3:28:13, 13.45s/it]

Test Loss: 3.354141 Acc: 24.22%


  8%|▊         | 81/1000 [17:17<3:24:00, 13.32s/it]

Test Loss: 3.346975 Acc: 25.06%


  9%|▉         | 91/1000 [19:25<3:27:16, 13.68s/it]

Test Loss: 3.306051 Acc: 25.58%


  9%|▉         | 93/1000 [19:57<3:14:43, 12.88s/it]


KeyboardInterrupt: 

In [None]:
plt.xlabel('rounds')
plt.ylabel('accuracy')
xx = np.arange(0, T + test_freq, test_freq)
plt.plot(xx, accuracies, label=params['method'], marker='.')
plt.legend()