# Notebook pro trénink s destilací nad datasetem CIFAR100
V tomto notebooku je trénován MobileNetV2 nad datasetem CIFAR100, jako učitelsý model je využíván finetunued ViT nad stejným datasetem. 

MobileNetV2 je používán s náhodnou inicializací, tréninkem pouze klasifikační hlavy inicializovaného (předtrénovaného nad ImageNetem) MobileNetuV2 a trénink celého modelu, taktéž inicializovaného. Tyto tři úlohy jsou trénovány bězným způsobem a také s pomocí destilace výše zmíněného modelu.  

Při destilaci je využíváno předpočítaných logitů ze sešitu precompute_logits.

In [1]:
%pip install transformers[torch] huggingface_hub datasets evaluate torchvision

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


## Import knihoven a definice metod

In [1]:
from transformers import Trainer, TrainingArguments, MobileNetV2Config, MobileNetV2ForImageClassification, EarlyStoppingCallback
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, Subset
from torchvision import transforms
import torch.nn.functional as F
from PIL import Image
import torch.nn as nn
import numpy as np
import evaluate
import random
import pickle
import torch
import os

Resetování náhodného seedu pro replikovatelnost výsledků.
Zřejmě je možné části odebrat.

TODO: Odebrat zbytečná nastavení.

In [2]:
def reset_seed(seed=42):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

Nový wrapper, který pracuje přímo se soubory staženého a upraveného datasetu CIFAR10.
Využití načtení pomocí metody jako dříve není možné kvůli jiné checksum. 

Zároveň se již dotahují logity přímo z datasetu.

In [31]:
class CustomCIFAR100(Dataset):
    def __init__(self, root, train=True, transform=None):
        self.root = root
        self.train = train
        self.transform = transform


        self.data = []
        self.targets = []
        self.logits = []
        
        if self.train:
            data_file = os.path.join(self.root, 'cifar-100-python', 'train')
        else:
            data_file = os.path.join(self.root, 'cifar-100-python', 'test')
            
        with open(data_file, 'rb') as fo:
            dict = pickle.load(fo, encoding='bytes')
            self.data.append(dict[b'data'])
            self.targets.extend(dict[b'fine_labels'])
            self.logits.extend(dict[b'logits'])  

        self.data = np.concatenate(self.data, axis=0)
        self.targets = np.array(self.targets)
        self.logits = np.array(self.logits)


    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image = self.data[index].reshape(3, 32, 32).transpose(1, 2, 0)
        label = self.targets[index]
        logit = self.logits[index]
        
        image = Image.fromarray(image.astype('uint8'), 'RGB')
        logit = torch.tensor(logit, dtype=torch.float)
        if self.transform:
            image = self.transform(image)

        return {
            'pixel_values': image,
            'labels': label,
            'logits': logit
        }
    
    @property
    def labels(self):
        return self.targets


Definice accuracy metriky pro trénování modelu.

In [4]:
accuracy_metric = evaluate.load("accuracy")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
f1_metric = evaluate.load("f1")

def compute_metrics(eval_pred):
    pred, labels = eval_pred
    predictions = np.argmax(pred, axis=1)
    
    accuracy = accuracy_metric.compute(predictions=predictions, references=labels)
    precision = precision_metric.compute(predictions=predictions, references=labels, average='macro', zero_division = 0)
    recall = recall_metric.compute(predictions=predictions, references=labels, average='macro', zero_division = 0)
    f1 = f1_metric.compute(predictions=predictions, references=labels, average='macro')

    return {
        "accuracy": accuracy["accuracy"],
        "precision": precision["precision"],
        "recall": recall["recall"],
        "f1": f1["f1"]
    }

Trénovací argumenty pro trainer. 

In [None]:
class Custom_training_args(TrainingArguments):
    def __init__(self, lambda_param, temperature, *args, **kwargs):
        super().__init__(*args, **kwargs)    
        self.lambda_param = lambda_param
        self.temperature = temperature

