# Notebook pro trénink s destilací nad datasetem CIFAR100
V tomto notebooku je trénován MobileNetV2 nad datasetem CIFAR100, jako učitelsý model je využíván finetunued ViT nad stejným datasetem. 

MobileNetV2 je používán s náhodnou inicializací, tréninkem pouze klasifikační hlavy inicializovaného (předtrénovaného nad ImageNetem) MobileNetuV2 a trénink celého modelu, taktéž inicializovaného. Tyto tři úlohy jsou trénovány bězným způsobem a také s pomocí destilace výše zmíněného modelu.  

Při destilaci je využíváno předpočítaných logitů ze sešitu precompute_logits.

In [1]:
%pip install transformers[torch] huggingface_hub datasets evaluate torchvision

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


## Import knihoven a definice metod

In [24]:
from transformers import Trainer, TrainingArguments, MobileNetV2Config, MobileNetV2ForImageClassification, EarlyStoppingCallback
from torchvision.transforms import v2 as transformsv2
from torch.utils.data import Dataset
import torch.nn.functional as F
from PIL import Image
import torch.nn as nn
from enum import Enum
import numpy as np
import evaluate
import random
import pickle
import torch
import os

Resetování náhodného seedu pro replikovatelnost výsledků.
Zřejmě je možné části odebrat.

TODO: Odebrat zbytečná nastavení.

In [25]:
def reset_seed(seed=42):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

Nový wrapper, který pracuje přímo se soubory staženého a upraveného datasetu CIFAR10.
Využití načtení pomocí metody jako dříve není možné kvůli jiné checksum. 

Zároveň se již dotahují logity přímo z datasetu.

In [26]:
dataset_part = Enum('dataset_part', [('TRAIN', 1), ('TEST', 2), ('EVAL', 3)])

In [130]:
class CustomCIFAR100(Dataset):
    def __init__(self, root, dataset_part = dataset_part.TRAIN, transform=None):
        self.root = root
        self.dataset_part = dataset_part
        self.transform = transform
        

        self.data = []
        self.targets = []
        self.logits = []
        self.logits_aug = []


        if self.dataset_part == dataset_part.TRAIN:
            data_file = os.path.join(self.root, 'cifar-100-python', 'train')
        elif self.dataset_part == dataset_part.TEST:
            data_file = os.path.join(self.root, 'cifar-100-python', 'test')
        else:
            data_file = os.path.join(self.root, 'cifar-100-python', 'eval')

        with open(data_file, 'rb') as fo:
            dict = pickle.load(fo, encoding='bytes')
            self.data.append(dict[b'data'])
            self.targets.extend(dict[b'fine_labels'])
            self.logits.extend(dict[b'logits'])
            if self.dataset_part == dataset_part.TRAIN:
                self.logits_aug.extend(dict[b'logits_aug'])
            
                
            
        self.data = np.concatenate(self.data, axis=0)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image = self.data[index].reshape(3, 32, 32).transpose(1, 2, 0)
        image = Image.fromarray(image.astype('uint8'), 'RGB')
        label = self.targets[index]
        #logit = self.logits[index]

        if self.transform:
            logit = self.logits[index] if len(self.transform.extra_repr()) < 300 else self.logits_aug[index]
            logit = torch.tensor(logit, dtype=torch.float)
            torch.manual_seed(index%10000)
            image = self.transform(image)

        
        
            

        return {
            'pixel_values': image,
            'labels': label,
            'logits': logit,
        }
    
    def remove_entries(self, remove_list):
        self.data = np.delete(self.data, remove_list, axis=0)
        self.targets = np.delete(self.targets, remove_list, axis=0)
        self.logits = np.delete(self.logits, remove_list, axis=0)
        self.logits_aug = np.delete(self.logits_aug, remove_list, axis=0)
        
    @property
    def labels(self):
        return self.targets

Definice accuracy metriky pro trénování modelu.

In [105]:
accuracy_metric = evaluate.load("accuracy")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
f1_metric = evaluate.load("f1")

