In [1]:
import os
import sys

import numpy as np
import torch
from torch import nn, autograd, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.models import efficientnet_v2_l, EfficientNet_V2_L_Weights
from torchvision.datasets import CIFAR10
import torchvision.transforms as T
from sklearn.metrics import accuracy_score
import tqdm
import wandb

sys.path.append(os.path.dirname(os.getcwd()))
from pyfed.models.cnns import SimpleCNN

In [2]:
name = 'cnn-scratch-simpleaug'
EPOCH = 100
LR = 1e-3
FREEZE = 0

In [3]:
net = efficientnet_v2_l(weights=EfficientNet_V2_L_Weights.DEFAULT)
net.classifier[-1] = nn.Linear(net.classifier[-1].in_features, 10)
for module in net.features[:-FREEZE]:
    module.requires_grad_(False)
net = net.cuda()

In [3]:
net = SimpleCNN().cuda()

In [4]:
mean = [0.4914, 0.4822, 0.4465]
std = [0.2470, 0.2435, 0.2616]
trainset = CIFAR10('~/data', train=True, transform=T.Compose([
    # T.Resize((224,224)),
    T.RandomRotation(90),
    T.RandomHorizontalFlip(),
    # T.AutoAugment(T.AutoAugmentPolicy.CIFAR10),
    T.ToTensor(),
    T.Normalize(mean, std),
]))
trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=4)
testset = CIFAR10('~/data', train=False, transform=T.Compose([
    # T.Resize((224,224)),
    T.ToTensor(),
    T.Normalize(mean,std),
]))
testloader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=4)

In [5]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda x: x.requires_grad, net.parameters()), lr=LR)
scaler = torch.cuda.amp.GradScaler()
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)

In [6]:
wandb.init(job_type='train', project='kd', name=name)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdev-jahn[0m ([33miislab-official[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
for e in tqdm.trange(EPOCH):
    preds = []
    targets = []
    losses = []
    for data, target in trainloader:
        with torch.cuda.amp.autocast():
            output = net.train()(data.cuda())
            loss = criterion(output, target.cuda())
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        preds+=output.cpu().argmax(dim=1).tolist()
        targets+=target.tolist()
        losses.append(loss.item())
    scheduler.step()
    wandb.log({
        'epoch':e+1,
        'train_acc':accuracy_score(preds, targets)*100,
        'train_loss':np.mean(losses),
        'lr': scheduler.get_last_lr()[0],
    })
    preds = []
    targets = []
    losses = []
    with torch.no_grad():
        for data, target in testloader:
            output = net.eval()(data.cuda())
            loss = criterion(output, target.cuda())
            preds+=output.cpu().argmax(dim=1).tolist()
            targets+=target.tolist()
            losses.append(loss.item())
    wandb.log({
        'epoch':e+1,
        'test_acc':accuracy_score(preds, targets)*100,
        'test_loss':np.mean(losses)
    })

100%|█████████████████████████████████████████████████████████████████████████████████| 100/100 [11:10<00:00,  6.71s/it]


In [8]:
torch.save(net, f'{name}.pth')
wandb.finish()

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
lr,████▇▇▇▇▆▆▆▆▅▅▅▅▄▄▄▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁
test_acc,▁▃▄▅▅▆▆▆▇▆▇▇▇▇▇▇▇█▇██▇▇▇▇█▇████▇███▇████
test_loss,█▆▅▄▃▃▃▃▂▃▃▂▂▂▂▂▂▁▂▁▁▂▂▂▂▁▂▁▁▁▁▂▂▁▁▂▁▁▁▁
train_acc,▁▃▄▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇████████████
train_loss,█▆▅▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,100.0
lr,0.00035
test_acc,60.86
test_loss,1.12621
train_acc,66.158
train_loss,0.95414


# Distill

In [20]:
name = 'eff-cnn-distill-simpleaug'
EPOCH = 100
LR = 1e-3

In [21]:
teacher = torch.load('./eff-V2-L-finetune-nofreeze.pth').cuda()
student = SimpleCNN().cuda()
criterion = torch.dist
optimizer = optim.Adam(filter(lambda x: x.requires_grad, student.parameters()), lr=LR)
scaler = torch.cuda.amp.GradScaler()
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)

In [22]:
wandb.init(job_type='distill', project='kd', name=name)

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666932018318524, max=1.0)…

In [23]:
for e in tqdm.trange(EPOCH):
    preds = []
    targets = []
    losses = []
    for data, target in trainloader:
        with torch.cuda.amp.autocast():
            with torch.no_grad():
                output_t = teacher.eval()(data.cuda())
            output_s = student.train()(T.Resize((32,32))(data.cuda()))
            loss = criterion(output_t, output_s)
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        preds+=output_s.cpu().argmax(dim=1).tolist()
        targets+=target.tolist()
        losses.append(loss.item())
    scheduler.step()
    wandb.log({
        'epoch':e+1,
        'train_acc':accuracy_score(preds, targets)*100,
        'train_loss':np.mean(losses),
        'lr': scheduler.get_last_lr()[0],
    })
    preds = []
    targets = []
    losses = []
    with torch.no_grad(), torch.cuda.amp.autocast():
        for data, target in testloader:
            output_t = teacher.eval()(data.cuda())
            output_s = student.eval()(T.Resize((32,32))(data.cuda()))
            loss = criterion(output_t, output_s)
            preds+=output_s.cpu().argmax(dim=1).tolist()
            targets+=target.tolist()
            losses.append(loss.item())
    wandb.log({
        'epoch':e+1,
        'test_acc':accuracy_score(preds, targets)*100,
        'test_loss':np.mean(losses)
    })

100%|██████████████████████████████████████████████████████████████████████████████| 100/100 [3:08:43<00:00, 113.24s/it]


In [24]:
torch.save(student, f'{name}.pth')
wandb.finish()

VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
lr,████▇▇▇▇▆▆▆▆▅▅▅▅▄▄▄▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁
test_acc,▁▃▄▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇███▇███████████
test_loss,█▆▅▅▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁
train_acc,▁▄▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇███▇███████████████
train_loss,█▆▅▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,100.0
lr,0.00035
test_acc,70.16
test_loss,110.75859
train_acc,49.1
train_loss,79.98911