In [5]:
def get_training_args(output_dir, logging_dir, remove_unused_columns=True, lr=5e-5, epochs=5, weight_decay=0, lambda_param=.5, temp=5):
    return (
        Custom_training_args(
        output_dir=output_dir,
        eval_strategy="epoch",
        save_strategy="epoch",
        logging_strategy="epoch",
        learning_rate=lr, #Defaultní hodnota 
        per_device_train_batch_size=128,
        per_device_eval_batch_size=128,
        num_train_epochs=epochs,
        weight_decay=weight_decay,
        seed = 42,  #Defaultní hodnota 
        metric_for_best_model="f1",
        fp16=True, 
        logging_dir=logging_dir,
        remove_unused_columns=remove_unused_columns,
        lambda_param = lambda_param, 
        temperature = temp
    ))

Náhodně inicializovaný MobileNetV2.

In [6]:
def get_random_init_mobilenet():
    reset_seed(42)
    student_config = MobileNetV2Config()
    student_config.num_labels = 100
    return MobileNetV2ForImageClassification(student_config)

Zamražení modelu a trénink pouze klasifikační hlavy.

In [7]:
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False

    for param in model.classifier.parameters():
        param.requires_grad = True

Inicializovaný MobileNetV2.

In [8]:
def get_mobilenet():
    model_pretrained = MobileNetV2ForImageClassification.from_pretrained("google/mobilenet_v2_1.0_224")
    in_features = model_pretrained.classifier.in_features

    model_pretrained.classifier = nn.Linear(in_features,100) #Úprava klasifikační hlavy
    model_pretrained.num_labels = 100
    model_pretrained.config.num_labels = 100

    return model_pretrained

In [9]:
reset_seed(42)

In [10]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA A100 80GB PCIe MIG 1g.10gb


Provedení transformací nad datasetem.

In [41]:
transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])



test = CustomCIFAR100(root='./data/100-logits', train=False, transform=transform)
train_whole = CustomCIFAR100(root='./data/100-logits', train=True, transform=transform)


Trochu lepší způsob rozdělení ...

In [42]:
train_idx, validation_idx = train_test_split(np.arange(len(train_whole)),
                                             test_size=0.2,
                                             random_state=42,
                                             shuffle=True,
                                             stratify=train_whole.labels)

In [43]:
eval = Subset(train_whole, validation_idx)
train = Subset(train_whole, train_idx)

### Standardní trénink náhodně inicializovaného modelu. 

In [13]:
training_args = get_training_args(output_dir="./results/cifar100-random", logging_dir='./logs/cifar100-random', lr=0.0005, weight_decay=0.008, epochs=20)
model = get_random_init_mobilenet()

In [14]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 4)]
)

In [15]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,4.5772,3.970619,0.0904
2,3.8446,3.571363,0.1543
3,3.601,3.276188,0.2054
4,3.2361,3.045846,0.2412
5,3.0865,2.803951,0.2911
6,2.8025,2.733018,0.3028
7,2.6715,2.589235,0.3393
8,2.4569,2.547741,0.3504
9,2.292,2.430317,0.3697
10,2.1832,2.566315,0.3503


TrainOutput(global_step=15640, training_loss=2.3696027243533706, metrics={'train_runtime': 5653.3542, 'train_samples_per_second': 176.886, 'train_steps_per_second': 2.766, 'total_flos': 2.124225847296e+18, 'train_loss': 2.3696027243533706, 'epoch': 20.0})

In [16]:
model.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [17]:
trainer.evaluate(test)

{'eval_loss': 2.25469708442688,
 'eval_accuracy': 0.4223,
 'eval_runtime': 27.6583,
 'eval_samples_per_second': 361.555,
 'eval_steps_per_second': 5.676,
 'epoch': 20.0}

## Definice destilačního tréninku

Třída, která upravuje hugging face trenéra pro destilaci znalostí. Nově pracuje s logity uloženými v datasetu.

