# Notebook pro trénink s destilací nad datasetem CIFAR10
V tomto notebooku je trénován MobileNetV2 nad datasetem CIFAR10, jako učitelsý model je využíván finetunued ViT nad stejným datasetem. 

MobileNetV2 je používán s náhodnou inicializací, tréninkem pouze klasifikační hlavy inicializovaného (předtrénovaného nad ImageNetem) MobileNetuV2 a trénink celého modelu, taktéž inicializovaného. Tyto tři úlohy jsou trénovány bězným způsobem a také s pomocí destilace výše zmíněného modelu.  

Při destilaci je využíváno předpočítaných logitů ze sešitu precompute_logits.

In [None]:
%pip install transformers[torch] huggingface_hub datasets evaluate torchvision

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


## Import knihoven a definice metod

In [2]:
from transformers import Trainer, TrainingArguments, MobileNetV2Config, MobileNetV2ForImageClassification, EarlyStoppingCallback
from torchvision.transforms import v2 as transformsv2
from torch.utils.data import Dataset
import torch.nn.functional as F
from PIL import Image
import torch.nn as nn
from enum import Enum
import numpy as np
import evaluate
import random
import pickle
import torch
import os

In [3]:
dataset_part = Enum('dataset_part', [('TRAIN', 1), ('TEST', 2), ('EVAL', 3)])

Resetování náhodného seedu pro replikovatelnost výsledků.
Zřejmě je možné části odebrat.

TODO: Odebrat zbytečná nastavení.

In [4]:
def reset_seed(seed=42):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

Nový wrapper, který pracuje přímo se soubory staženého a upraveného datasetu CIFAR10.
Využití načtení pomocí metody jako dříve není možné kvůli jiné checksum. 

Zároveň se již dotahují logity přímo z datasetu.

In [5]:
class CustomCIFAR10(Dataset):
    def __init__(self, root, dataset_part = dataset_part.TRAIN, transform=None):
        self.root = root
        self.dataset_part = dataset_part
        self.transform = transform

        self.data = []
        self.targets = []
        self.logits = []
        
        if self.dataset_part == dataset_part.TRAIN:
             for i in range(1, 5):
                 data_file = os.path.join(self.root, 'cifar-10-batches-py', f'data_batch_{i}')
                 with open(data_file, 'rb') as fo:
                     dict = pickle.load(fo, encoding='bytes')
                     self.data.append(dict[b'data'])
                     self.targets.extend(dict[b'labels'])
                     self.logits.extend(dict[b'logits'])  
        elif self.dataset_part == dataset_part.TEST:
            data_file = os.path.join(self.root, 'cifar-10-batches-py', 'test_batch')
            with open(data_file, 'rb') as fo:
                dict = pickle.load(fo, encoding='bytes')
                self.data.append(dict[b'data'])
                self.targets.extend(dict[b'labels'])
                self.logits.extend(dict[b'logits'])
        else:
            data_file = os.path.join(self.root, 'cifar-10-batches-py', 'data_batch_5')
            with open(data_file, "rb") as fo:
                dict = pickle.load(fo, encoding='bytes')
                self.data.append(dict[b'data'])
                self.targets.extend(dict[b'labels'])
                self.logits.extend(dict[b'logits'])  

        self.data = np.concatenate(self.data, axis=0)
        self.targets = np.array(self.targets)
        self.logits = np.array(self.logits)


    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image = self.data[index].reshape(3, 32, 32).transpose(1, 2, 0)
        label = self.targets[index]
        logit = self.logits[index]
        
        image = Image.fromarray(image.astype('uint8'), 'RGB')
        logit = torch.tensor(logit, dtype=torch.float)
        if self.transform:
            image = self.transform(image)
            
        return {
            'pixel_values': image,
            'labels': label,
            'logits': logit
        }
    
    @property
    def labels(self):
        return self.targets


Definice accuracy metriky pro trénování modelu.

In [6]:
accuracy_metric = evaluate.load("accuracy")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
f1_metric = evaluate.load("f1")

