In [21]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.models as models
from torchvision import transforms, datasets
from catalyst import dl
from catalyst.contrib.callbacks import WandbLogger
from catalyst.utils import set_global_seed
from dataclasses import dataclass
from tqdm.auto import tqdm
from pathlib import Path

from torchvision.models.resnet import ResNet, BasicBlock
from typing import Dict, List

In [2]:
class DistillationRunner(dl.Runner):

    def __init__(
        self,
        teacher_weight: float,
        temperature: float,
        model = None,
        device = None
    ):
        self.teacher_weight = teacher_weight
        self.temperature = temperature
        super().__init__(model=model, device=device)

    def _handle_batch(self, batch):
        features, labels = batch

        teacher = self.model['teacher']
        student = self.model['student']

        with torch.no_grad():
            teacher.eval()
            teacher_logits = teacher(features)

        student_logits = student(features)

        teacher_loss = prob_cross_entropy(
            out=student_logits / self.temperature,
            target=teacher_logits / self.temperature
        )
        scale = self.temperature ** 2
        ce_loss = F.cross_entropy(
            student_logits, labels
        )
        total_loss = (
            (1 - self.teacher_weight) * ce_loss
            + self.teacher_weight * scale * teacher_loss
        )

        self.batch_metrics.update({
            'teacher_loss': teacher_loss,
            'ce_loss': ce_loss,
            'loss': total_loss
        })


def prob_cross_entropy(out, target):
    out_log_probs = F.log_softmax(out, dim=-1)
    target_probs = F.softmax(target, dim=-1)
    return torch.mean(-1 * torch.sum(target_probs * out_log_probs, dim=-1))


In [26]:
@dataclass
class Config:

    experiment_name: str = "distill-tuned-teacher"

    flip_prob: float = 0.5
    rotation_degrees: float = 15

    logdir: str = 'logdir_tune'
        
    teacher_checkpoint: str = None

    student_layers: str = "[2, 2]"
    student_to_teacher_layers_map: str = """{
        'layer1': 'layer2',
        'layer2': 'layer4'
    }"""

    teacher_weight: float = 0.85
    temperature: float = 1.5

    max_lr: float = 5e-5
    weight_decay: float = 1e-5

    num_epochs: int = 10
    batch_size: int = 32
    patience: int = 2

    seed: int = 21


    def to_dict(self):
        as_dict = {}
        for key, val in self.__dict__.items():
            if key in ["student_layers", "student_to_teacher_layers_map"]:
                val = eval(val)
            as_dict[key] = val
        return as_dict


config = Config()
assert len(eval(config.student_layers)) == len(eval(config.student_to_teacher_layers_map))
set_global_seed(config.seed)

In [27]:
config.to_dict()

{'experiment_name': 'finetune-vgg-on-cifar-10',
 'flip_prob': 0.5,
 'rotation_degrees': 15,
 'logdir': 'logdir_tune',
 'teacher_checkpoint': None,
 'student_layers': [2, 2],
 'student_to_teacher_layers_map': {'layer1': 'layer2', 'layer2': 'layer4'},
 'teacher_weight': 0.85,
 'temperature': 1.5,
 'max_lr': 5e-05,
 'weight_decay': 1e-05,
 'num_epochs': 10,
 'batch_size': 32,
 'patience': 2,
 'seed': 21}

In [18]:
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)

train_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(config.flip_prob),
    transforms.RandomRotation(config.rotation_degrees),
    transforms.ToTensor(),
    normalize
])

test_transforms = transforms.Compose([
    transforms.ToTensor(),
    normalize
])

I will be first finetuning teacher on CIFAR10, since all pretrained torchvision models are for ImageNet.

And the ImageNet itself is too large to handle with computational resources I have.

In [6]:
train = datasets.CIFAR10('data', train=True, download=True, transform=train_transforms)
test = datasets.CIFAR10('data', train=False, download=True, transform=test_transforms)
print(len(train))
print(len(test))


N_CLASSES = 10

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified
50000
10000


In [7]:
teacher = models.resnet18()




In [7]:
def freeze_module(module: nn.Module):
    for param in module.parameters():
        param.requires_grad = False


def unfreeze_module(module: nn.Module):
    for param in module.parameters():
        param.requires_grad = True


