In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.ops import MLP
from PIL import Image
import numpy as np
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100, CIFAR10, MNIST
from torch.optim import SGD, lr_scheduler
import copy
from tqdm import tqdm
from collections import OrderedDict
import random
import math
import matplotlib.pyplot as plt

In [None]:
!pip install wandb -qU
import wandb
wandb.login()

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.7/6.7 MB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.3/207.3 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m289.0/289.0 kB[0m [31m15.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25h

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Data

In [None]:
preprocess = transforms.Compose([
    transforms.RandomCrop((28, 28)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
])

train_dataset = CIFAR100('datasets/cifar100', train=True, transform=preprocess, download=True)
test_dataset = CIFAR100('datasets/cifar100', train=False, transform=preprocess, download=True)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to datasets/cifar100/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:03<00:00, 47244573.22it/s]


Extracting datasets/cifar100/cifar-100-python.tar.gz to datasets/cifar100
Files already downloaded and verified


## Model

In [None]:
class LeNet5_circa(nn.Module):
    def __init__(self):
        super( LeNet5_circa, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=5)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(4 * 4 * 64, 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 100)

    def forward(self, x):
        x = self.pool(self.conv1(x).relu())
        x = self.pool(self.conv2(x).relu())
        x = torch.flatten(x, 1)
        x = self.fc1(x).relu()
        x = self.fc2(x).relu()
        x = self.fc3(x)

        return x


model = LeNet5_circa().cuda()
model.to('cuda')

criterion = torch.nn.CrossEntropyLoss().cuda()
params = {
    'lr': 0.001,
    'momentum': 0.9,
    'epochs': 200,
    'T_max': 20
}
optimizer = torch.optim.SGD(model.parameters(), lr=params['lr'], momentum=params['momentum'], weight_decay=4e-4)

In [None]:
wandb.init(
    project='fl',
    name=f'centralized lr={params["lr"]} m={params["momentum"]} T_max={params["T_max"]}',
    config=params
)

VBox(children=(Label(value='0.011 MB of 0.011 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

In [None]:
path = lambda t: f'/content/drive/My Drive/fl/lenet-{t}-lr{params["lr"]}-m{params["momentum"]}.pt'

backup = 0
if backup:
    model.load_state_dict(torch.load(path(backup)))

scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=params["T_max"], last_epoch=backup)

LeNet5_circa(
  (conv1): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=1024, out_features=384, bias=True)
  (fc2): Linear(in_features=384, out_features=192, bias=True)
  (fc3): Linear(in_features=192, out_features=100, bias=True)
)

## Training

In [None]:
T = params['epochs']
test_freq = 10
save_freq = 50

In [None]:
def test(model):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    test_loss = test_loss / len(test_loader)
    test_accuracy = 100. * correct / total
    print(f'Test Loss: {test_loss:.6f} Acc: {test_accuracy:.2f}%')
    return test_accuracy, test_loss


def train(model):
    accuracies = []
    losses = []
    for t in tqdm(range(backup, T)):
        model.train()
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        scheduler.step()

        if t % test_freq == 0 or t == T-1:
            acc, loss = test(model)
            accuracies.append(acc)
            losses.append(loss)
            wandb.log({'acc': acc, 'loss': loss, 'epoch': t})

        if t % save_freq == 0 or t == T-1:
            torch.save(model.state_dict(), path(t))


    return accuracies, losses


accuracies, losses = train(model)

  0%|          | 1/200 [00:24<1:22:48, 24.97s/it]

Test Loss: 4.576331 Acc: 2.64%


  6%|▌         | 11/200 [03:59<1:10:08, 22.27s/it]

Test Loss: 3.278267 Acc: 21.72%


 10%|█         | 21/200 [07:34<1:06:16, 22.22s/it]

Test Loss: 3.051432 Acc: 26.07%


 16%|█▌        | 31/200 [11:10<1:02:48, 22.30s/it]

Test Loss: 2.953679 Acc: 27.78%


 20%|██        | 41/200 [14:48<1:00:15, 22.74s/it]

Test Loss: 2.576150 Acc: 34.69%


 26%|██▌       | 51/200 [18:24<56:21, 22.69s/it]

Test Loss: 2.356283 Acc: 39.46%


 30%|███       | 61/200 [22:00<52:22, 22.61s/it]

Test Loss: 2.268678 Acc: 41.89%


 36%|███▌      | 71/200 [25:36<48:41, 22.65s/it]

Test Loss: 2.308035 Acc: 41.33%


 40%|████      | 81/200 [29:10<44:18, 22.34s/it]

Test Loss: 2.263652 Acc: 42.37%


In [None]:
plt.xlabel('epoch')
plt.ylabel('accuracy')
xx = np.arange(0, T + test_freq, test_freq)
plt.plot(xx, accuracies, label='centralized', marker='.')
plt.legend()