In [10]:
import numpy as np
import matplotlib.pyplot as plt
import copy

%matplotlib inline

import torch
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [15]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 64
batch_size_test = 64

trainset = CIFAR10(root='./data', train=True,
                   download=True, transform=transform)
#trainloader = DataLoader(trainset, batch_size=batch_size,
#                         shuffle=True, num_workers=2)

testset = CIFAR10(root='./data', train=False,
                  download=True, transform=transform)
testloader = DataLoader(testset, batch_size=batch_size_test,
                        shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [16]:
class SmallCNN(nn.Module):
    def __init__(self):
        super(SmallCNN, self).__init__()

        self.ConvLayers = nn.ModuleList([
            nn.Conv2d(3, 32, kernel_size=3, padding=1),   # N x 3 x 32 x 32 -> N x 32 x 32 x 32
            nn.Conv2d(32, 64, kernel_size=3, padding=1),  # N x 32 x 16 x 16 -> N x 64 x 16 x 16
            nn.Conv2d(64, 128, kernel_size=3, padding=1), # N x 64 x 8 x 8 -> N x 128 x 8 x 8
        ])

        self.BatchNorms = nn.ModuleList([
            nn.BatchNorm2d(32),
            nn.BatchNorm2d(64),
            nn.BatchNorm2d(128),
        ])

        self.fc1 = nn.Linear(128 * 4 * 4, 256)
        self.fc2 = nn.Linear(256, 10)

        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        for conv, batchnorm in zip(self.ConvLayers, self.BatchNorms):
            x = batchnorm(conv(x))
            x = self.pool(self.relu(x))       # final shape: N x 128 x 4 x 4
        x = x.view(x.size(0), -1)                   # reshape to N x 128*4*4
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")
net = SmallCNN().to(device)

init_weights = copy.deepcopy(net.state_dict())

device: cuda:0


In [None]:
NUM_CLIENTS = 100
IID_DATA = True

data_size = len(trainset)

#create indices for each client
if IID_DATA:
    client_indices = torch.tensor_split(torch.randperm(data_size), NUM_CLIENTS)
else:
    raise NotImplementedError("Non-IID data not implemented")

print(f"Samples per client: {len(client_indices[0]):d}")

clientloaders = [DataLoader(Subset(trainset, indices), batch_size=batch_size, shuffle=True) for indices in client_indices]

NUM_ROUNDS = 20
NUM_LOCAL_EPOCHS = 10
C = 0.1

CLIENTS_PER_ROUND = int(NUM_CLIENTS * C)

print(f"Clients per round: {CLIENTS_PER_ROUND:d}")

LR = 0.001

criterion = nn.CrossEntropyLoss()

Samples per client: 500
Clients per round: 10


In [18]:
avg_test_loss = []

current_weights = copy.deepcopy(init_weights)

for round in range(NUM_ROUNDS):     # iterate thru rounds

    clients = torch.randperm(NUM_CLIENTS)[:CLIENTS_PER_ROUND] # random selection of clients to participate

    local_weights = []
    temp_avg_loss = 0
    for client in clients:   # iterate thru clients

        net.load_state_dict(current_weights)    # load global weights
        net.train()                             # set model to train mode

        optimizer = optim.Adam(net.parameters(), lr=LR, weight_decay=0)

        clientloader = clientloaders[client]
        client_loss = 0
        for local_epoch in range(NUM_LOCAL_EPOCHS):     # iterate thru local epochs
            
            epoch_loss = 0
            for Xlocal, Ylocal in clientloader:     # iterate thru local data
                Xlocal, Ylocal = Xlocal.to(device), Ylocal.to(device)

                outputs = net(Xlocal)

                optimizer.zero_grad()
                loss = criterion(outputs, Ylocal)
                loss.backward()
                optimizer.step()
            
                epoch_loss += loss.item() 

        temp_avg_loss += epoch_loss / len(clientloader)   # average loss per client
        local_weights.append(copy.deepcopy(net.state_dict()))
    
    avg_test_loss.append(temp_avg_loss / CLIENTS_PER_ROUND)
    
    print(f"Round {round+1} done")
    print(f"training loss: {avg_test_loss[-1]:.3f}")
    # average local weights
    new_weights = {}
    for key in current_weights:
        new_weights[key] = torch.stack([local_weights[i][key] for i in range(CLIENTS_PER_ROUND)]).sum(0) / CLIENTS_PER_ROUND
    
    print("Federated Averaging done")

    current_weights = new_weights

    if round % 5 == 4:
        net.load_state_dict(current_weights)
        net.eval()

        correct = 0
        total = 0

        with torch.no_grad():
            for (images, labels) in testloader:
                images, labels = images.to(device), labels.to(device)
                # calculate outputs by running images through the network
                outputs = net(images)
                # the class with the highest energy is what we choose as prediction
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(f'Test accuracy in round {round+1:d}: {100 * correct // total} %')

Round 1 done
training loss: 0.470
Federated Averaging done
Round 2 done
training loss: 0.271
Federated Averaging done
Round 3 done
training loss: 0.145
Federated Averaging done
Round 4 done
training loss: 0.105
Federated Averaging done
Round 5 done
training loss: 0.089
Federated Averaging done
Test accuracy in round 5: 64 %
Round 6 done
training loss: 0.083
Federated Averaging done
Round 7 done
training loss: 0.064
Federated Averaging done
Round 8 done
training loss: 0.074
Federated Averaging done
Round 9 done
training loss: 0.071
Federated Averaging done
Round 10 done
training loss: 0.067
Federated Averaging done
Test accuracy in round 10: 69 %
Round 11 done
training loss: 0.061
Federated Averaging done
Round 12 done
training loss: 0.057
Federated Averaging done
Round 13 done
training loss: 0.065
Federated Averaging done
Round 14 done
training loss: 0.041
Federated Averaging done
Round 15 done
training loss: 0.055
Federated Averaging done
Test accuracy in round 15: 69 %
Round 16 done
