# Notebook pro trénink s destilací nad datasetem CIFAR100
V tomto notebooku je trénován MobileNetV2 nad datasetem CIFAR100, jako učitelsý model je využíván finetunued ViT nad stejným datasetem. 

MobileNetV2 je používán s náhodnou inicializací, tréninkem pouze klasifikační hlavy inicializovaného (předtrénovaného nad ImageNetem) MobileNetuV2 a trénink celého modelu, taktéž inicializovaného. Tyto tři úlohy jsou trénovány bězným způsobem a také s pomocí destilace výše zmíněného modelu.  

Při destilaci je využíváno předpočítaných logitů ze sešitu precompute_logits.

## Import knihoven a definice metod

In [1]:
from transformers import Trainer, EarlyStoppingCallback, AutoModelForImageClassification, AutoImageProcessor
from torch.utils.data import ConcatDataset
import torch
import base
import os

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/jovyan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /home/jovyan/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /home/jovyan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!


In [2]:
dataset_part = base.get_dataset_part()
DATASET = "cifar100"

Inicializovaný MobileNetV2.

In [3]:
base.reset_seed()

In [4]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA A100 80GB PCIe MIG 2g.20gb


Provedení transformací nad datasetem.

In [5]:
transform = base.base_transforms()

train = base.CustomCIFAR100L(root=f"{os.path.expanduser('~')}/data/100-logits", dataset_part=dataset_part.TRAIN, transform=None)
eval = base.CustomCIFAR100L(root=f"{os.path.expanduser('~')}/data/100-logits", dataset_part=dataset_part.EVAL, transform=None)
test = base.CustomCIFAR100L(root=f"{os.path.expanduser('~')}/data/100-logits", dataset_part=dataset_part.TEST, transform=None)


### Standardní trénink náhodně inicializovaného modelu. 

