In [None]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
from torch.optim import SGD
from tqdm import tqdm
from torchvision.datasets import CIFAR100
import copy
from torchvision import transforms
from client_selector import ClientSelector
from data_splitter import DataSplitter
from scipy.optimize import minimize, LinearConstraint

In [None]:
import wandb
wandb.login()

In [None]:
K = 100

params = {
    'K': K,
    'C': 0.1,
    'B': 64,
    'J': 4,
    'lr_server': 1,
    'lr_client': 1e-1,
    'participation': 'uniform',
    'gamma': 1,
    'rounds': 2000
}

In [None]:
preprocess = transforms.Compose([
    transforms.RandomCrop((28, 28)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
])

train_dataset = CIFAR100('datasets/cifar100', train=True, transform=preprocess, download=True)
test_dataset = CIFAR100('datasets/cifar100', train=False, transform=preprocess, download=True)

test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
data_split_params = {
    'K': K,
    'split_method': 'non-iid',
    'n_labels': 10
}

data_splitter = DataSplitter(data_split_params, train_dataset)
client_datasets = data_splitter.split()

In [None]:
client_selector = ClientSelector(params)

In [None]:
class LeNet5_circa(nn.Module):
    def __init__(self):
        super( LeNet5_circa, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=5)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(4 * 4 * 64, 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 100)

    def forward(self, x):
        x = self.pool(self.conv1(x).relu())
        x = self.pool(self.conv2(x).relu())
        x = torch.flatten(x, 1)
        x = self.fc1(x).relu()
        x = self.fc2(x).relu()
        x = self.fc3(x)

        return x


model = LeNet5_circa().cuda()
model.to('cuda')

criterion = torch.nn.CrossEntropyLoss().cuda()

In [None]:
wandb.init(
    project='fl',
    name=f'fed mgda {data_split_params["split_method"]}, J={params["J"]}, lr={params["lr_client"]}, lr_server={params["lr_server"]}, n_labels={data_split_params["n_labels"]}',
    config={**params, **data_split_params}
)

In [None]:
def w_norm(coeffs, clients_gradients_normalized):
    w = np.sum([coeffs[i]*clients_gradients_normalized[i] for i in range(len(clients_gradients_normalized))], axis=0)
    w_l2_norm = np.linalg.norm(w)
    gradients = [2 * w.dot(clients_gradients_normalized[i]) for i in range(len(clients_gradients_normalized))]
    return w_l2_norm, gradients

def magical_gradient(clients_gradients):
    n = len(clients_gradients)
    coeffs = np.full(n, 1/n)

    clients_gradients = list(clients_gradients)  # Convert generator to list

    res = minimize(
        w_norm,
        coeffs,
        args = (clients_gradients,),
        jac=True,
        bounds=[(0.,1.) for _ in range(n)],
        constraints=[LinearConstraint(A=[[1] * n], lb = 1., ub = 1.)]
    )

    # this is the gradient produced by the minimization of the convex linear combination (so it's the shortest convex linear combination)
    grad_mo = torch.tensor(sum(res.x[i]*clients_gradients[i] for i in range(n))).cuda()
    return grad_mo

In [None]:
T = params['rounds']
test_freq = 50

In [None]:
def test(model):
    model.eval()
    test_loss, correct, total = 0, 0, 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    test_loss = test_loss / len(test_loader)
    test_accuracy = 100. * correct / total
    print(f'Test Loss: {test_loss:.6f} Acc: {test_accuracy:.2f}%')
    return test_accuracy, test_loss

def client_update(model, k, params):
    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=params['lr_client'], weight_decay=4e-4)
    loader = DataLoader(client_datasets[k], batch_size=params['B'], shuffle=True)

    client_loss, client_correct, client_total = 0, 0, 0
    i = 0
    client_gradients = []
    for i in range(params['J']):
        for batch_idx, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.cuda(), targets.cuda()

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            client_loss += loss.item()
            _, predicted = outputs.max(1)
            client_total += targets.size(0)
            client_correct += predicted.eq(targets).sum().item()

            loss.backward()

            # extract the gradient
            client_grad = []
            for parameter in model.parameters():
                client_grad += parameter.grad.reshape(-1).tolist()
            client_grad = np.array(client_grad)

            client_gradients.append(client_grad)

            optimizer.step()
            i += 1

            if i >= params['J']:
                client_loss = client_loss / params['J']
                client_accuracy = 100. * client_correct / client_total

                client_gradients = np.array(client_gradients)
                client_gradient_avg = np.mean(client_gradients, axis = 0) # average the gradients of the different J steps
                client_gradient_norm = client_gradient_avg / np.linalg.norm(client_gradient_avg)

                return client_accuracy, client_loss, client_gradient_norm

    client_gradients = np.array(client_gradients)
    client_gradient_avg = np.mean(client_gradients, axis = 0) # average the gradients of the different J steps
    client_gradient_norm = client_gradient_avg / np.linalg.norm(client_gradient_avg)

    client_loss = client_loss / params['J'] # average loss in J steps
    client_accuracy = 100. * client_correct / client_total

    return client_accuracy, client_loss, client_gradient_norm

def train(model, params):
    server_optimizer = SGD(model.parameters(), lr=params['lr_server'], weight_decay=4e-4)
    server_optimizer.zero_grad(set_to_none=True)

    test_accuracies, test_losses = [], []
    train_accuracies, train_losses = [], []

    for t in tqdm(range(params['rounds'])):
        round_loss, round_accuracy = 0, 0
        s = client_selector.sample()

        clients_grad_normalized = []
        for k in s:
            client_accuracy, client_loss, client_gradient_norm = client_update(copy.deepcopy(model), k, params)
            round_loss += client_loss
            round_accuracy += client_accuracy
            clients_grad_normalized.append(client_gradient_norm)
        round_loss_avg = round_loss / len(s)
        round_accuracy_avg = round_accuracy / len(s)
        train_accuracies.append(round_loss_avg)
        train_losses.append(round_accuracy_avg)

        grad_mo = magical_gradient(clients_grad_normalized).cuda()

        # distribute the gradient on the parameters
        idx = 0
        for name, par in model.named_parameters():
            shape = tuple(par.data.shape)
            tot_len = np.prod(shape).astype(int) # shape[0]*shape[1]
            par.grad = grad_mo[idx:(idx + tot_len)].reshape(shape).to(torch.float) # setting the gradients!

            idx += tot_len

        server_optimizer.step()

        wandb.log({'train/acc': round_accuracy_avg, 'train/loss': round_loss_avg, 'round': t})

        if t % test_freq == 0 or t == params['rounds']-1:
            acc, loss = test(model)
            test_accuracies.append(acc)
            test_losses.append(loss)
            wandb.log({'acc': acc, 'loss': loss, 'round': t})

    return test_accuracies, test_losses


train(model, params)