# ***Tarea 3***
## Asignatura: Transformer y Diffusers

#### Autor: Victor M. Fonte Chavez

#### Ejercicio 2: Destilar el conocimiento del Vision Transformer (ViT)

In [9]:
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
torch.set_float32_matmul_precision('medium')

import torchvision as tv
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

from torchmetrics import MetricCollection
from torchmetrics.classification.accuracy import Accuracy
from torchmetrics.classification.stat_scores import StatScores

from transformers import ViTForImageClassification, DeiTForImageClassification

import lightning as pl

### Trainer Class: Comun para todos los ejercicios

In [4]:
class TrainerFineTune(pl.LightningModule):
    def __init__(self, model, save_dir = None):
        super(TrainerFineTune, self).__init__()

        self.n_classes = 102
        
        self.net = model
        for name, param in self.net.named_parameters():
            if "classifier" not in name:
                param.requires_grad = False
        
        self.save_dir = save_dir

        self.train_metrics = MetricCollection(
            {
                "train_acc": Accuracy(num_classes=self.n_classes, task="multiclass", top_k=1),
                "train_acc_top5": Accuracy(
                    num_classes=self.n_classes,
                    task="multiclass",
                    top_k=min(5, self.n_classes),
                ),
            }
        )
        self.val_metrics = MetricCollection(
            {
                "val_acc": Accuracy(num_classes=self.n_classes, task="multiclass", top_k=1),
                "val_acc_top5": Accuracy(
                    num_classes=self.n_classes,
                    task="multiclass",
                    top_k=min(5, self.n_classes),
                ),
            }
        )
        self.train_progress = {"loss": [], "acc": []}
        self.val_progress = {"loss": [], "acc": []}
        self.train_loss = []
        self.train_acc = []
        self.val_loss = []
        self.val_acc = []

    def forward(self, x):
        x = self.net(x).logits
        return x

    def training_step(self, batch, _):
        images, y = batch
        y = F.one_hot(y, num_classes=self.n_classes).float()

        y_hat = self.forward(images)
        loss = F.cross_entropy(y_hat, y.argmax(1))
        mets = self.train_metrics(y_hat, y.argmax(1))

        self.train_progress["loss"].append(loss.item())
        self.train_progress["acc"].append(100*mets["train_acc"].item())

        self.log("lr", self.trainer.optimizers[0].param_groups[0]["lr"], prog_bar=True)
        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=False, sync_dist=True)
        with torch.inference_mode():
            for k, v in mets.items():
                self.log(k, 100*v, prog_bar=True, on_step=True, on_epoch=False, sync_dist=True)

        return loss

    @torch.no_grad()
    def validation_step(self, batch, _):
        images, y = batch
        y = F.one_hot(y, num_classes=self.n_classes).float()

        y_hat = self.forward(images)
        loss = F.cross_entropy(y_hat, y.argmax(1))
        mets = self.val_metrics(y_hat, y.argmax(1))

        self.val_progress["loss"].append(loss.item())
        self.val_progress["acc"].append(100*mets["val_acc"].item())

        self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
        for k, v in mets.items():
            self.log(k, 100*v, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)

    @torch.no_grad()
    def on_train_epoch_end(self):
        self.train_loss.append(np.mean(self.train_progress["loss"]))
        self.val_loss.append(np.mean(self.val_progress["loss"]))
        self.train_acc.append(np.mean(self.train_progress["acc"]))
        self.val_acc.append(np.mean(self.val_progress["acc"]))

        self.train_progress["loss"].clear()
        self.val_progress["loss"].clear()
        self.train_progress["acc"].clear()
        self.val_progress["acc"].clear()

        print(f"\nEpoch {self.current_epoch+1}/{self.trainer.max_epochs}", f"val_loss: {self.val_loss[-1]:.4f}", f"val_acc: {self.val_acc[-1]:.4f}")

        if self.save_dir is not None:
            torch.save(self.net, os.path.join(self.save_dir, f"model_{self.current_epoch}.ckpt"))

    def configure_optimizers(self):
        optimizer = optim.Adam(self.net.parameters())
        scheduler = CosineAnnealingLR(
                    optimizer,
                    T_max=410,
                    eta_min=5e-5
                )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": 1,
            }
        }