def compute_metrics(eval_pred):
    pred, labels = eval_pred
    predictions = np.argmax(pred, axis=1)
    
    accuracy = accuracy_metric.compute(predictions=predictions, references=labels)
    precision = precision_metric.compute(predictions=predictions, references=labels, average='macro', zero_division = 0)
    recall = recall_metric.compute(predictions=predictions, references=labels, average='macro', zero_division = 0)
    f1 = f1_metric.compute(predictions=predictions, references=labels, average='macro')

    return {
        "accuracy": accuracy["accuracy"],
        "precision": precision["precision"],
        "recall": recall["recall"],
        "f1": f1["f1"]
    }

Trénovací argumenty pro trainer. 

In [7]:
class Custom_training_args(TrainingArguments):
    def __init__(self, lambda_param, temperature, *args, **kwargs):
        super().__init__(*args, **kwargs)    
        self.lambda_param = lambda_param
        self.temperature = temperature

In [8]:
def get_training_args(output_dir, logging_dir, remove_unused_columns=True, lr=5e-5, epochs=5, weight_decay=0, lambda_param=.5, temp=5):
    return (
        Custom_training_args(
        output_dir=output_dir,
        eval_strategy="epoch",
        save_strategy="epoch",
        logging_strategy="epoch",
        learning_rate=lr, #Defaultní hodnota 
        per_device_train_batch_size=128,
        per_device_eval_batch_size=128,
        num_train_epochs=epochs,
        weight_decay=weight_decay,
        seed = 42,  #Defaultní hodnota 
        metric_for_best_model="f1",
        load_best_model_at_end = True,
        fp16=True, 
        logging_dir=logging_dir,
        remove_unused_columns=remove_unused_columns,
        lambda_param = lambda_param, 
        temperature = temp
    ))

Náhodně inicializovaný MobileNetV2.

In [9]:
def get_random_init_mobilenet():
    reset_seed(42)
    student_config = MobileNetV2Config()
    student_config.num_labels = 10
    return MobileNetV2ForImageClassification(student_config)

Zamražení modelu a trénink pouze klasifikační hlavy.

In [10]:
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False

    for param in model.classifier.parameters():
        param.requires_grad = True

Inicializovaný MobileNetV2.

In [11]:
def get_mobilenet():
    model_pretrained = MobileNetV2ForImageClassification.from_pretrained("google/mobilenet_v2_1.0_224")
    in_features = model_pretrained.classifier.in_features

    model_pretrained.classifier = nn.Linear(in_features,10) #Úprava klasifikační hlavy
    model_pretrained.num_labels = 10
    model_pretrained.config.num_labels = 10

    return model_pretrained

In [12]:
reset_seed(42)

In [13]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA A100 80GB PCIe MIG 1g.10gb


Provedení transformací nad datasetem.

In [14]:
transform = transformsv2.Compose([
    transformsv2.Resize((224, 224)), 
    transformsv2.ToImage(),
    transformsv2.ToDtype(torch.float32, scale=True),
    transformsv2.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

#Poslední train batch použijeme jako eval část...
test = CustomCIFAR10(root='./data/10-logits', dataset_part=dataset_part.TEST, transform=transform)
train = CustomCIFAR10(root='./data/10-logits', dataset_part=dataset_part.TRAIN, transform=transform)
eval = CustomCIFAR10(root='./data/10-logits', dataset_part=dataset_part.EVAL, transform=transform)

In [15]:
# Test rozložení --> Good Enough
import pandas as pd
df = pd.DataFrame(eval.labels)
print(df.value_counts())

0
5    1025
9    1022
3    1016
0    1014
1    1014
8    1003
4     997
6     980
7     977
2     952
Name: count, dtype: int64


### Standardní trénink náhodně inicializovaného modelu. 

In [16]:
training_args = get_training_args(output_dir="./results/cifar10-random", logging_dir='./logs/cifar10-random', lr=0.0005, weight_decay=0.008, epochs=20)
model = get_random_init_mobilenet()

In [17]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 4)]
)

In [None]:
trainer.train()

AssertionError: EarlyStoppingCallback requires load_best_model_at_end = True

In [None]:
model.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [None]:
trainer.evaluate(test)

