In [1]:
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.ops import MLP
from PIL import Image
import numpy as np
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100, CIFAR10, MNIST
from torch.optim import SGD, lr_scheduler
import copy
from tqdm import tqdm
from collections import OrderedDict
import random
import math
import matplotlib.pyplot as plt

In [2]:
!pip install wandb -qU
import wandb
wandb.login()

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.7/6.7 MB[0m [31m20.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.3/207.3 kB[0m [31m25.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m289.0/289.0 kB[0m [31m35.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[?25h

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Data

In [13]:
preprocess = transforms.Compose([
    transforms.RandomCrop((28, 28)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
])

train_dataset = CIFAR100('datasets/cifar100', train=True, transform=preprocess, download=True)
test_dataset = CIFAR100('datasets/cifar100', train=False, transform=preprocess, download=True)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


## Model

In [18]:
class LeNet5_circa(nn.Module):
    def __init__(self):
        super( LeNet5_circa, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=5)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(4 * 4 * 64, 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 100)

    def forward(self, x):
        x = self.pool(self.conv1(x).relu())
        x = self.pool(self.conv2(x).relu())
        x = torch.flatten(x, 1)
        x = self.fc1(x).relu()
        x = self.fc2(x).relu()
        x = self.fc3(x)

        return x


model = LeNet5_circa().cuda()
model.to('cuda')

criterion = torch.nn.CrossEntropyLoss().cuda()
params = {
    'lr': 0.001,
    'momentum': 0.9,
    'epochs': 100
}
optimizer = torch.optim.SGD(model.parameters(), lr=params['lr'], momentum=params['momentum'], weight_decay=4e-4)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, params['lr'])

In [19]:
wandb.init(
    project='fl',
    name='centralized 28x28px',
    config=params
)

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
acc,▁▇████
epoch,▁▂▄▅▇█
loss,█▂▁▁▁▂

0,1
acc,42.99
epoch,50.0
loss,2.45586


In [None]:
path = lambda t: f'/content/drive/My Drive/fl/lenet-{t}.pt'

backup = 0
if backup:
    model.load_state_dict(torch.load(path(backup)))
model

LeNet5_circa(
  (conv1): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=576, out_features=384, bias=True)
  (fc2): Linear(in_features=384, out_features=192, bias=True)
  (fc3): Linear(in_features=192, out_features=100, bias=True)
)

## Training

In [20]:
T = params['epochs']
test_freq = 10
save_freq = 100

In [None]:
def test(model):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    test_loss = test_loss / len(test_loader)
    test_accuracy = 100. * correct / total
    print(f'Test Loss: {test_loss:.6f} Acc: {test_accuracy:.2f}%')
    return test_accuracy, test_loss


def train(model):
    accuracies = []
    losses = []
    for t in tqdm(range(backup, T)):
        model.train()
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        scheduler.step()

        if t % test_freq == 0 or t == T-1:
            acc, loss = test(model)
            accuracies.append(acc)
            losses.append(loss)
            wandb.log({'acc': acc, 'loss': loss, 'epoch': t})

        if t % save_freq == 0 or t == T-1:
            torch.save(model.state_dict(), path(t))


    return accuracies, losses


accuracies, losses = train(model)

  1%|          | 1/100 [00:24<39:42, 24.06s/it]

Test Loss: 4.539940 Acc: 2.46%


 11%|█         | 11/100 [03:47<30:56, 20.86s/it]

Test Loss: 3.095170 Acc: 24.79%


 21%|██        | 21/100 [07:11<27:44, 21.07s/it]

Test Loss: 2.689197 Acc: 32.72%


 31%|███       | 31/100 [10:37<24:32, 21.34s/it]

Test Loss: 2.411912 Acc: 38.00%


 41%|████      | 41/100 [14:03<20:58, 21.33s/it]

Test Loss: 2.254222 Acc: 42.02%


 51%|█████     | 51/100 [17:29<17:24, 21.32s/it]

Test Loss: 2.191563 Acc: 43.97%


 59%|█████▉    | 59/100 [20:17<14:18, 20.94s/it]

In [None]:
plt.xlabel('epoch')
plt.ylabel('accuracy')
xx = np.arange(0, T + test_freq, test_freq)
plt.plot(xx, accuracies, label='centralized', marker='.')
plt.legend()