def compute_metrics(eval_pred):
    pred, labels = eval_pred
    predictions = np.argmax(pred, axis=1)
    
    accuracy = accuracy_metric.compute(predictions=predictions, references=labels)
    precision = precision_metric.compute(predictions=predictions, references=labels, average='macro', zero_division = 0)
    recall = recall_metric.compute(predictions=predictions, references=labels, average='macro', zero_division = 0)
    f1 = f1_metric.compute(predictions=predictions, references=labels, average='macro')

    return {
        "accuracy": accuracy["accuracy"],
        "precision": precision["precision"],
        "recall": recall["recall"],
        "f1": f1["f1"]
    }

Trénovací argumenty pro trainer. 

In [106]:
class Custom_training_args(TrainingArguments):
    def __init__(self, lambda_param, temperature, *args, **kwargs):
        super().__init__(*args, **kwargs)    
        self.lambda_param = lambda_param
        self.temperature = temperature

In [107]:
def get_training_args(output_dir, logging_dir, remove_unused_columns=True, lr=5e-5, epochs=5, weight_decay=0, lambda_param=.5, temp=5):
    return (
        Custom_training_args(
        output_dir=output_dir,
        eval_strategy="epoch",
        save_strategy="epoch",
        logging_strategy="epoch",
        learning_rate=lr, #Defaultní hodnota 
        per_device_train_batch_size=128,
        per_device_eval_batch_size=128,
        num_train_epochs=epochs,
        weight_decay=weight_decay,
        seed = 42,  #Defaultní hodnota 
        metric_for_best_model="f1",
        load_best_model_at_end = True,
        fp16=True, 
        logging_dir=logging_dir,
        remove_unused_columns=remove_unused_columns,
        lambda_param = lambda_param, 
        temperature = temp
    ))

Náhodně inicializovaný MobileNetV2.

In [108]:
def get_random_init_mobilenet():
    reset_seed(42)
    student_config = MobileNetV2Config()
    student_config.num_labels = 100
    return MobileNetV2ForImageClassification(student_config)

Zamražení modelu a trénink pouze klasifikační hlavy.

In [109]:
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False

    for param in model.classifier.parameters():
        param.requires_grad = True

Inicializovaný MobileNetV2.

In [110]:
def get_mobilenet():
    model_pretrained = MobileNetV2ForImageClassification.from_pretrained("google/mobilenet_v2_1.0_224")
    in_features = model_pretrained.classifier.in_features

    model_pretrained.classifier = nn.Linear(in_features,100) #Úprava klasifikační hlavy
    model_pretrained.num_labels = 100
    model_pretrained.config.num_labels = 100

    return model_pretrained

In [111]:
reset_seed(42)

In [112]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA A100 80GB PCIe MIG 2g.20gb


Provedení transformací nad datasetem.

In [131]:
transform = transformsv2.Compose([
    transformsv2.Resize((224, 224)), 
    transformsv2.ToImage(),
    transformsv2.ToDtype(torch.float32, scale=True),
    transformsv2.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])



train = CustomCIFAR100(root='./data/100-logits', dataset_part=dataset_part.TRAIN, transform=transform)
eval = CustomCIFAR100(root='./data/100-logits', dataset_part=dataset_part.EVAL, transform=transform)
test = CustomCIFAR100(root='./data/100-logits', dataset_part=dataset_part.TEST, transform=transform)


