# Notebook pro trénink s destilací nad datasetem CIFAR100
V tomto notebooku je trénován MobileNetV2 nad datasetem CIFAR100, jako učitelsý model je využíván finetunued ViT nad stejným datasetem. 

MobileNetV2 je používán s náhodnou inicializací, tréninkem pouze klasifikační hlavy inicializovaného (předtrénovaného nad ImageNetem) MobileNetuV2 a trénink celého modelu, taktéž inicializovaného. Tyto tři úlohy jsou trénovány bězným způsobem a také s pomocí destilace výše zmíněného modelu.  

Při destilaci je využíváno předpočítaných logitů ze sešitu precompute_logits.

In [1]:
%pip install transformers[torch] huggingface_hub datasets evaluate torchvision

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


## Import knihoven a definice metod

In [2]:
from transformers import Trainer, TrainingArguments, MobileNetV2Config, MobileNetV2ForImageClassification, EarlyStoppingCallback
from torchvision.transforms import v2 as transformsv2
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, Subset
import torch.nn.functional as F
from PIL import Image
import torch.nn as nn
import numpy as np
import evaluate
import random
import pickle
import torch
import os

  from .autonotebook import tqdm as notebook_tqdm


Resetování náhodného seedu pro replikovatelnost výsledků.
Zřejmě je možné části odebrat.

TODO: Odebrat zbytečná nastavení.

In [3]:
def reset_seed(seed=42):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

Nový wrapper, který pracuje přímo se soubory staženého a upraveného datasetu CIFAR10.
Využití načtení pomocí metody jako dříve není možné kvůli jiné checksum. 

Zároveň se již dotahují logity přímo z datasetu.

In [4]:
class CustomCIFAR100(Dataset):
    def __init__(self, root, train=True, transform=None):
        self.root = root
        self.train = train
        self.transform = transform


        self.data = []
        self.targets = []
        self.logits = []
        
        if self.train:
            data_file = os.path.join(self.root, 'cifar-100-python', 'train')
        else:
            data_file = os.path.join(self.root, 'cifar-100-python', 'test')
            
        with open(data_file, 'rb') as fo:
            dict = pickle.load(fo, encoding='bytes')
            self.data.append(dict[b'data'])
            self.targets.extend(dict[b'fine_labels'])
            self.logits.extend(dict[b'logits'])  

        self.data = np.concatenate(self.data, axis=0)
        self.targets = np.array(self.targets)
        self.logits = np.array(self.logits)


    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image = self.data[index].reshape(3, 32, 32).transpose(1, 2, 0)
        label = self.targets[index]
        logit = self.logits[index]
        
        image = Image.fromarray(image.astype('uint8'), 'RGB')
        logit = torch.tensor(logit, dtype=torch.float)
        if self.transform:
            image = self.transform(image)

        return {
            'pixel_values': image,
            'labels': label,
            'logits': logit
        }
    
    @property
    def labels(self):
        return self.targets


Definice accuracy metriky pro trénování modelu.

In [5]:
accuracy_metric = evaluate.load("accuracy")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
f1_metric = evaluate.load("f1")

def compute_metrics(eval_pred):
    pred, labels = eval_pred
    predictions = np.argmax(pred, axis=1)
    
    accuracy = accuracy_metric.compute(predictions=predictions, references=labels)
    precision = precision_metric.compute(predictions=predictions, references=labels, average='macro', zero_division = 0)
    recall = recall_metric.compute(predictions=predictions, references=labels, average='macro', zero_division = 0)
    f1 = f1_metric.compute(predictions=predictions, references=labels, average='macro')

    return {
        "accuracy": accuracy["accuracy"],
        "precision": precision["precision"],
        "recall": recall["recall"],
        "f1": f1["f1"]
    }

Trénovací argumenty pro trainer. 

In [6]:
class Custom_training_args(TrainingArguments):
    def __init__(self, lambda_param, temperature, *args, **kwargs):
        super().__init__(*args, **kwargs)    
        self.lambda_param = lambda_param
        self.temperature = temperature

In [7]:
def get_training_args(output_dir, logging_dir, remove_unused_columns=True, lr=5e-5, epochs=5, weight_decay=0, lambda_param=.5, temp=5):
    return (
        Custom_training_args(
        output_dir=output_dir,
        eval_strategy="epoch",
        save_strategy="epoch",
        logging_strategy="epoch",
        learning_rate=lr, #Defaultní hodnota 
        per_device_train_batch_size=128,
        per_device_eval_batch_size=128,
        num_train_epochs=epochs,
        weight_decay=weight_decay,
        seed = 42,  #Defaultní hodnota 
        metric_for_best_model="f1",
        load_best_model_at_end = True,
        fp16=True, 
        logging_dir=logging_dir,
        remove_unused_columns=remove_unused_columns,
        lambda_param = lambda_param, 
        temperature = temp
    ))

