In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt
import wandb
import tqdm
import numpy as np

In [None]:
torch.manual_seed(33)
np.random.seed(33)

wandb.login()

## Data

In [None]:
preprocess = transforms.Compose([
    transforms.RandomCrop((28, 28)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
])

train_dataset = CIFAR100('datasets/cifar100', train=True, transform=preprocess, download=True)
test_val_dataset = CIFAR100('datasets/cifar100', train=False, transform=preprocess, download=True)

def split_test_val(d):
  train_size = int(0.5 * len(d))
  val_size = len(d) - train_size
  train_dataset, val_dataset = torch.utils.data.random_split(d, [train_size, val_size])

  return train_dataset, val_dataset

test_dataset, val_dataset = split_test_val(test_val_dataset)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

## Model

In [None]:
class LeNet5_circa(nn.Module):
    def __init__(self):
        super( LeNet5_circa, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=5)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(4 * 4 * 64, 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 100)

    def forward(self, x):
        x = self.pool(self.conv1(x).relu())
        x = self.pool(self.conv2(x).relu())
        x = torch.flatten(x, 1)
        x = self.fc1(x).relu()
        x = self.fc2(x).relu()
        x = self.fc3(x)

        return x


model = LeNet5_circa().cuda()
model.to('cuda')

criterion = torch.nn.CrossEntropyLoss().cuda()
params = {
    'lr': 0.001,
    'momentum': 0.9,
    'epochs': 200,
    'T_max': 200
}
optimizer = torch.optim.SGD(model.parameters(), lr=params['lr'], momentum=params['momentum'])
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=params["T_max"])

In [None]:
wandb.init(
    project='fl',
    name=f'centralized lr={params["lr"]} m={params["momentum"]} T_max={params["T_max"]}',
    config=params
)

## Training

In [None]:
T = params['epochs']
test_freq = 5

In [None]:
def test(model, loader=val_loader):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    test_loss = test_loss / len(test_loader)
    test_accuracy = 100. * correct / total
    print(f'Test Loss: {test_loss:.6f} Acc: {test_accuracy:.2f}%')
    return test_accuracy, test_loss


def train(model):
    accuracies = []
    losses = []
    for t in tqdm(range(T)):
        model.train()
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        scheduler.step()

        if t % test_freq == 0 or t == T-1:
            acc, loss = test(model, val_loader)
            accuracies.append(acc)
            losses.append(loss)
            wandb.log({'acc': acc, 'loss': loss, 'epoch': t})

    return accuracies, losses


accuracies, losses = train(model)