In [None]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/vit-basetrain-model-transform", logging_dir=f"~/logs/{DATASET}/vit-basetrain-model-transform", lr=0.0005, weight_decay=0.008, adam_beta1=.95, epochs=30)
model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=100, ignore_mismatched_sizes=True)
processor = AutoImageProcessor.from_pretrained("timm/tiny_vit_5m_224.in1k")

Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([100]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([100, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [32]:
from torchvision.transforms import ToTensor
def collate_fn(examples):
    to_tensor = ToTensor()
    pixel_values = torch.stack([to_tensor(example["pixel_values"]) for example in examples])
    labels = torch.tensor([example["labels"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [33]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    processing_class = processor,
    data_collator = collate_fn,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 5)]
)

In [34]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,3.0292,2.31022,0.3931,0.418164,0.3931,0.375002
2,1.9006,1.904595,0.4812,0.505944,0.4812,0.475442
3,1.4709,1.786745,0.513,0.548163,0.513,0.513766
4,1.1581,1.729847,0.5356,0.557527,0.5356,0.531547
5,0.9283,1.722173,0.5476,0.56867,0.5476,0.544768


The `save_pretrained` method is disabled for TimmWrapperImageProcessor. The image processor configuration is saved directly in `config.json` when `save_pretrained` is called for saving the model.


KeyboardInterrupt: 

In [None]:
model.eval()

In [None]:
trainer.evaluate(test)

In [None]:
torch.save(model.state_dict(), f"{os.path.expanduser('~')}/models/{DATASET}/vit-basetrain_token.pth")

In [35]:
transform = base.base_transforms()

train = base.CustomCIFAR100L(root=f"{os.path.expanduser('~')}/data/100-logits", dataset_part=dataset_part.TRAIN, transform=transform)
eval = base.CustomCIFAR100L(root=f"{os.path.expanduser('~')}/data/100-logits", dataset_part=dataset_part.EVAL, transform=transform)
test = base.CustomCIFAR100L(root=f"{os.path.expanduser('~')}/data/100-logits", dataset_part=dataset_part.TEST, transform=transform)

training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/vit-basetrain-token-transf", logging_dir=f"~/logs/{DATASET}/vit-basetrain-token-transf", lr=0.0005, weight_decay=0.008, adam_beta1=.95, epochs=30)
model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=100, ignore_mismatched_sizes=True)
processor = AutoImageProcessor.from_pretrained("timm/tiny_vit_5m_224.in1k")


Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([100]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([100, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    processing_class = processor,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 5)]
)

In [37]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.6421,1.027111,0.7031,0.723367,0.7031,0.700154
2,0.7633,0.893031,0.7416,0.760275,0.7416,0.739986
3,0.5508,0.869524,0.7499,0.774201,0.7499,0.749924
4,0.421,0.86798,0.7605,0.781796,0.7605,0.760626
5,0.3172,0.838674,0.7758,0.787255,0.7758,0.775314
6,0.2555,0.790946,0.7871,0.79647,0.7871,0.786925
7,0.2166,0.823603,0.7912,0.803079,0.7912,0.791609
8,0.1629,0.876584,0.7883,0.79762,0.7883,0.78796
9,0.14,0.956399,0.7797,0.790966,0.7797,0.780496
10,0.1169,0.903787,0.7885,0.798007,0.7885,0.788618


TrainOutput(global_step=9390, training_loss=0.1716814164552318, metrics={'train_runtime': 5929.021, 'train_samples_per_second': 202.394, 'train_steps_per_second': 1.584, 'total_flos': 5.5315759693824e+18, 'train_loss': 0.1716814164552318, 'epoch': 30.0})

In [38]:
model.eval()

TimmWrapperForImageClassification(
  (timm_model): TinyVit(
    (patch_embed): PatchEmbed(
      (conv1): ConvNorm(
        (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (act): GELU(approximate='none')
      (conv2): ConvNorm(
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (stages): Sequential(
      (0): ConvLayer(
        (blocks): Sequential(
          (0): MBConv(
            (conv1): ConvNorm(
              (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
            (act1): GELU(approximate='none')
            (conv2): ConvNorm(
              (c

In [39]:
trainer.evaluate(test)

{'eval_loss': 0.9375255703926086,
 'eval_accuracy': 0.8402,
 'eval_precision': 0.8423099663864122,
 'eval_recall': 0.8402,
 'eval_f1': 0.8405098146032582,
 'eval_runtime': 12.3538,
 'eval_samples_per_second': 809.468,
 'eval_steps_per_second': 6.395,
 'epoch': 30.0}

## Definice destilačního tréninku

Třída, která upravuje hugging face trenéra pro destilaci znalostí. Nově pracuje s logity uloženými v datasetu.

### Trénink náhodně inicializovaného modelu s pomocí destilace znalostí

In [41]:
base.reset_seed()

In [42]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/vit-distilltrain", logging_dir=f"~/logs/{DATASET}/vit-distilltrain", remove_unused_columns=False, epochs=30, lr=0.00047, weight_decay=0, adam_beta1=.9, lambda_param=1, temp=6)
student_model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=100, ignore_mismatched_sizes=True)

Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([100]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([100, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [43]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    processing_class = processor,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 5)]
)

In [44]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.2989,0.831203,0.6838,0.731836,0.6838,0.685063
2,0.6,0.662224,0.7296,0.753612,0.7296,0.73067
3,0.4533,0.622127,0.7555,0.776564,0.7555,0.753489
4,0.3685,0.563021,0.7707,0.78972,0.7707,0.772211
5,0.3103,0.546491,0.7793,0.799878,0.7793,0.779068
6,0.2627,0.532122,0.7826,0.800717,0.7826,0.784513
7,0.232,0.495274,0.797,0.812961,0.797,0.799272
8,0.2032,0.501735,0.7954,0.811297,0.7954,0.797779
9,0.1854,0.486403,0.797,0.811036,0.797,0.797992
10,0.1645,0.481768,0.7979,0.809001,0.7979,0.798922


KeyboardInterrupt: 

In [None]:
student_model.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [None]:
trainer.evaluate(test)

{'eval_loss': 1.09884512424469,
 'eval_accuracy': 0.5504,
 'eval_precision': 0.5808980316978757,
 'eval_recall': 0.5504,
 'eval_f1': 0.5546805225416259,
 'eval_runtime': 37.502,
 'eval_samples_per_second': 266.653,
 'eval_steps_per_second': 2.107,
 'epoch': 30.0}

In [None]:
torch.save(model.state_dict(), f"{os.path.expanduser('~')}/models/{DATASET}/vit-distilltrain.pth")