Náhodně inicializovaný MobileNetV2.

In [8]:
def get_random_init_mobilenet():
    reset_seed(42)
    student_config = MobileNetV2Config()
    student_config.num_labels = 100
    return MobileNetV2ForImageClassification(student_config)

Zamražení modelu a trénink pouze klasifikační hlavy.

In [9]:
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False

    for param in model.classifier.parameters():
        param.requires_grad = True

Inicializovaný MobileNetV2.

In [10]:
def get_mobilenet():
    model_pretrained = MobileNetV2ForImageClassification.from_pretrained("google/mobilenet_v2_1.0_224")
    in_features = model_pretrained.classifier.in_features

    model_pretrained.classifier = nn.Linear(in_features,100) #Úprava klasifikační hlavy
    model_pretrained.num_labels = 100
    model_pretrained.config.num_labels = 100

    return model_pretrained

In [11]:
reset_seed(42)

In [12]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA A10


Provedení transformací nad datasetem.

In [13]:
transform = transformsv2.Compose([
    transformsv2.Resize((224, 224)), 
    transformsv2.ToImage(),
    transformsv2.ToDtype(torch.float32, scale=True),
    transformsv2.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])



test = CustomCIFAR100(root='../data/100-logits', train=False, transform=transform)
train_whole = CustomCIFAR100(root='../data/100-logits', train=True, transform=transform)


Trochu lepší způsob rozdělení ...

In [14]:
train_idx, validation_idx = train_test_split(np.arange(len(train_whole)),
                                             test_size=0.2,
                                             random_state=42,
                                             shuffle=True,
                                             stratify=train_whole.labels)

In [15]:
eval = Subset(train_whole, validation_idx)
train = Subset(train_whole, train_idx)

### Standardní trénink náhodně inicializovaného modelu. 

In [16]:
training_args = get_training_args(output_dir="./results/cifar100-random", logging_dir='./logs/cifar100-random', lr=0.0005, weight_decay=0.008, epochs=20)
model = get_random_init_mobilenet()

In [17]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 4)]
)

In [18]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,4.1515,3.665846,0.1196,0.105591,0.1196,0.086739
2,3.5133,3.091202,0.2275,0.222321,0.2275,0.201253
3,3.0162,2.660194,0.306,0.297025,0.306,0.277242
4,2.6473,2.534648,0.3267,0.359282,0.3267,0.304568
5,2.3703,2.209965,0.4016,0.41496,0.4016,0.387414
6,2.1014,2.075731,0.4331,0.450639,0.4331,0.421803
7,1.892,1.948526,0.4704,0.476048,0.4704,0.461536
8,1.6885,1.928553,0.4823,0.504688,0.4823,0.478651
9,1.4927,1.99344,0.4741,0.496955,0.4741,0.468927
10,1.3048,1.816158,0.508,0.521405,0.508,0.504065


TrainOutput(global_step=5634, training_loss=1.614692786842701, metrics={'train_runtime': 3041.9242, 'train_samples_per_second': 262.991, 'train_steps_per_second': 2.058, 'total_flos': 1.52944261005312e+18, 'train_loss': 1.614692786842701, 'epoch': 18.0})

In [19]:
model.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [20]:
trainer.evaluate(test)

{'eval_loss': 1.8767671585083008,
 'eval_accuracy': 0.5153,
 'eval_precision': 0.5321961826850576,
 'eval_recall': 0.5153,
 'eval_f1': 0.5155597800513738,
 'eval_runtime': 30.5945,
 'eval_samples_per_second': 326.856,
 'eval_steps_per_second': 2.582,
 'epoch': 18.0}

## Definice destilačního tréninku

Třída, která upravuje hugging face trenéra pro destilaci znalostí. Nově pracuje s logity uloženými v datasetu.

In [21]:
class ImageDistilTrainer(Trainer):
    def __init__(self, student_model=None, *args, **kwargs):
        super().__init__(model=student_model, *args, **kwargs)
        self.student = student_model
        self.loss_function = nn.KLDivLoss(reduction="batchmean")
        self.temperature = self.args.temperature
        self.lambda_param = self.args.lambda_param



    def compute_loss(self, student, inputs, return_outputs=False, num_items_in_batch=None):
        logits = inputs.pop("logits")

        student_output = student(**inputs)

        soft_teacher = F.softmax(logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_output.logits / self.temperature, dim=-1)


        distillation_loss = self.loss_function(soft_student, soft_teacher) * (self.temperature ** 2)


        student_target_loss = student_output.loss

        loss = ((1. - self.lambda_param) * student_target_loss + self.lambda_param * distillation_loss)
        return (loss, student_output) if return_outputs else loss