{'eval_loss': 1.0413583517074585,
 'eval_accuracy': 0.6297,
 'eval_precision': 0.6376551865860559,
 'eval_recall': 0.6297,
 'eval_f1': 0.6264225488074469}

## Definice destilačního tréninku

Třída, která upravuje hugging face trenéra pro destilaci znalostí. Nově pracuje s logity uloženými v datasetu.

In [44]:
class ImageDistilTrainer(Trainer):
    def __init__(self, student_model=None, *args, **kwargs):
        super().__init__(model=student_model, *args, **kwargs)
        self.student = student_model
        self.loss_function = nn.KLDivLoss(reduction="batchmean")
        self.temperature = self.args.temperature
        self.lambda_param = self.args.lambda_param



    def compute_loss(self, student, inputs, return_outputs=False, num_items_in_batch=None):
        logits = inputs.pop("logits")

        student_output = student(**inputs)

        soft_teacher = F.softmax(logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_output.logits / self.temperature, dim=-1)


        distillation_loss = self.loss_function(soft_student, soft_teacher) * (self.temperature ** 2)


        student_target_loss = student_output.loss

        loss = ((1. - self.lambda_param) * student_target_loss + self.lambda_param * distillation_loss)
        return (loss, student_output) if return_outputs else loss

### Trénink náhodně inicializovaného modelu s pomocí destilace znalostí

In [45]:
reset_seed(42)

In [46]:
student_model = get_random_init_mobilenet()

In [47]:
training_args = get_training_args(output_dir="./results/cifar10-random-KD", logging_dir='./logs/cifar10-random-KD', remove_unused_columns=False, epochs=20, lr=0.00045, lambda_param=1, temp=6)

In [48]:
trainer = ImageDistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 4)]
)

In [49]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.732,1.098484,0.4264,0.438522,0.4264,0.394528


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


KeyboardInterrupt: 

In [None]:
student_model.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [None]:
trainer.evaluate(test)

{'eval_loss': 1.0984840393066406,
 'eval_accuracy': 0.4264,
 'eval_precision': 0.4385223237761998,
 'eval_recall': 0.42639999999999995,
 'eval_f1': 0.3945276383308019}

## Získání inicializovaného MobileNetV2 modelu

In [26]:
reset_seed(42)

In [27]:
model_pretrained = get_mobilenet()

In [28]:
print(model_pretrained)

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [29]:
freeze_model(model_pretrained)

In [30]:
training_args = get_training_args(output_dir="./results/cifar10-pretrained-head", logging_dir='./logs/cifar10-pretrained-head')

In [31]:
trainer = Trainer(
    model=model_pretrained,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 2)]
)

In [32]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,1.8904,1.359444,0.6308
2,1.1226,1.207275,0.6241
3,1.0168,0.97121,0.7096
4,0.9041,0.907923,0.7149
5,0.8768,0.984846,0.6807
6,0.8324,0.870077,0.7219
7,0.8274,1.021028,0.6552
8,0.8021,1.105069,0.6294
9,0.7919,0.86678,0.7069
10,0.785,0.874822,0.7126


TrainOutput(global_step=15640, training_loss=0.8626149301943572, metrics={'train_runtime': 3635.8498, 'train_samples_per_second': 275.039, 'train_steps_per_second': 4.302, 'total_flos': 2.020099608576e+18, 'train_loss': 0.8626149301943572, 'epoch': 20.0})

In [33]:
model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [34]:
trainer.evaluate(test)

{'eval_loss': 0.7993389368057251,
 'eval_accuracy': 0.7356,
 'eval_runtime': 27.842,
 'eval_samples_per_second': 359.17,
 'eval_steps_per_second': 5.639,
 'epoch': 20.0}

### Trénink inicializovaného MobileNetV2

In [35]:
reset_seed(42)

In [36]:
model_pretrained_whole = get_mobilenet()

In [37]:
training_args = get_training_args(output_dir="./results/cifar10-pretrained", logging_dir='./logs/cifar10-pretrained')