In [114]:
augment_transform = transformsv2.Compose([ 
    transformsv2.ToImage(),
    transformsv2.ToDtype(torch.float32, scale=True),
    transformsv2.Resize(size=(224, 224), antialias=True),
    transformsv2.RandomHorizontalFlip(),
    transformsv2.RandomVerticalFlip(),
    transformsv2.RandomRotation(15),
    transformsv2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


train_aug = CustomCIFAR100(root='./data/100-logits', dataset_part=dataset_part.TRAIN, transform=augment_transform)


In [115]:
rem_ls = []

for index, val in enumerate(train_aug):
    target_alt = torch.topk(val["logits"], k=1).indices.numpy()[0]
    target_act = torch.topk(train[index]["logits"], k=1).indices.numpy()[0]
    if target_alt != target_act:
        rem_ls.append(index)
train_aug.remove_entries(rem_ls)

In [116]:
train_combo = torch.utils.data.ConcatDataset([train, train_aug])

### Standardní trénink náhodně inicializovaného modelu. 

In [117]:
training_args = get_training_args(output_dir="./results/cifar100-random", logging_dir='./logs/cifar100-random', lr=0.0005, weight_decay=0.008, epochs=20)
model = get_random_init_mobilenet()

In [118]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 4)]
)

In [87]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,4.0695,3.530944,0.1466,0.144824,0.1466,0.114114
2,3.3527,2.811494,0.2744,0.272026,0.2744,0.245172
3,2.8164,2.463264,0.349,0.350675,0.349,0.327709
4,2.4428,2.183587,0.4081,0.414208,0.4081,0.39433
5,2.1526,2.022997,0.4443,0.457037,0.4443,0.427724
6,1.92,1.858425,0.4889,0.494161,0.4889,0.477761
7,1.7171,1.770266,0.5099,0.52903,0.5099,0.504634
8,1.5482,1.701913,0.5292,0.533493,0.5292,0.522187
9,1.3748,1.683333,0.5329,0.549421,0.5329,0.531061
10,1.2205,1.627559,0.548,0.552544,0.548,0.543666


TrainOutput(global_step=10300, training_loss=1.4081217534333756, metrics={'train_runtime': 6446.5867, 'train_samples_per_second': 204.487, 'train_steps_per_second': 1.598, 'total_flos': 2.800239480939479e+18, 'train_loss': 1.4081217534333756, 'epoch': 20.0})

In [88]:
model.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [89]:
trainer.evaluate(test)

{'eval_loss': 1.8555642366409302,
 'eval_accuracy': 0.5607,
 'eval_precision': 0.5731154814126223,
 'eval_recall': 0.5607,
 'eval_f1': 0.5632832843293292,
 'eval_runtime': 27.6216,
 'eval_samples_per_second': 362.035,
 'eval_steps_per_second': 2.86,
 'epoch': 20.0}

## Definice destilačního tréninku

Třída, která upravuje hugging face trenéra pro destilaci znalostí. Nově pracuje s logity uloženými v datasetu.

In [119]:
class ImageDistilTrainer(Trainer):
    def __init__(self, student_model=None, *args, **kwargs):
        super().__init__(model=student_model, *args, **kwargs)
        self.student = student_model
        self.loss_function = nn.KLDivLoss(reduction="batchmean")
        self.temperature = self.args.temperature
        self.lambda_param = self.args.lambda_param



    def compute_loss(self, student, inputs, return_outputs=False, num_items_in_batch=None):
        logits = inputs.pop("logits")

        student_output = student(**inputs)

        soft_teacher = F.softmax(logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_output.logits / self.temperature, dim=-1)


        distillation_loss = self.loss_function(soft_student, soft_teacher) * (self.temperature ** 2)


        student_target_loss = student_output.loss

        loss = ((1. - self.lambda_param) * student_target_loss + self.lambda_param * distillation_loss)
        return (loss, student_output) if return_outputs else loss

### Trénink náhodně inicializovaného modelu s pomocí destilace znalostí

In [120]:
reset_seed(42)

In [121]:
student_model = get_random_init_mobilenet()

In [122]:
training_args = get_training_args(output_dir="./results/cifar100-random-KD", logging_dir='./logs/cifar100-random-KD', remove_unused_columns=False, epochs=30, lr=0.00047, lambda_param=.75, temp=6)

In [123]:
trainer = ImageDistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 4)]
)