### Trénink náhodně inicializovaného modelu s pomocí destilace znalostí

In [22]:
reset_seed(42)

In [23]:
student_model = get_random_init_mobilenet()

In [24]:
training_args = get_training_args(output_dir="./results/cifar100-random-KD", logging_dir='./logs/cifar100-random-KD', remove_unused_columns=False, epochs=20, lr=0.00045, lambda_param=1, temp=6)

In [25]:
trainer = ImageDistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 4)]
)

In [26]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,3.214,2.9748,0.1004,0.056157,0.1004,0.052941
2,2.8871,2.698806,0.196,0.187306,0.196,0.146708
3,2.6336,2.435085,0.2642,0.255461,0.2642,0.214999
4,2.4049,2.272758,0.314,0.343229,0.314,0.267332
5,2.215,2.094159,0.3616,0.388128,0.3616,0.320316
6,2.0504,1.937629,0.4194,0.442721,0.4194,0.3937
7,1.908,1.872917,0.4387,0.466657,0.4387,0.412218
8,1.7778,1.79174,0.4636,0.485604,0.4636,0.442349
9,1.6516,1.77806,0.4645,0.494373,0.4645,0.44108
10,1.5385,1.656386,0.5035,0.520974,0.5035,0.486511


TrainOutput(global_step=5947, training_loss=1.6951558786779968, metrics={'train_runtime': 2998.4342, 'train_samples_per_second': 266.806, 'train_steps_per_second': 2.088, 'total_flos': 1.61441164394496e+18, 'train_loss': 1.6951558786779968, 'epoch': 19.0})

In [27]:
student_model.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [28]:
trainer.evaluate(test)

{'eval_loss': 1.3772859573364258,
 'eval_accuracy': 0.531,
 'eval_precision': 0.5580247452820734,
 'eval_recall': 0.5309999999999999,
 'eval_f1': 0.5285947047145109,
 'eval_runtime': 17.9155,
 'eval_samples_per_second': 558.176,
 'eval_steps_per_second': 4.41,
 'epoch': 19.0}

## Získání inicializovaného MobileNetV2 modelu

In [29]:
reset_seed(42)

In [30]:
model_pretrained = get_mobilenet()

In [31]:
print(model_pretrained)

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [32]:
freeze_model(model_pretrained)

In [33]:
training_args = get_training_args(output_dir="./results/cifar100-pretrained-head", logging_dir='./logs/cifar100-pretrained-head')

In [34]:
trainer = Trainer(
    model=model_pretrained,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 2)]
)

In [35]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,4.3666,4.12372,0.1839,0.208631,0.1839,0.177796
2,3.9275,3.765841,0.3024,0.325632,0.3024,0.293932
3,3.6436,3.570344,0.3498,0.359684,0.3498,0.333879
4,3.4753,3.458225,0.3587,0.382333,0.3587,0.344111
5,3.3925,3.428162,0.3769,0.382731,0.3769,0.359658


TrainOutput(global_step=1565, training_loss=3.761084686376797, metrics={'train_runtime': 469.9788, 'train_samples_per_second': 425.551, 'train_steps_per_second': 3.33, 'total_flos': 4.248451694592e+17, 'train_loss': 3.761084686376797, 'epoch': 5.0})

In [36]:
model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [37]:
trainer.evaluate(test)

{'eval_loss': 3.432467222213745,
 'eval_accuracy': 0.3773,
 'eval_precision': 0.3774894959232473,
 'eval_recall': 0.3772999999999999,
 'eval_f1': 0.35815750868611934,
 'eval_runtime': 18.3126,
 'eval_samples_per_second': 546.073,
 'eval_steps_per_second': 4.314,
 'epoch': 5.0}

### Trénink inicializovaného MobileNetV2

In [38]:
reset_seed(42)

In [39]:
model_pretrained_whole = get_mobilenet()

In [40]:
training_args = get_training_args(output_dir="./results/cifar100-pretrained", logging_dir='./logs/cifar100-pretrained')

