In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.ops import MLP
from PIL import Image
import numpy as np
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100, CIFAR10, MNIST
from torch.optim import SGD, lr_scheduler
import copy
from tqdm import tqdm
from collections import OrderedDict
import random
import math
import matplotlib.pyplot as plt

## Data

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
preprocess = transforms.Compose([
    transforms.RandomCrop((24, 24)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
])

train_dataset = CIFAR100('datasets/cifar100', train=True, transform=preprocess, download=True)
test_dataset = CIFAR100('datasets/cifar100', train=False, transform=preprocess, download=True)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to datasets/cifar100/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:13<00:00, 12847639.23it/s]


Extracting datasets/cifar100/cifar-100-python.tar.gz to datasets/cifar100
Files already downloaded and verified


## Model

In [None]:
class LeNet5_circa(nn.Module):
    def __init__(self):
        super( LeNet5_circa, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=5)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(576, 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 100)

    def forward(self, x):
        x = self.pool(self.conv1(x).relu())
        x = self.pool(self.conv2(x).relu())
        x = torch.flatten(x, 1)
        x = self.fc1(x).relu()
        x = self.fc2(x).relu()
        x = self.fc3(x)
        x = nn.functional.softmax(x)
        return x


model = LeNet5_circa().cuda()
model.to('cuda')

criterion = torch.nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, 0.01)

In [None]:
backup = 1500
if backup:
    model.load_state_dict(torch.load(f'/content/drive/My Drive/lenet001-{backup}.pt'))
model

LeNet5_circa(
  (conv1): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=576, out_features=384, bias=True)
  (fc2): Linear(in_features=384, out_features=192, bias=True)
  (fc3): Linear(in_features=192, out_features=100, bias=True)
)

## Training

In [None]:
T = 2000
test_freq = 20
save_freq = 100

In [None]:
def test(model):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    test_loss = test_loss / len(test_loader)
    test_accuracy = 100. * correct / total
    print(f'Test Loss: {test_loss:.6f} Acc: {test_accuracy:.2f}%')
    return test_accuracy, test_loss


def train(model):
    accuracies = []
    losses = []
    for t in tqdm(range(backup, T)):
        model.train()
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        scheduler.step()

        if t % test_freq == 0 or t == T-1:
            acc, loss = test(model)
            accuracies.append(acc)
            losses.append(loss)

        if t % save_freq == 0 or t == T-1:
            torch.save(model.state_dict(), f'/content/drive/My Drive/lenet001-{t}.pt')


    return accuracies, losses


accuracies, losses = train(model)

  x = nn.functional.softmax(x)
  0%|          | 1/500 [00:24<3:20:01, 24.05s/it]

Test Loss: 4.478512 Acc: 14.40%


  4%|▍         | 21/500 [07:02<2:46:10, 20.82s/it]

Test Loss: 4.475781 Acc: 14.71%


  8%|▊         | 41/500 [13:40<2:40:53, 21.03s/it]

Test Loss: 4.477677 Acc: 14.43%


 12%|█▏        | 61/500 [20:21<2:33:47, 21.02s/it]

Test Loss: 4.478799 Acc: 14.35%


 16%|█▌        | 81/500 [27:02<2:27:14, 21.09s/it]

Test Loss: 4.476023 Acc: 14.59%


 20%|██        | 101/500 [33:48<2:20:09, 21.08s/it]

Test Loss: 4.476896 Acc: 14.50%


 24%|██▍       | 121/500 [40:26<2:10:59, 20.74s/it]

Test Loss: 4.479036 Acc: 14.30%


 28%|██▊       | 141/500 [47:05<2:04:37, 20.83s/it]

Test Loss: 4.480172 Acc: 14.16%


 32%|███▏      | 161/500 [53:46<1:58:49, 21.03s/it]

Test Loss: 4.476419 Acc: 14.51%


 35%|███▍      | 174/500 [58:05<1:50:02, 20.25s/it]

In [None]:
plt.xlabel('epoch')
plt.ylabel('accuracy')
xx = np.arange(0, T + test_freq, test_freq)
plt.plot(xx, accuracies, label='centralized', marker='.')
plt.legend()