In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
import copy
from tqdm import tqdm
from collections import OrderedDict
from client_selector import ClientSelector
from data_splitter import DataSplitter

In [None]:
import wandb
wandb.login()

## Data

In [None]:
K = 100

params = {
    'K': K,
    'C': 0.1,
    'B': 64,
    'J': 4,
    'lr_client': 1e-1,
    'participation': 'uniform',
    'gamma': 1.0,
    'rounds': 2000
}

In [None]:
preprocess = transforms.Compose([
    transforms.RandomCrop((28, 28)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
])

train_dataset = CIFAR100('datasets/cifar100', train=True, transform=preprocess, download=True)
test_dataset = CIFAR100('datasets/cifar100', train=False, transform=preprocess, download=True)

test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
data_split_params = {
    'K': K,
    'split_method': 'iid'
}

data_splitter = DataSplitter(data_split_params, train_dataset)
client_datasets = data_splitter.split()

In [None]:
client_selector = ClientSelector(params)

## Model

In [None]:
class LeNet5_circa(nn.Module):
    def __init__(self):
        super( LeNet5_circa, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=5)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(4 * 4 * 64, 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 100)

    def forward(self, x):
        x = self.pool(self.conv1(x).relu())
        x = self.pool(self.conv2(x).relu())
        x = torch.flatten(x, 1)
        x = self.fc1(x).relu()
        x = self.fc2(x).relu()
        x = self.fc3(x)

        return x


model = LeNet5_circa().cuda()
model.to('cuda')

criterion = torch.nn.CrossEntropyLoss().cuda()

In [None]:
wandb.init(
    project='fl',
    name=f'fed {data_split_params["split_method"]}, J={params["J"]}, lr={params["lr_client"]}',
    config={**params, **data_split_params}
)

## Utils

In [None]:
def reduce_w(w_list, f):
    return OrderedDict([
            (key, f([x[key] for x in w_list])) for key in w_list[0].keys()
        ])


def tensor_sum(tensors_list, weights=None):
    if weights:
      return torch.sum(torch.stack([t*w for t, w in zip(tensors_list, weights)]), dim=0)
    return torch.sum(torch.stack(tensors_list), dim=0)

## Training

In [None]:
T = params['rounds']
test_freq = 50

In [None]:
def test(model):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    test_loss = test_loss / len(test_loader)
    test_accuracy = 100. * correct / total
    print(f'Test Loss: {test_loss:.6f} Acc: {test_accuracy:.2f}%')
    return test_accuracy, test_loss


def client_update(model, k, params):
    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=params['lr_client'], weight_decay=4e-4)
    loader = DataLoader(client_datasets[k], batch_size=params['B'], shuffle=True)

    i = 0
    for i in range(params['J']):
        for batch_idx, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            i += 1

            if i >= params['J']:
              return model.state_dict()

    return model.state_dict()


def train(model, params):
    accuracies = []
    losses = []
    w = model.state_dict()
    for t in tqdm(range(T)):
        s = client_selector.sample()

        w_clients = []
        for k in s:
            w_clients.append(client_update(copy.deepcopy(model), k, params))

        w = reduce_w(
            w_clients,
            lambda x: tensor_sum(x) / len(w_clients)
        )
        model.load_state_dict(w)

        if t % test_freq == 0 or t == T-1:
            acc, loss = test(model)
            accuracies.append(acc)
            losses.append(loss)
            wandb.log({'acc': acc, 'loss': loss, 'round': t})

    return accuracies, losses


accuracies, losses = train(model, params)