In [38]:
trainer = Trainer(
    model=model_pretrained_whole,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [39]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.6678,0.352117,0.8815
2,0.1829,0.473985,0.8552
3,0.1099,0.307834,0.9086
4,0.0515,0.306266,0.9136
5,0.0356,0.562306,0.8742
6,0.016,0.446769,0.9061
7,0.0135,0.61808,0.8736
8,0.0084,0.747778,0.854
9,0.0059,0.461412,0.9117
10,0.0038,0.43447,0.9198


TrainOutput(global_step=15640, training_loss=0.048249347503666225, metrics={'train_runtime': 5404.6192, 'train_samples_per_second': 185.027, 'train_steps_per_second': 2.894, 'total_flos': 2.020099608576e+18, 'train_loss': 0.048249347503666225, 'epoch': 20.0})

In [40]:
model_pretrained_whole.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [41]:
trainer.evaluate(test)

{'eval_loss': 0.4153886139392853,
 'eval_accuracy': 0.9295,
 'eval_runtime': 30.5178,
 'eval_samples_per_second': 327.678,
 'eval_steps_per_second': 5.145,
 'epoch': 20.0}

## Trénink s pomocí destilace znalostí inicializovaného MobileNetV2

### Trénink inicializovaného modelu - pouze klasifikační hlavy s pomocí destilace

In [42]:
reset_seed(42)

In [43]:
student_model_pretrained = get_mobilenet()

In [44]:
freeze_model(student_model_pretrained)

In [45]:
training_args = get_training_args(output_dir="./results/cifar10-pretrained-head-KD", logging_dir='./logs/cifar10-pretrained-head-KD', remove_unused_columns=False, temp=6, lambda_param=.8)

In [46]:
trainer = ImageDistilTrainer(
    student_model=student_model_pretrained,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 2)]
)

In [47]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,1.1942,1.295173,0.638
2,0.773,1.352618,0.6388
3,0.7344,1.305183,0.7129
4,0.6947,1.334172,0.7149
5,0.6873,1.368952,0.6847
6,0.6717,1.329848,0.7203
7,0.6704,1.372895,0.6647
8,0.6614,1.381742,0.6395
9,0.6585,1.350707,0.7035
10,0.6571,1.364204,0.7202


TrainOutput(global_step=15640, training_loss=0.688703476498499, metrics={'train_runtime': 3721.5092, 'train_samples_per_second': 268.708, 'train_steps_per_second': 4.203, 'total_flos': 2.020099608576e+18, 'train_loss': 0.688703476498499, 'epoch': 20.0})

In [48]:
student_model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [49]:
trainer.evaluate(test)

{'eval_loss': 1.321333408355713,
 'eval_accuracy': 0.7356,
 'eval_runtime': 30.0257,
 'eval_samples_per_second': 333.048,
 'eval_steps_per_second': 5.229,
 'epoch': 20.0}

### Trénink inicializovaného modelu s pomocí destilace

In [50]:
reset_seed(42)

In [51]:
student_model_pretrained_whole = get_mobilenet()

In [52]:
training_args = get_training_args(output_dir="./results/cifar10-pretrained-KD", logging_dir='./logs/cifar10-pretrained-KD', remove_unused_columns=False, temp=6, lambda_param=1)

In [53]:
trainer = ImageDistilTrainer(
    student_model=student_model_pretrained_whole.to(device),
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [54]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.5882,1.11303,0.8868
2,0.2465,1.138879,0.8644
3,0.1968,1.028809,0.9216
4,0.1559,0.991875,0.9252
5,0.1398,1.172694,0.8737
6,0.1223,1.053637,0.9134
7,0.118,1.166352,0.8826
8,0.1099,1.100621,0.878
9,0.1057,1.038172,0.9194
10,0.1042,0.995655,0.9324


TrainOutput(global_step=15640, training_loss=0.13773624153088426, metrics={'train_runtime': 5781.8021, 'train_samples_per_second': 172.956, 'train_steps_per_second': 2.705, 'total_flos': 2.020099608576e+18, 'train_loss': 0.13773624153088426, 'epoch': 20.0})

In [55]:
student_model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [56]:
trainer.evaluate(test)

{'eval_loss': 0.9895767569541931,
 'eval_accuracy': 0.936,
 'eval_runtime': 27.2427,
 'eval_samples_per_second': 367.071,
 'eval_steps_per_second': 5.763,
 'epoch': 20.0}