In [18]:
class ImageDistilTrainer(Trainer):
    def __init__(self, student_model=None, *args, **kwargs):
        super().__init__(model=student_model, *args, **kwargs)
        self.student = student_model
        self.loss_function = nn.KLDivLoss(reduction="batchmean")
        self.temperature = self.args.temperature
        self.lambda_param = self.args.lambda_param



    def compute_loss(self, student, inputs, return_outputs=False, num_items_in_batch=None):
        logits = inputs.pop("logits")

        student_output = student(**inputs)

        soft_teacher = F.softmax(logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_output.logits / self.temperature, dim=-1)


        distillation_loss = self.loss_function(soft_student, soft_teacher) * (self.temperature ** 2)


        student_target_loss = student_output.loss

        loss = ((1. - self.lambda_param) * student_target_loss + self.lambda_param * distillation_loss)
        return (loss, student_output) if return_outputs else loss

### Trénink náhodně inicializovaného modelu s pomocí destilace znalostí

In [19]:
reset_seed(42)

In [20]:
student_model = get_random_init_mobilenet()

In [21]:
training_args = get_training_args(output_dir="./results/cifar100-random-KD", logging_dir='./logs/cifar100-random-KD', remove_unused_columns=False, epochs=20, lr=0.00045, lambda_param=1, temp=6)

In [22]:
trainer = ImageDistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 4)]
)

In [23]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,4.5262,4.003145,0.0827
2,3.9152,3.575923,0.1499
3,3.7056,3.366372,0.2043
4,3.4085,3.18104,0.2307
5,3.2774,2.971278,0.2768
6,3.0358,2.905349,0.2888
7,2.9225,2.807522,0.3178
8,2.7287,2.687188,0.341
9,2.5785,2.609439,0.3601
10,2.4869,2.660033,0.3497


TrainOutput(global_step=15640, training_loss=2.6360328176747196, metrics={'train_runtime': 5925.3147, 'train_samples_per_second': 168.767, 'train_steps_per_second': 2.64, 'total_flos': 2.124225847296e+18, 'train_loss': 2.6360328176747196, 'epoch': 20.0})

In [24]:
student_model.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [25]:
trainer.evaluate(test)

{'eval_loss': 2.2995285987854004,
 'eval_accuracy': 0.4302,
 'eval_runtime': 31.3261,
 'eval_samples_per_second': 319.222,
 'eval_steps_per_second': 5.012,
 'epoch': 20.0}

## Získání inicializovaného MobileNetV2 modelu

In [26]:
reset_seed(42)

In [27]:
model_pretrained = get_mobilenet()

In [28]:
print(model_pretrained)

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [29]:
freeze_model(model_pretrained)

In [30]:
training_args = get_training_args(output_dir="./results/cifar100-pretrained-head", logging_dir='./logs/cifar100-pretrained-head')

In [31]:
trainer = Trainer(
    model=model_pretrained,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 2)]
)

In [32]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,4.3078,3.696208,0.3016
2,3.3189,3.178281,0.367
3,3.0068,2.887075,0.4013
4,2.5969,2.713907,0.4142
5,2.4835,2.5312,0.426
6,2.2879,2.486408,0.4177
7,2.2204,2.347985,0.4459
8,2.1154,2.374058,0.438
9,2.036,2.180859,0.4698
10,2.015,2.30733,0.4365


TrainOutput(global_step=15640, training_loss=2.2494143278702445, metrics={'train_runtime': 3656.3553, 'train_samples_per_second': 273.496, 'train_steps_per_second': 4.277, 'total_flos': 2.124225847296e+18, 'train_loss': 2.2494143278702445, 'epoch': 20.0})

In [33]:
model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [34]:
trainer.evaluate(test)

{'eval_loss': 2.1018595695495605,
 'eval_accuracy': 0.4858,
 'eval_runtime': 31.0564,
 'eval_samples_per_second': 321.995,
 'eval_steps_per_second': 5.055,
 'epoch': 20.0}

### Trénink inicializovaného MobileNetV2

In [35]:
reset_seed(42)

In [36]:
model_pretrained_whole = get_mobilenet()

In [37]:
training_args = get_training_args(output_dir="./results/cifar100-pretrained", logging_dir='./logs/cifar100-pretrained')

