In [8]:
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.ops import MLP
from PIL import Image
import numpy as np
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100, CIFAR10, MNIST
import copy
from tqdm import tqdm
from collections import OrderedDict
import random
import math
import matplotlib.pyplot as plt
from client_selector import ClientSelector
from data_splitter import DataSplitter

In [9]:
!pip install wandb -qU
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mivanludvig[0m ([33mivanludvigdev[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [10]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Data

In [11]:
K = 100

params = {
    'K': K,
    'C': 0.1,
    'B': 64,
    'J': 4,
    # 'lr_server': 1e-1,
    'lr_client': 1e-1,
    'method': 'fedavg',
    'tau': 1e-3,
    # 'gamma': 5,
    'participation': 'uniform',
    'rounds': 2000
}

In [12]:
preprocess = transforms.Compose([
    transforms.RandomCrop((28, 28)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
])

train_dataset = CIFAR100('datasets/cifar100', train=True, transform=preprocess, download=True)
test_dataset = CIFAR100('datasets/cifar100', train=False, transform=preprocess, download=True)

test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [13]:
data_split_params = {
    'K': K,
    'split_method': 'non-iid',
    'n_labels': 5
}

data_splitter = DataSplitter(data_split_params, train_dataset)
client_datasets = data_splitter.split()

samples_per_client: 500
samples_per_label: 100


100%|██████████| 100/100 [00:01<00:00, 57.31it/s]


In [14]:
def split_train_test(d):
  train_size = int(0.8 * len(d))
  test_size = len(d) - train_size
  train_dataset, test_dataset = torch.utils.data.random_split(d, [train_size, test_size])

  return train_dataset, test_dataset

train_test_client_datasets = [split_train_test(d) for d in client_datasets]
train_client_datasets = [d[0] for d in train_test_client_datasets]
test_client_datasets = [d[1] for d in train_test_client_datasets]

In [15]:
client_selector = ClientSelector(params)

## Model

In [16]:
class LeNet5_circa(nn.Module):
    def __init__(self):
        super( LeNet5_circa, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=5)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(4 * 4 * 64, 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 100)

    def forward(self, x):
        x = self.pool(self.conv1(x).relu())
        x = self.pool(self.conv2(x).relu())
        x = torch.flatten(x, 1)
        x = self.fc1(x).relu()
        x = self.fc2(x).relu()
        x = self.fc3(x)

        return x


model = LeNet5_circa().cuda()
model.to('cuda')

criterion = torch.nn.CrossEntropyLoss().cuda()

In [17]:
wandb.init(
    project='fl',
    name=f'fed mod {data_split_params["split_method"]}, J={params["J"]}, lr={params["lr_client"]}, n_labels={data_split_params["n_labels"]}',
    config={**params, **data_split_params}
)

In [18]:
path = lambda t: f'/content/drive/My Drive/fl/{data_split_params["split_method"]}-mod-J{params["J"]}-lr{params["lr_client"]}-n_labels{data_split_params["n_labels"]}-{t}.pt'

backup = 0
if backup:
    model.load_state_dict(torch.load(path(backup)))
model

LeNet5_circa(
  (conv1): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=1024, out_features=384, bias=True)
  (fc2): Linear(in_features=384, out_features=192, bias=True)
  (fc3): Linear(in_features=192, out_features=100, bias=True)
)

## Utils

In [19]:
def reduce_w(w_list, f):
    return OrderedDict([
            (key, f([x[key] for x in w_list])) for key in w_list[0].keys()
        ])


def tensor_sum(tensors_list, weights=None):
    if weights:
      return torch.sum(torch.stack([t*w for t, w in zip(tensors_list, weights)]), dim=0)
    return torch.sum(torch.stack(tensors_list), dim=0)


def w_norm2(w):
    res = 0
    for key in w.keys():
        res += torch.linalg.vector_norm(w[key]) ** 2
    return math.sqrt(res)


def fed_adagrad(v, delta, params):
    delta_norm2 = w_norm2(delta)
    return v + delta_norm2


def fed_yogi(v, delta, params):
    delta_norm2 = w_norm2(delta)
    return v - (1-params['beta2']) * delta_norm2 * torch.sign(v - delta_norm2)


def fed_adam(v, delta, params):
    delta_norm2 = w_norm2(delta)
    return params['beta2'] * v + (1-params['beta2']) * delta_norm2


methods = {
    'adagrad': fed_adagrad,
    'yogi': fed_yogi,
    'adam': fed_adam
}

## Training

In [20]:
T = params['rounds']
test_freq = 50
save_freq = 200

In [21]:
def test(model, loader, verbose=False):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    test_loss = test_loss / len(loader)
    test_accuracy = 100. * correct / total
    if verbose:
        print(f'Test Loss: {test_loss:.6f} Acc: {test_accuracy:.2f}%')
    return test_accuracy, test_loss


def client_update(model, k, params):
    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=params['lr_client'], weight_decay=4e-4)
    loader = DataLoader(train_client_datasets[k], batch_size=params['B'], shuffle=True)
    test_client_loader = DataLoader(test_client_datasets[k], batch_size=params['B'], shuffle=True)

    i = 0
    for i in range(params['J']):
        for batch_idx, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            i += 1

            if i >= params['J']:
              return model.state_dict(), *test(model, test_client_loader)

    return model.state_dict(), *test(model, test_client_loader)


def train(model, params):
    accuracies = []
    losses = []
    v = params['tau'] ** 2
    w = model.state_dict()
    m = reduce_w([w], lambda x: torch.mul(x[0], 0.0))
    for t in tqdm(range(T)):
        s = client_selector.sample()

        w_clients = []
        w_losses = []
        for k in s:
            client_w, client_acc, client_loss = client_update(copy.deepcopy(model), k, params)
            w_clients.append(client_w)
            w_losses.append(client_loss)

        w_sum = sum(w_losses)
        weights = [l/w_sum for l in w_losses]
        if params['method'] == 'fedavg':
            w = reduce_w(
                w_clients,
                lambda x: tensor_sum(x, weights)
            )
        else:
            deltas = [
                reduce_w(
                    [w, w_client],
                    lambda x: x[1] - x[0]
                ) for w_client in w_clients
            ]

            delta = reduce_w(
                deltas,
                lambda x: tensor_sum(x) / len(deltas)
            )

            m = reduce_w(
                [m, delta],
                lambda x: params['beta1'] * x[0] + (1-params['beta1']) * x[1]
            )

            v = methods[params['method']](v, delta, params)
            w = reduce_w(
                [w, m],
                lambda x: x[0] + params['lr_server'] * x[1] / (math.sqrt(v) + params['tau'])
            )

        model.load_state_dict(w)

        if t % test_freq == 0 or t == T-1:
            acc, loss = test(model, test_loader, True)
            accuracies.append(acc)
            losses.append(loss)
            wandb.log({'acc': acc, 'loss': loss, 'round': t})

        if t % save_freq == 0 or t == T-1:
            torch.save(model.state_dict(), path(t))

    return accuracies, losses


accuracies, losses = train(model, params)

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  0%|          | 1/2000 [00:05<3:09:47,  5.70s/it]

Test Loss: 4.605740 Acc: 1.27%


  3%|▎         | 51/2000 [00:20<1:05:48,  2.03s/it]

Test Loss: 4.592994 Acc: 1.89%


  5%|▌         | 101/2000 [00:33<52:28,  1.66s/it]

Test Loss: 4.511084 Acc: 2.11%


  8%|▊         | 151/2000 [00:47<50:35,  1.64s/it]

Test Loss: 4.424550 Acc: 2.63%


 10%|█         | 201/2000 [01:01<55:23,  1.85s/it]

Test Loss: 4.481052 Acc: 3.13%


 13%|█▎        | 251/2000 [01:16<1:01:04,  2.10s/it]

Test Loss: 4.223954 Acc: 4.89%


 15%|█▌        | 301/2000 [01:29<46:39,  1.65s/it]

Test Loss: 4.177453 Acc: 5.20%


 18%|█▊        | 351/2000 [01:42<44:23,  1.62s/it]

Test Loss: 4.058074 Acc: 6.54%


 20%|██        | 401/2000 [01:55<40:13,  1.51s/it]

Test Loss: 3.936386 Acc: 9.08%


 23%|██▎       | 451/2000 [02:08<36:02,  1.40s/it]

Test Loss: 3.911915 Acc: 9.02%


 25%|██▌       | 501/2000 [02:20<32:07,  1.29s/it]

Test Loss: 3.813553 Acc: 10.67%


 28%|██▊       | 551/2000 [02:33<30:28,  1.26s/it]

Test Loss: 3.742543 Acc: 11.63%


 30%|███       | 601/2000 [02:45<29:29,  1.26s/it]

Test Loss: 3.712796 Acc: 11.63%


 33%|███▎      | 651/2000 [02:59<29:13,  1.30s/it]

Test Loss: 3.640003 Acc: 13.62%


 35%|███▌      | 701/2000 [03:11<27:56,  1.29s/it]

Test Loss: 3.503747 Acc: 15.53%


 38%|███▊      | 751/2000 [03:24<27:10,  1.31s/it]

Test Loss: 3.555487 Acc: 14.77%


 40%|████      | 801/2000 [03:38<25:45,  1.29s/it]

Test Loss: 3.506569 Acc: 14.87%


 43%|████▎     | 851/2000 [03:51<24:59,  1.30s/it]

Test Loss: 3.356865 Acc: 18.66%


 45%|████▌     | 901/2000 [04:04<23:56,  1.31s/it]

Test Loss: 3.319912 Acc: 19.61%


 48%|████▊     | 951/2000 [04:17<24:35,  1.41s/it]

Test Loss: 3.315273 Acc: 19.96%


 50%|█████     | 1001/2000 [04:30<25:05,  1.51s/it]

Test Loss: 3.289889 Acc: 19.89%


 53%|█████▎    | 1051/2000 [04:43<24:39,  1.56s/it]

Test Loss: 3.300867 Acc: 18.98%


 55%|█████▌    | 1101/2000 [04:56<24:46,  1.65s/it]

Test Loss: 3.191100 Acc: 21.93%


 58%|█████▊    | 1151/2000 [05:10<23:39,  1.67s/it]

Test Loss: 3.140942 Acc: 22.54%


 60%|██████    | 1201/2000 [05:24<22:39,  1.70s/it]

Test Loss: 3.232154 Acc: 20.87%


 63%|██████▎   | 1251/2000 [05:37<20:59,  1.68s/it]

Test Loss: 3.038176 Acc: 25.23%


 65%|██████▌   | 1301/2000 [05:51<19:10,  1.65s/it]

Test Loss: 3.059211 Acc: 25.15%


 68%|██████▊   | 1351/2000 [06:04<17:48,  1.65s/it]

Test Loss: 3.072062 Acc: 23.56%


 70%|███████   | 1401/2000 [06:17<15:33,  1.56s/it]

Test Loss: 2.985109 Acc: 26.35%


 73%|███████▎  | 1451/2000 [06:30<13:43,  1.50s/it]

Test Loss: 3.011678 Acc: 25.12%


 75%|███████▌  | 1501/2000 [06:44<12:11,  1.47s/it]

Test Loss: 3.071455 Acc: 24.61%


 78%|███████▊  | 1551/2000 [06:57<10:17,  1.38s/it]

Test Loss: 3.032217 Acc: 25.28%


 80%|████████  | 1601/2000 [07:10<08:27,  1.27s/it]

Test Loss: 2.912180 Acc: 27.85%


 83%|████████▎ | 1651/2000 [07:22<07:33,  1.30s/it]

Test Loss: 2.906576 Acc: 27.97%


 85%|████████▌ | 1701/2000 [07:35<06:31,  1.31s/it]

Test Loss: 2.919141 Acc: 27.60%


 88%|████████▊ | 1751/2000 [07:48<05:20,  1.29s/it]

Test Loss: 2.908475 Acc: 28.46%


 90%|█████████ | 1801/2000 [08:01<04:16,  1.29s/it]

Test Loss: 2.858901 Acc: 28.46%


 93%|█████████▎| 1851/2000 [08:14<03:11,  1.29s/it]

Test Loss: 2.823965 Acc: 29.57%


 95%|█████████▌| 1901/2000 [08:26<02:08,  1.30s/it]

Test Loss: 2.908374 Acc: 28.10%


 98%|█████████▊| 1951/2000 [08:39<01:04,  1.32s/it]

Test Loss: 2.887194 Acc: 28.80%


100%|██████████| 2000/2000 [08:52<00:00,  3.76it/s]

Test Loss: 2.925799 Acc: 28.24%



