In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install wandb
!pip install catalyst



In [3]:
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mdimaorekhov[0m (use `wandb login --relogin` to force relogin)


In [4]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.models as models
from torchvision import transforms, datasets
from catalyst import dl
from catalyst.contrib.callbacks import WandbLogger
from catalyst.utils import set_global_seed
from dataclasses import dataclass
from tqdm.auto import tqdm
from pathlib import Path

from torchvision.models.resnet import ResNet, BasicBlock
from typing import Dict, List


DEVICE = torch.device("cuda")

In [5]:
class DistillationRunner(dl.Runner):

    def __init__(
        self,
        teacher: torch.nn.Module,
        teacher_weight: float,
        temperature: float,
        model = None,
        device = None
    ):
        self.teacher = teacher
        self.teacher_weight = teacher_weight
        self.temperature = temperature
        super().__init__(model=model, device=device)

    def _handle_batch(self, batch):
        features, labels = batch

        with torch.no_grad():
            self.teacher.eval()
            teacher_logits = self.teacher(features)

        student_logits = self.model(features)

        teacher_loss = prob_cross_entropy(
            out=student_logits / self.temperature,
            target=teacher_logits / self.temperature
        )
        scale = self.temperature ** 2
        ce_loss = F.cross_entropy(
            student_logits, labels
        )
        total_loss = (
            (1 - self.teacher_weight) * ce_loss
            + self.teacher_weight * scale * teacher_loss
        )

        self.batch_metrics.update({
            'teacher_loss': teacher_loss,
            'ce_loss': ce_loss,
            'loss': total_loss
        })

        self.input = {'targets': labels.cpu()}
        self.output = {'logits': student_logits.detach().cpu()}


def prob_cross_entropy(out, target):
    out_log_probs = F.log_softmax(out, dim=-1)
    target_probs = F.softmax(target, dim=-1)
    return torch.mean(-1 * torch.sum(target_probs * out_log_probs, dim=-1))


In [6]:
@dataclass
class Config:

    experiment_name: str = "distill"

    flip_prob: float = 0.5
    rotation_degrees: float = 25

    logdir: str = 'logdir_distillation'

    teacher_checkpoint: str = 'drive/MyDrive/logdir_tune/checkpoints/best.pth'

    student_layers: str = "[1, 1, 1, 1]"

    teacher_weight: float = 0.85
    temperature: float = 1.25

    max_lr: float = 1e-3
    weight_decay: float = 1e-6

    num_epochs: int = 100
    batch_size: int = 32
    patience: int = 2

    seed: int = 21


    def to_dict(self):
        as_dict = {}
        for key, val in self.__dict__.items():
            if key == "student_layers":
                val = eval(val)
            as_dict[key] = val
        return as_dict


config = Config()
set_global_seed(config.seed)

In [7]:
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)

RESIZE_BY = 256
train_transforms = transforms.Compose([
    transforms.Resize(RESIZE_BY),
    transforms.RandomHorizontalFlip(config.flip_prob),
    transforms.RandomRotation(config.rotation_degrees),
    transforms.ToTensor(),
    normalize
])

test_transforms = transforms.Compose([
    transforms.Resize(RESIZE_BY),
    transforms.ToTensor(),
    normalize
])

In [8]:
train = datasets.CIFAR10('data', train=True, download=True, transform=train_transforms)
test = datasets.CIFAR10('data', train=False, download=True, transform=test_transforms)
print(len(train))
print(len(test))


N_CLASSES = 10

Files already downloaded and verified
Files already downloaded and verified
50000
10000


In [9]:
teacher = models.resnet18()
teacher.fc = nn.Linear(512, N_CLASSES)
checkpoint = torch.load(
    config.teacher_checkpoint, map_location=torch.device("cpu")
)
teacher.load_state_dict(checkpoint['model_state_dict'])
teacher = teacher.eval().to(DEVICE)

In [10]:
student = ResNet(
    block=BasicBlock,
    layers=eval(config.student_layers),
    num_classes=N_CLASSES
)

In [11]:
loaders = {
    'train': DataLoader(train, batch_size=config.batch_size, shuffle=True),
    'valid': DataLoader(test, batch_size=config.batch_size)
}

optimizer = torch.optim.Adam(student.parameters(), weight_decay=config.weight_decay)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=config.max_lr,
    epochs=config.num_epochs,
    steps_per_epoch=len(loaders['train'])
)

callbacks = [
    dl.OptimizerCallback(metric_key='loss'),
    dl.SchedulerCallback(mode='batch'),
    dl.EarlyStoppingCallback(config.patience),
    dl.AccuracyCallback(topk_args=[1, 3, 5], num_classes=N_CLASSES),
    WandbLogger(
        project='dl-course',
        entity='dimaorekhov',
        group='distillation',
        name=config.experiment_name,
        config=config.to_dict()
    )
]

In [12]:
Path(config.logdir).mkdir(exist_ok=True)

In [None]:
runner = DistillationRunner(
    teacher=teacher,
    teacher_weight=config.teacher_weight,
    temperature=config.temperature,
    device=DEVICE
)

runner.train(
    model=student,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=loaders,
    callbacks=callbacks,
    num_epochs=config.num_epochs,
    verbose=True,
    logdir=config.logdir
)

[34m[1mwandb[0m: Currently logged in as: [33mdimaorekhov[0m (use `wandb login --relogin` to force relogin)


1/100 * Epoch (train):   0% 1/1563 [00:00<11:30,  2.26it/s, accuracy01=0.094, accuracy03=0.250, accuracy05=0.531, ce_loss=2.426, loss=3.545, lr=4.000e-05, momentum=0.950, teacher_loss=2.395]


To get the last learning rate computed by the scheduler, please use `get_last_lr()`.



1/100 * Epoch (train): 100% 1563/1563 [09:08<00:00,  2.85it/s, accuracy01=0.438, accuracy03=0.750, accuracy05=0.875, ce_loss=1.605, loss=2.251, lr=4.263e-05, momentum=0.950, teacher_loss=1.513]
1/100 * Epoch (valid): 100% 313/313 [01:08<00:00,  4.58it/s, accuracy01=0.625, accuracy03=0.875, accuracy05=0.875, ce_loss=1.204, loss=1.761, teacher_loss=1.190]
[2020-11-30 15:30:52,522] 
1/100 * Epoch 1 (_base): lr=4.263e-05 | momentum=0.9497
1/100 * Epoch 1 (train): accuracy01=0.4300 | accuracy03=0.7655 | accuracy05=0.8983 | ce_loss=1.5652 | loss=2.3658 | lr=4.088e-05 | momentum=0.9499 | teacher_loss=1.6046
1/100 * Epoch 1 (valid): accuracy01=0.5278 | accuracy03=0.8465 | accuracy05=0.9450 | ce_loss=1.2982 | loss=1.9575 | teacher_loss=1.3273
2/100 * Epoch (train): 100% 1563/1563 [09:01<00:00,  2.89it/s, accuracy01=0.500, accuracy03=0.938, accuracy05=0.938, ce_loss=1.227, loss=1.843, lr=5.049e-05, momentum=0.949, teacher_loss=1.249]
2/100 * Epoch (valid): 100% 313/313 [01:09<00:00,  4.50it/s, a