In [5]:
class TrainerDistillation(pl.LightningModule):
    def __init__(self, teacher, student, save_dir = None, alpha = 0.25, tau = 2):
        
        super(TrainerDistillation, self).__init__()

        self.n_classes = 102

        self.teacher = teacher.eval().requires_grad_(False)
        self.student = student
        self.freeze_gradients(self.student)
        
        self.save_dir = save_dir
        self.alpha = alpha
        self.tau = tau

        self.train_metrics = MetricCollection(
            {
                "train_acc": Accuracy(num_classes=self.n_classes, task="multiclass", top_k=1),
                "train_acc_top5": Accuracy(
                    num_classes=self.n_classes,
                    task="multiclass",
                    top_k=min(5, self.n_classes),
                ),
            }
        )
        self.val_metrics = MetricCollection(
            {
                "val_acc": Accuracy(num_classes=self.n_classes, task="multiclass", top_k=1),
                "val_acc_top5": Accuracy(
                    num_classes=self.n_classes,
                    task="multiclass",
                    top_k=min(5, self.n_classes),
                ),
            }
        )
        self.train_progress = {"loss": [], "acc": []}
        self.val_progress = {"loss": [], "acc": []}
        self.train_loss = []
        self.train_acc = []
        self.val_loss = []
        self.val_acc = []
    
    def freeze_gradients(self, net):
        for name, param in net.named_parameters():
            if "classifier" not in name:
                param.requires_grad = False

    def student_loss(self, y, y_hat):
        return F.cross_entropy(y_hat, y)

    def distillation_loss(self, student_logits, teacher_logits):
        soft_targets = F.softmax(teacher_logits/self.tau, dim=1)
        soft_prob = F.log_softmax(student_logits/self.tau, dim=1)
        return -torch.sum(soft_targets * soft_prob) / soft_prob.size()[0] * (self.tau**2)

    def forward(self, x, choice = "teacher"):
        if choice == "teacher":
            return self.teacher(x).logits
        else:
            return self.student(x).logits

    def training_step(self, batch, _):
        images, y = batch
        y = F.one_hot(y, num_classes=self.n_classes).float()
        
        teacher_pred = self.forward(images, "teacher")
        student_pred = self.forward(images, "student")

        student_loss = self.student_loss(y, student_pred)
        distillation_loss = self.distillation_loss(student_pred, teacher_pred)
        loss = (1 - self.alpha) * student_loss + self.alpha * distillation_loss

        mets = self.train_metrics(student_pred, y.argmax(1))

        self.train_progress["loss"].append(loss.item())
        self.train_progress["acc"].append(100*mets["train_acc"].item())

        self.log("lr", self.trainer.optimizers[0].param_groups[0]["lr"], prog_bar=True)
        self.log("train_student_loss", student_loss, prog_bar=True, on_step=True, on_epoch=False, sync_dist=True)
        self.log("train_distillation_loss", distillation_loss, prog_bar=True, on_step=True, on_epoch=False, sync_dist=True)
        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=False, sync_dist=True)
        with torch.inference_mode():
            for k, v in mets.items():
                self.log(k, 100*v, prog_bar=True, on_step=True, on_epoch=False, sync_dist=True)

        return loss

    @torch.no_grad()
    def validation_step(self, batch, _):
        images, y = batch
        y = F.one_hot(y, num_classes=self.n_classes).float()

        teacher_pred = self.forward(images, "teacher")
        student_pred = self.forward(images, "student")

        student_loss = self.student_loss(y, student_pred)
        distillation_loss = self.distillation_loss(student_pred, teacher_pred)
        loss = (1 - self.alpha) * student_loss + self.alpha * distillation_loss

        mets = self.val_metrics(student_pred, y.argmax(1))

        self.val_progress["loss"].append(loss.item())
        self.val_progress["acc"].append(100*mets["val_acc"].item())
        
        self.log("val_student_loss", student_loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
        self.log("val_distillation_loss", distillation_loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
        self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
        for k, v in mets.items():
            self.log(k, 100*v, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)

    @torch.no_grad()
    def on_train_epoch_end(self):
        self.train_loss.append(np.mean(self.train_progress["loss"]))
        self.val_loss.append(np.mean(self.val_progress["loss"]))
        self.train_acc.append(np.mean(self.train_progress["acc"]))
        self.val_acc.append(np.mean(self.val_progress["acc"]))

        self.train_progress["loss"].clear()
        self.val_progress["loss"].clear()
        self.train_progress["acc"].clear()
        self.val_progress["acc"].clear()

        print(f"\nEpoch {self.current_epoch+1}/{self.trainer.max_epochs}", \
              f"val_loss: {self.val_loss[-1]:.4f}", f"val_acc: {self.val_acc[-1]:.4f}")

        if self.save_dir is not None:
            torch.save(self.student, os.path.join(self.save_dir, f"model_{self.current_epoch}.ckpt"))

    def configure_optimizers(self):
        optimizer = optim.Adam(self.student.parameters())
        scheduler = CosineAnnealingLR(
                    optimizer,
                    T_max=410,
                    eta_min=5e-5
                )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": 1,
            }
        }