In [41]:
trainer = Trainer(
    model=model_pretrained_whole,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [42]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,3.2751,2.070399,0.5244,0.543151,0.5244,0.508957
2,1.6438,1.473536,0.6097,0.625245,0.6097,0.605971
3,1.2066,1.26809,0.6559,0.6618,0.6559,0.651263
4,1.0111,1.259078,0.6489,0.670386,0.6489,0.648706
5,0.9242,1.154022,0.6805,0.685636,0.6805,0.674981


TrainOutput(global_step=1565, training_loss=1.61216064063124, metrics={'train_runtime': 585.1473, 'train_samples_per_second': 341.794, 'train_steps_per_second': 2.675, 'total_flos': 4.248451694592e+17, 'train_loss': 1.61216064063124, 'epoch': 5.0})

In [43]:
model_pretrained_whole.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [44]:
trainer.evaluate(test)

{'eval_loss': 1.1403955221176147,
 'eval_accuracy': 0.6805,
 'eval_precision': 0.6896003843930464,
 'eval_recall': 0.6805000000000001,
 'eval_f1': 0.6767723010859202,
 'eval_runtime': 17.6611,
 'eval_samples_per_second': 566.216,
 'eval_steps_per_second': 4.473,
 'epoch': 5.0}

## Trénink s pomocí destilace znalostí inicializovaného MobileNetV2

### Trénink inicializovaného modelu - pouze klasifikační hlavy s pomocí destilace

In [45]:
reset_seed(42)

In [46]:
student_model_pretrained = get_mobilenet()

In [47]:
freeze_model(student_model_pretrained)

In [48]:
training_args = get_training_args(output_dir="./results/cifar100-pretrained-head-KD", logging_dir='./logs/cifar100-pretrained-head-KD', remove_unused_columns=False, temp=6, lambda_param=.8)

In [49]:
trainer = ImageDistilTrainer(
    student_model=student_model_pretrained,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 2)]
)

In [50]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,3.5178,3.355098,0.1797,0.220875,0.1797,0.168106
2,3.2409,3.142756,0.2831,0.346824,0.2831,0.279037
3,3.0797,3.033154,0.3392,0.379779,0.3392,0.325473
4,2.99,2.978981,0.3436,0.406908,0.3436,0.332064
5,2.9475,2.958369,0.3593,0.402591,0.3593,0.34664


TrainOutput(global_step=1565, training_loss=3.155180383261781, metrics={'train_runtime': 433.0658, 'train_samples_per_second': 461.824, 'train_steps_per_second': 3.614, 'total_flos': 4.248451694592e+17, 'train_loss': 3.155180383261781, 'epoch': 5.0})

In [51]:
student_model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [52]:
trainer.evaluate(test)

{'eval_loss': 2.81398868560791,
 'eval_accuracy': 0.3537,
 'eval_precision': 0.39045062851507895,
 'eval_recall': 0.35369999999999996,
 'eval_f1': 0.33894118699147735,
 'eval_runtime': 17.613,
 'eval_samples_per_second': 567.761,
 'eval_steps_per_second': 4.485,
 'epoch': 5.0}

### Trénink inicializovaného modelu s pomocí destilace

In [53]:
reset_seed(42)

In [54]:
student_model_pretrained_whole = get_mobilenet()

In [55]:
training_args = get_training_args("./results/cifar100-pretrained-KD", './logs/cifar100-pretrained-KD', remove_unused_columns=False, temp=6, lambda_param=1)

In [56]:
trainer = ImageDistilTrainer(
    student_model=student_model_pretrained_whole.to(device),
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [57]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.6861,2.036413,0.4577,0.530295,0.4577,0.426668
2,1.7625,1.579123,0.566,0.607231,0.566,0.547062
3,1.4056,1.371883,0.619,0.641029,0.619,0.603205
4,1.2316,1.354033,0.6157,0.645628,0.6157,0.605793
5,1.1537,1.25149,0.6543,0.673847,0.6543,0.64309


TrainOutput(global_step=1565, training_loss=1.647921840679912, metrics={'train_runtime': 533.0318, 'train_samples_per_second': 375.212, 'train_steps_per_second': 2.936, 'total_flos': 4.248451694592e+17, 'train_loss': 1.647921840679912, 'epoch': 5.0})

In [58]:
student_model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [59]:
trainer.evaluate(test)

{'eval_loss': 1.0503957271575928,
 'eval_accuracy': 0.651,
 'eval_precision': 0.6701687624603764,
 'eval_recall': 0.6510000000000001,
 'eval_f1': 0.6393124124046187,
 'eval_runtime': 13.9925,
 'eval_samples_per_second': 714.669,
 'eval_steps_per_second': 5.646,
 'epoch': 5.0}