In [124]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,3.0399,2.791122,0.1483,0.153566,0.1483,0.107175
2,2.5735,2.30771,0.2798,0.291919,0.2798,0.236826
3,2.2456,2.043069,0.3672,0.382228,0.3672,0.339746
4,2.0052,1.878129,0.4167,0.436576,0.4167,0.392702
5,1.8185,1.709248,0.4679,0.481027,0.4679,0.447591
6,1.6728,1.620318,0.4941,0.503461,0.4941,0.47805
7,1.5506,1.543608,0.5104,0.520614,0.5104,0.497499
8,1.4437,1.453382,0.5421,0.555254,0.5421,0.531065
9,1.3432,1.438919,0.5438,0.557198,0.5438,0.537853
10,1.2519,1.350234,0.5733,0.575506,0.5733,0.565275


TrainOutput(global_step=11330, training_loss=1.3234549971039276, metrics={'train_runtime': 6518.0056, 'train_samples_per_second': 303.369, 'train_steps_per_second': 2.37, 'total_flos': 3.080263429033427e+18, 'train_loss': 1.3234549971039276, 'epoch': 22.0})

In [125]:
student_model.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [132]:
trainer.evaluate(test)

{'eval_loss': 1.1580955982208252,
 'eval_accuracy': 0.6056,
 'eval_precision': 0.6186554197788803,
 'eval_recall': 0.6056,
 'eval_f1': 0.6042473276260995,
 'eval_runtime': 17.7426,
 'eval_samples_per_second': 563.614,
 'eval_steps_per_second': 4.453,
 'epoch': 22.0}

## Získání inicializovaného MobileNetV2 modelu

In [29]:
reset_seed(42)

In [30]:
model_pretrained = get_mobilenet()

In [31]:
print(model_pretrained)

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [32]:
freeze_model(model_pretrained)

In [33]:
training_args = get_training_args(output_dir="./results/cifar100-pretrained-head", logging_dir='./logs/cifar100-pretrained-head')

In [34]:
trainer = Trainer(
    model=model_pretrained,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 2)]
)

In [35]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,4.3666,4.12372,0.1839,0.208631,0.1839,0.177796
2,3.9275,3.765841,0.3024,0.325632,0.3024,0.293932
3,3.6436,3.570344,0.3498,0.359684,0.3498,0.333879
4,3.4753,3.458225,0.3587,0.382333,0.3587,0.344111
5,3.3925,3.428162,0.3769,0.382731,0.3769,0.359658


TrainOutput(global_step=1565, training_loss=3.761084686376797, metrics={'train_runtime': 469.9788, 'train_samples_per_second': 425.551, 'train_steps_per_second': 3.33, 'total_flos': 4.248451694592e+17, 'train_loss': 3.761084686376797, 'epoch': 5.0})

In [36]:
model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [37]:
trainer.evaluate(test)

{'eval_loss': 3.432467222213745,
 'eval_accuracy': 0.3773,
 'eval_precision': 0.3774894959232473,
 'eval_recall': 0.3772999999999999,
 'eval_f1': 0.35815750868611934,
 'eval_runtime': 18.3126,
 'eval_samples_per_second': 546.073,
 'eval_steps_per_second': 4.314,
 'epoch': 5.0}

### Trénink inicializovaného MobileNetV2

In [38]:
reset_seed(42)

In [39]:
model_pretrained_whole = get_mobilenet()

In [40]:
training_args = get_training_args(output_dir="./results/cifar100-pretrained", logging_dir='./logs/cifar100-pretrained')

