In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
from tqdm import tqdm
import wandb

In [10]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.fc1 = nn.Linear(2048, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def client_update(model, optimizer, train_loader, epoch=5):
    model.train()
    for e in range(epoch):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()
    return loss.item()

def server_aggregate(global_model, client_models):
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_models[i].state_dict()[k] for i in range(len(client_models))], 0).mean(0)
    global_model.load_state_dict(global_dict)
    for model in client_models:
        model.load_state_dict(global_model.state_dict())

def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.cuda(), target.cuda()
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    acc = correct / len(test_loader.dataset)

    return test_loss, acc

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mnilm[0m (use `wandb login --relogin` to force relogin)


True

In [5]:
# Hyperparameters

hyperparameters = {
    "num_clients":100,
    "num_selected":10,
    "num_rounds":5,
    "epochs":5,
    "batch_size":32
}

In [15]:
# IID case: all the clients have images of all the classes

# Creating decentralized datasets

traindata = datasets.MNIST('./data', train=True, download=True,
                       transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
                       )
traindata_split = torch.utils.data.random_split(traindata, [int(traindata.data.shape[0] / hyperparameters['num_clients']) for _ in range(hyperparameters['num_clients'])])
train_loader = [torch.utils.data.DataLoader(x, batch_size=hyperparameters['batch_size'], shuffle=True) for x in traindata_split]

test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
        ), batch_size=hyperparameters['batch_size'], shuffle=True)

# Instantiate models and optimizers

with wandb.init(project="FL toy example", config=hyperparameters):
    wandb.run.name = "Trial 1"
    
    global_model = Net().to(device)
    client_models = [Net().to(device) for _ in range(hyperparameters["num_selected"])]
    for model in client_models:
        model.load_state_dict(global_model.state_dict())

    opt = [optim.SGD(model.parameters(), lr=0.1) for model in client_models]
    
    wandb.watch(global_model, log="all", log_freq=10)

    # Runnining FL
    client_idx = np.random.permutation(hyperparameters["num_clients"])[:hyperparameters["num_selected"]]
    for r in range(hyperparameters["num_rounds"]):
        # select random clients
        #client_idx = np.random.permutation(hyperparameters["num_clients"])[:hyperparameters["num_selected"]]
        print(client_idx)

        # client update
        loss = 0
        for i in range(hyperparameters["num_selected"]):
            loss += client_update(client_models[i], opt[i], train_loader[client_idx[i]], epoch=hyperparameters["epochs"])

        # serer aggregate
        server_aggregate(global_model, client_models)
        test_loss, acc = test(global_model, test_loader)

        print('%d-th round' % r)
        print('average train loss %0.3g | test loss %0.3g | test acc: %0.3f' % (loss / hyperparameters["num_selected"], test_loss, acc))

[96 46 51 52 58 30 93  3 71 11]
0-th round
average train loss 0.679 | test loss 0.527 | test acc: 0.856
[96 46 51 52 58 30 93  3 71 11]
1-th round
average train loss 0.141 | test loss 0.256 | test acc: 0.925
[96 46 51 52 58 30 93  3 71 11]
2-th round
average train loss 0.0584 | test loss 0.206 | test acc: 0.945
[96 46 51 52 58 30 93  3 71 11]
3-th round
average train loss 0.0135 | test loss 0.184 | test acc: 0.952
[96 46 51 52 58 30 93  3 71 11]
4-th round
average train loss 0.00707 | test loss 0.173 | test acc: 0.956


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

In [12]:
traindata = datasets.MNIST('./data', train=True, download=True,
                       transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
                       )

train_loader = torch.utils.data.DataLoader(traindata, batch_size=hyperparameters['batch_size'], shuffle=True)

test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
        ), batch_size=hyperparameters['batch_size'], shuffle=True)

global_model = Net().cuda()

opt = optim.SGD(global_model.parameters(), lr=0.1)

loss = 0

loss += client_update(global_model, opt, train_loader, epoch=hyperparameters['epochs'])

# serer aggregate
test_loss, acc = test(global_model, test_loader)

print('average train loss %0.3g | test loss %0.3g | test acc: %0.3f' % (loss , test_loss, acc))

average train loss 0.000904 | test loss 0.0398 | test acc: 0.989