In [38]:
trainer = Trainer(
    model=model_pretrained_whole,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [39]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,3.0576,1.438863,0.6068
2,1.0972,1.123755,0.6781
3,0.8403,1.221768,0.6577
4,0.6009,1.039952,0.7085
5,0.5,0.950684,0.7221
6,0.361,1.074553,0.6993
7,0.3039,0.981014,0.7273
8,0.2127,1.036193,0.7196
9,0.1646,0.98247,0.7333
10,0.127,1.178681,0.7015


TrainOutput(global_step=15640, training_loss=0.35456949499866847, metrics={'train_runtime': 5339.362, 'train_samples_per_second': 187.288, 'train_steps_per_second': 2.929, 'total_flos': 2.124225847296e+18, 'train_loss': 0.35456949499866847, 'epoch': 20.0})

In [40]:
model_pretrained_whole.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [41]:
trainer.evaluate(test)

{'eval_loss': 0.9824703335762024,
 'eval_accuracy': 0.7333,
 'eval_runtime': 30.8492,
 'eval_samples_per_second': 324.158,
 'eval_steps_per_second': 5.089,
 'epoch': 20.0}

## Trénink s pomocí destilace znalostí inicializovaného MobileNetV2

### Trénink inicializovaného modelu - pouze klasifikační hlavy s pomocí destilace

In [42]:
reset_seed(42)

In [43]:
student_model_pretrained = get_mobilenet()

In [44]:
freeze_model(student_model_pretrained)

In [45]:
training_args = get_training_args(output_dir="./results/cifar100-pretrained-head-KD", logging_dir='./logs/cifar100-pretrained-head-KD', remove_unused_columns=False, temp=6, lambda_param=.8)

In [46]:
trainer = ImageDistilTrainer(
    student_model=student_model_pretrained,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 2)]
)

In [47]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,4.2146,3.570572,0.2982
2,3.463,3.202306,0.3603
3,3.244,3.009376,0.3988
4,2.9709,2.894381,0.4078
5,2.8945,2.769151,0.4247
6,2.7664,2.75109,0.4135
7,2.7212,2.660129,0.4317
8,2.6505,2.663648,0.4313
9,2.5983,2.529032,0.4637
10,2.5829,2.635812,0.4439


TrainOutput(global_step=15640, training_loss=2.7446819354201217, metrics={'train_runtime': 3693.3679, 'train_samples_per_second': 270.756, 'train_steps_per_second': 4.235, 'total_flos': 2.124225847296e+18, 'train_loss': 2.7446819354201217, 'epoch': 20.0})

In [48]:
student_model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [49]:
trainer.evaluate(test)

{'eval_loss': 2.468466281890869,
 'eval_accuracy': 0.4783,
 'eval_runtime': 27.222,
 'eval_samples_per_second': 367.351,
 'eval_steps_per_second': 5.767,
 'epoch': 20.0}

### Trénink inicializovaného modelu s pomocí destilace

In [50]:
reset_seed(42)

In [51]:
student_model_pretrained_whole = get_mobilenet()

In [52]:
training_args = get_training_args("./results/cifar100-pretrained-KD", './logs/cifar100-pretrained-KD', remove_unused_columns=False, temp=6, lambda_param=1)

In [53]:
trainer = ImageDistilTrainer(
    student_model=student_model_pretrained_whole.to(device),
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [54]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,3.1584,1.591682,0.6046
2,1.351,1.252602,0.6774
3,1.0858,1.360459,0.6558
4,0.8271,1.13954,0.7084
5,0.7219,1.033424,0.7252
6,0.5795,1.116371,0.7104
7,0.5176,1.043184,0.7278
8,0.4206,1.0445,0.73
9,0.3699,0.998406,0.7375
10,0.3253,1.149427,0.7108


TrainOutput(global_step=15640, training_loss=0.5431039517492894, metrics={'train_runtime': 5592.0988, 'train_samples_per_second': 178.824, 'train_steps_per_second': 2.797, 'total_flos': 2.124225847296e+18, 'train_loss': 0.5431039517492894, 'epoch': 20.0})

In [55]:
student_model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [56]:
trainer.evaluate(test)

{'eval_loss': 1.0426266193389893,
 'eval_accuracy': 0.742,
 'eval_runtime': 18.7153,
 'eval_samples_per_second': 534.321,
 'eval_steps_per_second': 8.389,
 'epoch': 20.0}