### Cargado de los datos

In [6]:
data_dir = "./_data"
train_dir = os.path.join(data_dir, "train")
valid_dir = os.path.join(data_dir, "valid")

cat_to_name_file = os.path.join(data_dir, 'cat_to_name.json')

with open(cat_to_name_file, 'r') as f:
    cat_to_name = json.load(f)

In [7]:
size = 224
min_scale = 0.08
max_scale = 1.0
flip_prob = 0.5
rand_aug_n = 0
rand_aug_m = 9
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)

batch_size = 32
n_classes = len(os.listdir(train_dir))

transforms_train = transforms.Compose(
    [
        transforms.RandomResizedCrop(
            (size, size),
            scale=(min_scale, max_scale),
        ),
        transforms.RandomHorizontalFlip(flip_prob),
        transforms.RandAugment(rand_aug_n, rand_aug_m),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ]
)
transforms_val = transforms.Compose(
    [
        transforms.Resize(
            (size, size),
        ),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ]
)

train_dataset = ImageFolder(root=train_dir, transform=transforms_train)
valid_dataset = ImageFolder(root=valid_dir, transform=transforms_val)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

Tomo un modelo Destilado de ViT y realizo finetune sobre el conjunto de datos de Flowers102. 

In [None]:
distilled_vit = DeiTForImageClassification.from_pretrained('facebook/deit-base-distilled-patch16-224', num_labels=102)
model_trainer = TrainerFineTune(distilled_vit)

