In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision.models as models
from torchvision import transforms, datasets
from catalyst import dl
from catalyst.contrib.callbacks import WandbLogger
from catalyst.utils import set_global_seed
from dataclasses import dataclass
from tqdm.auto import tqdm
from pathlib import Path

In [3]:
@dataclass
class Config:

    experiment_name: str = "finetune-vgg-on-cifar-10"

    flip_prob: float = 0.5
    rotation_degrees: float = 15

    logdir: str = 'drive/MyDrive/logdir_tune'

    n_conv_layers_to_tune: int = 1

    max_lr: float = 1e-4
    weight_decay: float = 1e-6

    num_epochs: int = 20
    batch_size: int = 32
    patience: int = 2

    seed: int = 21


config = Config()
set_global_seed(config.seed)

In [4]:
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)

train_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(config.flip_prob),
    transforms.RandomRotation(config.rotation_degrees),
    transforms.ToTensor(),
    normalize
])

test_transforms = transforms.Compose([
    transforms.ToTensor(),
    normalize
])

I will be first finetuning teacher on CIFAR10, since all pretrained torchvision models are for ImageNet.

And the ImageNet itself is too large to handle with computational resources I have.

In [5]:
train = datasets.CIFAR10('data', train=True, download=True, transform=train_transforms)
test = datasets.CIFAR10('data', train=False, download=True, transform=test_transforms)
print(len(train))
print(len(test))


N_CLASSES = 10

Files already downloaded and verified
Files already downloaded and verified
50000
10000


In [6]:
teacher = models.resnet18(pretrained=True).eval()

In [7]:
def freeze_module(module: nn.Module):
    for param in module.parameters():
        param.requires_grad = False


def unfreeze_module(module: nn.Module):
    for param in module.parameters():
        param.requires_grad = True


freeze_module(teacher)
teacher.fc = nn.Linear(512, N_CLASSES)


conv_layers_to_tune = [
    getattr(teacher, f"layer{i}")
    for i in range(4, config.n_conv_layers_to_tune, -1)
]
for m in [teacher.fc] + conv_layers_to_tune:
    unfreeze_module(m)
    m.train()


In [8]:
loaders = {
    'train': DataLoader(train, batch_size=config.batch_size, shuffle=True),
    'valid': DataLoader(test, batch_size=config.batch_size)
}

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(teacher.parameters(), weight_decay=config.weight_decay)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=config.max_lr,
    epochs=config.num_epochs,
    steps_per_epoch=len(loaders['train'])
)

callbacks = [
    dl.SchedulerCallback(mode='batch'),
    dl.EarlyStoppingCallback(config.patience),
    dl.AccuracyCallback(topk_args=[1, 3, 5], num_classes=N_CLASSES),
    WandbLogger(
        project='dl-course',
        entity='dimaorekhov',
        group='distillation',
        name=config.experiment_name,
        config=dict(config.__dict__)
    )
]

In [9]:
Path(config.logdir).mkdir(parents=True, exist_ok=True)

In [10]:
runner = dl.SupervisedRunner()
runner.train(
    model=teacher,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=loaders,
    callbacks=callbacks,
    num_epochs=config.num_epochs,
    verbose=True,
    logdir=config.logdir,
    check=True
)

[34m[1mwandb[0m: Currently logged in as: [33mdimaorekhov[0m (use `wandb login --relogin` to force relogin)


1/10 * Epoch (train):   0% 4/1563 [00:01<09:45,  2.66it/s, accuracy01=0.125, accuracy03=0.281, accuracy05=0.594, loss=2.434]
1/10 * Epoch (valid):   1% 4/313 [00:00<00:20, 15.15it/s, accuracy01=0.094, accuracy03=0.375, accuracy05=0.562, loss=2.406]
[2020-11-30 01:36:42,850] 
1/10 * Epoch 1 (_base): lr=2.000e-06 | momentum=0.9500
1/10 * Epoch 1 (train): accuracy01=0.1484 | accuracy03=0.3359 | accuracy05=0.5156 | loss=2.5009
1/10 * Epoch 1 (valid): accuracy01=0.1250 | accuracy03=0.4141 | accuracy05=0.6016 | loss=2.3263



To get the last learning rate computed by the scheduler, please use `get_last_lr()`.



2/10 * Epoch (train):   0% 4/1563 [00:01<08:05,  3.21it/s, accuracy01=0.062, accuracy03=0.500, accuracy05=0.625, loss=2.531]
2/10 * Epoch (valid):   1% 4/313 [00:00<00:20, 14.94it/s, accuracy01=0.156, accuracy03=0.406, accuracy05=0.531, loss=2.430]
[2020-11-30 01:36:44,651] 
2/10 * Epoch 2 (_base): lr=2.000e-06 | momentum=0.9500
2/10 * Epoch 2 (train): accuracy01=0.1094 | accuracy03=0.3906 | accuracy05=0.5234 | loss=2.4674
2/10 * Epoch 2 (valid): accuracy01=0.1172 | accuracy03=0.4062 | accuracy05=0.5781 | loss=2.3539


VBox(children=(Label(value=' 0.01MB of 0.01MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
accuracy01/train,0.10938
accuracy03/train,0.39062
accuracy05/train,0.52344
loss/train,2.46739
accuracy01/valid,0.11719
accuracy03/valid,0.40625
accuracy05/valid,0.57812
loss/valid,2.35387
lr/_base,0.0
momentum/_base,0.95


0,1
accuracy01/train,█▁
accuracy03/train,▁█
accuracy05/train,▁█
loss/train,█▁
accuracy01/valid,█▁
accuracy03/valid,█▁
accuracy05/valid,█▁
loss/valid,▁█
lr/_base,▁█
momentum/_base,█▁


Top best models:
logdir_tune/checkpoints/train.1.pth	2.3263