teacher.fc = nn.Linear(512, N_CLASSES)
checkpoint = torch.load(
    config.teacher_checkpoint, map_location=torch.device("cpu")
)
teacher.load_state_dict(checkpoint['model_state_dict'])
teacher.eval()

In [11]:
eval(config.student_layers)

[2, 2]

In [None]:
student = ResNet(
    block=BasicBlock,
    layers=eval(config.student_layers),
    num_classes=N_CLASSES
)



In [8]:
model = {
    'teacher': teacher,
    'student': student
}

loaders = {
    'train': DataLoader(train, batch_size=config.batch_size, shuffle=True),
    'valid': DataLoader(test, batch_size=config.batch_size)
}

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(teacher.parameters(), weight_decay=config.weight_decay)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=config.max_lr,
    epochs=config.num_epochs,
    steps_per_epoch=len(loaders['train'])
)

callbacks = [
    dl.OptimizerCallback(metric_key='loss'),
    dl.SchedulerCallback(mode='batch'),
    dl.EarlyStoppingCallback(config.patience),
    dl.AccuracyCallback(topk_args=[1, 3, 5], num_classes=N_CLASSES),
    WandbLogger(
        project='dl-course',
        entity='dimaorekhov',
        group='distillation',
        name=config.experiment_name,
        config=dict(config.__dict__)
    )
]

In [9]:
Path(config.logdir).mkdir(exist_ok=True)

In [10]:
runner = DistillationRunner(
    teacher_weight=config.teacher_weight,
    temperature=config.temperature
)

runner.train(
    model=teacher,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=loaders,
    callbacks=callbacks,
    num_epochs=config.num_epochs,
    verbose=True,
    logdir=config.to_dict,
    check=True
)

[34m[1mwandb[0m: Currently logged in as: [33mdimaorekhov[0m (use `wandb login --relogin` to force relogin)


1/10 * Epoch (train):   0% 4/1563 [00:01<09:45,  2.66it/s, accuracy01=0.125, accuracy03=0.281, accuracy05=0.594, loss=2.434]
1/10 * Epoch (valid):   1% 4/313 [00:00<00:20, 15.15it/s, accuracy01=0.094, accuracy03=0.375, accuracy05=0.562, loss=2.406]
[2020-11-30 01:36:42,850] 
1/10 * Epoch 1 (_base): lr=2.000e-06 | momentum=0.9500
1/10 * Epoch 1 (train): accuracy01=0.1484 | accuracy03=0.3359 | accuracy05=0.5156 | loss=2.5009
1/10 * Epoch 1 (valid): accuracy01=0.1250 | accuracy03=0.4141 | accuracy05=0.6016 | loss=2.3263



To get the last learning rate computed by the scheduler, please use `get_last_lr()`.



2/10 * Epoch (train):   0% 4/1563 [00:01<08:05,  3.21it/s, accuracy01=0.062, accuracy03=0.500, accuracy05=0.625, loss=2.531]
2/10 * Epoch (valid):   1% 4/313 [00:00<00:20, 14.94it/s, accuracy01=0.156, accuracy03=0.406, accuracy05=0.531, loss=2.430]
[2020-11-30 01:36:44,651] 
2/10 * Epoch 2 (_base): lr=2.000e-06 | momentum=0.9500
2/10 * Epoch 2 (train): accuracy01=0.1094 | accuracy03=0.3906 | accuracy05=0.5234 | loss=2.4674
2/10 * Epoch 2 (valid): accuracy01=0.1172 | accuracy03=0.4062 | accuracy05=0.5781 | loss=2.3539


VBox(children=(Label(value=' 0.01MB of 0.01MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
accuracy01/train,0.10938
accuracy03/train,0.39062
accuracy05/train,0.52344
loss/train,2.46739
accuracy01/valid,0.11719
accuracy03/valid,0.40625
accuracy05/valid,0.57812
loss/valid,2.35387
lr/_base,0.0
momentum/_base,0.95


0,1
accuracy01/train,█▁
accuracy03/train,▁█
accuracy05/train,▁█
loss/train,█▁
accuracy01/valid,█▁
accuracy03/valid,█▁
accuracy05/valid,█▁
loss/valid,▁█
lr/_base,▁█
momentum/_base,█▁


Top best models:
logdir_tune/checkpoints/train.1.pth	2.3263