trainer = pl.Trainer(
    max_epochs = 10,
    devices=1,
    accelerator='gpu',
)
trainer.fit(model_trainer, train_dataloader, valid_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                       | Params
-------------------------------------------------------------
0 | net           | DeiTForImageClassification | 85.9 M
1 | train_metrics | MetricCollection           | 0     
2 | val_metrics   | MetricCollection           | 0     
-------------------------------------------------------------
78.4 K    Trainable params
85.8 M    Non-trainable params
85.9 M    Total params
343.515   Total estimated model params size (MB)


Sanity Checking:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 0: 100%|██████████| 205/205 [01:26<00:00,  2.37it/s, v_num=54, lr=0.000529, train_loss=1.200, train_acc=79.20, train_acc_top5=95.80, val_loss=1.300, val_acc=80.60, val_acc_top5=96.20]
Epoch 1/10 val_loss: 1.5360 val_acc: 75.0248
Epoch 1: 100%|██████████| 205/205 [01:46<00:00,  1.92it/s, v_num=54, lr=5e-5, train_loss=0.835, train_acc=95.80, train_acc_top5=95.80, val_loss=1.050, val_acc=85.50, val_acc_top5=97.60]    
Epoch 2/10 val_loss: 1.0465 val_acc: 85.6036
Epoch 2: 100%|██████████| 205/205 [02:01<00:00,  1.69it/s, v_num=54, lr=0.000521, train_loss=1.070, train_acc=79.20, train_acc_top5=95.80, val_loss=0.899, val_acc=85.90, val_acc_top5=98.00]
Epoch 3/10 val_loss: 0.8981 val_acc: 85.9909
Epoch 3: 100%|██████████| 205/205 [01:46<00:00,  1.92it/s, v_num=54, lr=0.001, train_loss=0.596, train_acc=87.50, train_acc_top5=100.0, val_loss=0.597, val_acc=89.90, val_acc_top5=98.50]   
Epoch 4/10 val_loss: 0.5960 val_acc: 89.9306
Epoch 4: 100%|██████████| 205/205 [01:56<00:00,  1.76it/s, v

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 205/205 [02:02<00:00,  1.67it/s, v_num=54, lr=5e-5, train_loss=0.227, train_acc=100.0, train_acc_top5=100.0, val_loss=0.314, val_acc=94.00, val_acc_top5=98.70]


Tomo un modelo ViT y destilo el conocimiento del ViT entrenado anteriormente. 

In [12]:
teacher = torch.load('./_checkpoints/teacher/epoch_9.pth')
student = DeiTForImageClassification.from_pretrained('facebook/deit-base-distilled-patch16-224', num_labels=102)

model_trainer = TrainerDistillation(teacher, student)

trainer = pl.Trainer(
    max_epochs = 10,
    devices=1,
    accelerator='gpu',
)
trainer.fit(model_trainer, train_dataloader, valid_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                       | Params
-------------------------------------------------------------
0 | teacher       | ViTForImageClassification  | 85.9 M
1 | student       | DeiTForImageClassification | 85.9 M
2 | train_metrics | MetricCollection           | 0     
3 | val_metrics   | MetricCollection           | 0   

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 0: 100%|██████████| 205/205 [02:35<00:00,  1.31it/s, v_num=55, lr=0.000529, train_student_loss=1.060, train_distillation_loss=19.50, train_loss=5.660, train_acc=95.80, train_acc_top5=100.0, val_student_loss=1.430, val_distillation_loss=19.40, val_loss=5.910, val_acc=81.50, val_acc_top5=96.70]
Epoch 1/10 val_loss: 6.0768 val_acc: 76.0293
Epoch 1: 100%|██████████| 205/205 [02:44<00:00,  1.25it/s, v_num=55, lr=5e-5, train_student_loss=1.340, train_distillation_loss=19.50, train_loss=5.890, train_acc=83.30, train_acc_top5=91.70, val_student_loss=1.220, val_distillation_loss=19.50, val_loss=5.780, val_acc=87.00, val_acc_top5=97.70]    
Epoch 2/10 val_loss: 5.7744 val_acc: 87.1661
Epoch 2: 100%|██████████| 205/205 [02:51<00:00,  1.19it/s, v_num=55, lr=0.000521, train_student_loss=1.130, train_distillation_loss=19.70, train_loss=5.790, train_acc=83.30, train_acc_top5=91.70, val_student_loss=1.100, val_distillation_loss=19.50, val_loss=5.710, val_acc=87.40, val_acc_top5=98.20]
Epoch 3/10

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 205/205 [02:11<00:00,  1.56it/s, v_num=55, lr=5e-5, train_student_loss=0.492, train_distillation_loss=19.70, train_loss=5.310, train_acc=95.80, train_acc_top5=100.0, val_student_loss=0.632, val_distillation_loss=19.70, val_loss=5.390, val_acc=94.50, val_acc_top5=99.10]


Creo mi propio modelo destilado a partir de quitarle la mitad de las capas a un ViT grande y completamente preentrenado 

In [35]:
class DistillViT(nn.Module):
    def __init__(self, vit_model, delete_weights=True):
        super(DistillViT, self).__init__()
        self.vit_model = vit_model
        if delete_weights:
            self.vit_model.apply(self.reset_parameters)

        num_layers_to_remove = len(self.vit_model.vit.encoder.layer) // 2
        self.vit_model.vit.encoder.layer = self.vit_model.vit.encoder.layer[:num_layers_to_remove]
    
    def reset_parameters(self, layer):
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
            layer.reset_parameters()

    def forward(self, x):
        x = self.vit_model(x)
        return x

In [14]:
vit = ViTForImageClassification.from_pretrained(
            'google/vit-base-patch16-224-in21k', 
            num_labels=102
        )
distilled_vit = DistillViT(vit, delete_weights=False)
model_trainer = TrainerFineTune(distilled_vit)

trainer = pl.Trainer(
    max_epochs = 10,
    devices=1,
    accelerator='gpu',
)
trainer.fit(model_trainer, train_dataloader, valid_dataloader)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params
---------------------------------------------------
0 | net           | DistillViT       | 43.3 M
1 | train_metrics | MetricCollection | 0     
2 | val_metrics   | MetricCollection | 0     
---------------------------------------------------
78.4 K    Trainable params
43.3 M    Non-trainable params
43.3 M    Total params
173.399   Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 0: 100%|██████████| 205/205 [00:54<00:00,  3.76it/s, v_num=65, lr=0.000529, train_loss=2.860, train_acc=54.20, train_acc_top5=79.20, val_loss=2.950, val_acc=40.70, val_acc_top5=75.10]
Epoch 1/10 val_loss: 3.0733 val_acc: 37.8596
Epoch 1: 100%|██████████| 205/205 [00:50<00:00,  4.08it/s, v_num=65, lr=5e-5, train_loss=2.410, train_acc=50.00, train_acc_top5=87.50, val_loss=2.670, val_acc=49.90, val_acc_top5=81.40]    
Epoch 2/10 val_loss: 2.6746 val_acc: 50.1603
Epoch 2: 100%|██████████| 205/205 [00:50<00:00,  4.06it/s, v_num=65, lr=0.000521, train_loss=2.010, train_acc=66.70, train_acc_top5=100.0, val_loss=2.440, val_acc=56.50, val_acc_top5=85.80]
Epoch 3/10 val_loss: 2.4438 val_acc: 56.7441
Epoch 3: 100%|██████████| 205/205 [00:50<00:00,  4.06it/s, v_num=65, lr=0.001, train_loss=1.730, train_acc=70.80, train_acc_top5=91.70, val_loss=1.830, val_acc=70.50, val_acc_top5=92.40]   
Epoch 4/10 val_loss: 1.8302 val_acc: 70.8467
Epoch 4: 100%|██████████| 205/205 [01:10<00:00,  2.89it/s, v

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 205/205 [01:05<00:00,  3.14it/s, v_num=65, lr=5e-5, train_loss=0.903, train_acc=87.50, train_acc_top5=95.80, val_loss=0.951, val_acc=85.50, val_acc_top5=96.80]


In [36]:
teacher = torch.load('./_checkpoints/teacher/epoch_9.pth')
vit = ViTForImageClassification.from_pretrained(
            'google/vit-base-patch16-224-in21k', 
            num_labels=102
        )
student = DistillViT(vit, delete_weights=False)

model_trainer = TrainerDistillation(teacher, student)

trainer = pl.Trainer(
    max_epochs = 10,
    devices=1,
    accelerator='gpu',
)
trainer.fit(model_trainer, train_dataloader, valid_dataloader)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                      | Params
------------------------------------------------------------
0 | teacher       | ViTForImageClassification | 85.9 M
1 | student       | DistillViT                | 43.3 M
2 | train_metrics | MetricCollection          | 0     
3 | val_metrics   | MetricCollection          | 0     
------------------------------------------------------------
78.4 K    Trainable params
129 M     Non-trainable params
129 M     Total params
516.908   Total estima

Epoch 0: 100%|██████████| 205/205 [01:42<00:00,  2.00it/s, v_num=68, lr=0.000529, train_student_loss=3.210, train_distillation_loss=15.50, train_loss=6.290, train_acc=33.30, train_acc_top5=54.20, val_student_loss=2.920, val_distillation_loss=15.00, val_loss=5.940, val_acc=38.30, val_acc_top5=72.00]
Epoch 1/10 val_loss: 6.0985 val_acc: 35.7639
Epoch 1: 100%|██████████| 205/205 [01:41<00:00,  2.03it/s, v_num=68, lr=5e-5, train_student_loss=2.480, train_distillation_loss=14.20, train_loss=5.400, train_acc=58.30, train_acc_top5=75.00, val_student_loss=2.650, val_distillation_loss=14.40, val_loss=5.570, val_acc=45.50, val_acc_top5=77.50]    
Epoch 2/10 val_loss: 5.5814 val_acc: 45.7399
Epoch 2: 100%|██████████| 205/205 [01:41<00:00,  2.02it/s, v_num=68, lr=0.000521, train_student_loss=1.930, train_distillation_loss=12.70, train_loss=4.620, train_acc=66.70, train_acc_top5=91.70, val_student_loss=2.410, val_distillation_loss=13.80, val_loss=5.250, val_acc=51.00, val_acc_top5=81.20]
Epoch 3/10

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 205/205 [01:44<00:00,  1.97it/s, v_num=68, lr=5e-5, train_student_loss=0.667, train_distillation_loss=8.190, train_loss=2.550, train_acc=95.80, train_acc_top5=95.80, val_student_loss=0.886, val_distillation_loss=8.670, val_loss=2.830, val_acc=82.60, val_acc_top5=96.00]
