Paper: https://arxiv.org/abs/2003.00295

In [1]:
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.ops import MLP
from PIL import Image
import numpy as np
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100, MNIST
from torch.optim import SGD
import copy
from tqdm import tqdm
from collections import OrderedDict
import random
import math

## Data

In [2]:
K = 100
S = np.array(range(K))

In [3]:
preprocess = transforms.Compose([
    transforms.RandomCrop((24, 24)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
])

train_dataset = CIFAR100('datasets/cifar100', train=True, transform=preprocess, download=True)
test_dataset = CIFAR100('datasets/cifar100', train=False, transform=preprocess, download=True)

test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to datasets/cifar100/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:03<00:00, 53558421.32it/s]


Extracting datasets/cifar100/cifar-100-python.tar.gz to datasets/cifar100
Files already downloaded and verified


In [4]:
iid = True
samples_per_client = int(len(train_dataset) / K)
shards_per_client = 2

def split_data(dataset, iid=True):
    if iid:
        return torch.utils.data.random_split(train_dataset, [samples_per_client] * K)
    else:
        sorted_dataset = sorted(train_dataset, key=lambda x: x[1])
        shard_size = int(samples_per_client / shards_per_client)
        shards = [
            torch.utils.data.Subset(
                sorted_dataset,
                range(i*shard_size, (i+1)*shard_size)
            )
            for i in range(K*shards_per_client)
        ]

        random.shuffle(shards)

        return [
            torch.utils.data.ConcatDataset([shards[2*i], shards[2*i+1]])
            for i in range(K)
        ]


client_datasets = split_data(train_dataset, iid)
assert len(client_datasets) == K
assert len(client_datasets[0]) == samples_per_client
assert iid or all([0 < len(set(map(lambda x: x[1], client_datasets[i]))) <= 4 for i in range(K)])

## Model

In [6]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False).cuda()
model.fc = nn.Linear(512, 100)

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


## Utils

In [14]:
def reduce_w(w_list, f):
    return OrderedDict([
            (key, f([x[key] for x in w_list])) for key in w_list[0].keys()
        ])


def fed_adagrad(v, delta, params):
    delta_norm2 = torch.square(torch.norm(delta))
    return v + delta_norm2


def fed_yogi(v, delta, params):
    delta_norm2 = torch.square(torch.norm(delta))
    return v - (1-params['beta2']) * delta_norm2 * torch.sign(v - delta_norm2)


def fed_adam(v, delta, params):
    delta_norm2 = torch.square(torch.norm(delta))
    return params['beta2'] * v + (1-params['beta2']) * delta_norm2


methods = {
    'adagrad': fed_adagrad,
    'yogi': fed_yogi,
    'adam': fed_adam
}

## Training

In [8]:
T = 300
test_freq = 40

In [None]:
criterion = torch.nn.CrossEntropyLoss().cuda()

def test(model):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    test_loss = test_loss / len(test_loader)
    test_accuracy = 100. * correct / total
    print(f'Test Loss: {test_loss:.6f} Acc: {test_accuracy:.2f}%')
    return test_accuracy


def client_update(model, k, w, params):
    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=params['lr_client'])
    loader = DataLoader(client_datasets[k], batch_size=params['B'])

    for i in range(params['E']):
        for batch_idx, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

    return model.state_dict()


def train(model, params):
    accuracies = []
    v = 0
    w = model.state_dict()
    m = reduce_w([w], lambda x: torch.mul(x[0], 0.0))
    for t in tqdm(range(T)):
        m = int(max(params['C']*K, 1))
        s = np.random.choice(S, m, replace=False)

        w_clients = []
        for k in s:
            w_clients.append(client_update(copy.deepcopy(model), k, w, params))

        deltas = [
            reduce_w(
                [w, w_client],
                lambda x: x[0] - x[1]
            ) for w_client in w_clients
        ]

        n_weights = [len(client_datasets[k])/len(w_clients) for k in s]
        delta = reduce_w(
            [w, deltas],
            lambda x: torch.bmm(n_weights, deltas)
        )

        # fed avg
        # w = reduce_w(w_clients, lambda x: torch.sum(x) / len(w_clients))

        m = reduce_w(
            [m, delta],
            lambda x: params['beta1'] * x[0] + (1-params['beta1']) * x[1]
        )
        v = methods[params['method']](v, delta, params)
        w = reduce_w(
            [w, v],
            lambda x: x[0] + params['lr_server'] * m / (math.sqrt(v) + params['tau'])
        )

        model.load_state_dict(w)

        if t % test_freq == 0 or t == T-1:
            accuracies.append(test(model))

        return accuracies

params = {
    'C': 10 / K,
    'B': 20,
    'E': 1,
    'lr_server': 1e-1,
    'lr_client': 1e-1,
    'method': 'adagrad',
    'beta1': 0,
    'beta2': 0,
    'tau': 1e-1
}

accuracies = train(model, params)