In [41]:
trainer = Trainer(
    model=model_pretrained_whole,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [42]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,3.2751,2.070399,0.5244,0.543151,0.5244,0.508957
2,1.6438,1.473536,0.6097,0.625245,0.6097,0.605971
3,1.2066,1.26809,0.6559,0.6618,0.6559,0.651263
4,1.0111,1.259078,0.6489,0.670386,0.6489,0.648706
5,0.9242,1.154022,0.6805,0.685636,0.6805,0.674981


TrainOutput(global_step=1565, training_loss=1.61216064063124, metrics={'train_runtime': 585.1473, 'train_samples_per_second': 341.794, 'train_steps_per_second': 2.675, 'total_flos': 4.248451694592e+17, 'train_loss': 1.61216064063124, 'epoch': 5.0})

In [43]:
model_pretrained_whole.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [44]:
trainer.evaluate(test)

{'eval_loss': 1.1403955221176147,
 'eval_accuracy': 0.6805,
 'eval_precision': 0.6896003843930464,
 'eval_recall': 0.6805000000000001,
 'eval_f1': 0.6767723010859202,
 'eval_runtime': 17.6611,
 'eval_samples_per_second': 566.216,
 'eval_steps_per_second': 4.473,
 'epoch': 5.0}

## Trénink s pomocí destilace znalostí inicializovaného MobileNetV2

### Trénink inicializovaného modelu - pouze klasifikační hlavy s pomocí destilace

In [45]:
reset_seed(42)

In [46]:
student_model_pretrained = get_mobilenet()

In [47]:
freeze_model(student_model_pretrained)

In [48]:
training_args = get_training_args(output_dir="./results/cifar100-pretrained-head-KD", logging_dir='./logs/cifar100-pretrained-head-KD', remove_unused_columns=False, temp=6, lambda_param=.8)

In [49]:
trainer = ImageDistilTrainer(
    student_model=student_model_pretrained,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 2)]
)

In [50]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,3.5178,3.355098,0.1797,0.220875,0.1797,0.168106
2,3.2409,3.142756,0.2831,0.346824,0.2831,0.279037
3,3.0797,3.033154,0.3392,0.379779,0.3392,0.325473
4,2.99,2.978981,0.3436,0.406908,0.3436,0.332064
5,2.9475,2.958369,0.3593,0.402591,0.3593,0.34664


TrainOutput(global_step=1565, training_loss=3.155180383261781, metrics={'train_runtime': 433.0658, 'train_samples_per_second': 461.824, 'train_steps_per_second': 3.614, 'total_flos': 4.248451694592e+17, 'train_loss': 3.155180383261781, 'epoch': 5.0})

In [51]:
student_model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [52]:
trainer.evaluate(test)

{'eval_loss': 2.81398868560791,
 'eval_accuracy': 0.3537,
 'eval_precision': 0.39045062851507895,
 'eval_recall': 0.35369999999999996,
 'eval_f1': 0.33894118699147735,
 'eval_runtime': 17.613,
 'eval_samples_per_second': 567.761,
 'eval_steps_per_second': 4.485,
 'epoch': 5.0}

### Trénink inicializovaného modelu s pomocí destilace

In [53]:
reset_seed(42)

In [54]:
student_model_pretrained_whole = get_mobilenet()

In [55]:
training_args = get_training_args("./results/cifar100-pretrained-KD", './logs/cifar100-pretrained-KD', remove_unused_columns=False, temp=6, lambda_param=1)

In [56]:
trainer = ImageDistilTrainer(
    student_model=student_model_pretrained_whole.to(device),
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [57]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.6861,2.036413,0.4577,0.530295,0.4577,0.426668
2,1.7625,1.579123,0.566,0.607231,0.566,0.547062
3,1.4056,1.371883,0.619,0.641029,0.619,0.603205
4,1.2316,1.354033,0.6157,0.645628,0.6157,0.605793
5,1.1537,1.25149,0.6543,0.673847,0.6543,0.64309


TrainOutput(global_step=1565, training_loss=1.647921840679912, metrics={'train_runtime': 533.0318, 'train_samples_per_second': 375.212, 'train_steps_per_second': 2.936, 'total_flos': 4.248451694592e+17, 'train_loss': 1.647921840679912, 'epoch': 5.0})

In [58]:
student_model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [59]:
trainer.evaluate(test)

{'eval_loss': 1.0503957271575928,
 'eval_accuracy': 0.651,
 'eval_precision': 0.6701687624603764,
 'eval_recall': 0.6510000000000001,
 'eval_f1': 0.6393124124046187,
 'eval_runtime': 13.9925,
 'eval_samples_per_second': 714.669,
 'eval_steps_per_second': 5.646,
 'epoch': 5.0}