# Notebook pro trénink s destilací nad datasetem CIFAR10
V tomto notebooku je trénován MobileNetV2 nad datasetem CIFAR10, jako učitelsý model je využíván finetunued ViT nad stejným datasetem. 

MobileNetV2 je používán s náhodnou inicializací, tréninkem pouze klasifikační hlavy inicializovaného (předtrénovaného nad ImageNetem) MobileNetuV2 a trénink celého modelu, taktéž inicializovaného. Tyto tři úlohy jsou trénovány bězným způsobem a také s pomocí destilace výše zmíněného modelu.  

Při destilaci je využíváno předpočítaných logitů ze sešitu precompute_logits.

## Import knihoven a definice metod

In [1]:
from transformers import Trainer, EarlyStoppingCallback, AutoModelForImageClassification
from torch.utils.data import DataLoader, ConcatDataset
import pandas as pd
import torch
import base
import os

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/jovyan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /home/jovyan/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /home/jovyan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!


In [None]:
dataset_part = base.get_dataset_part()

Resetování náhodného seedu pro replikovatelnost výsledků.

In [None]:
base.reset_seed()

In [4]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA H100 PCIe


Provedení transformací nad datasetem.

In [None]:
DATASET = "cifar10"

In [6]:
transform = base.base_transforms()

#Poslední train batch použijeme jako eval část...
test = base.CustomCIFAR10L(root=f"{os.path.expanduser('~')}/data/10-logits", dataset_part=dataset_part.TEST, transform=transform)
train = base.CustomCIFAR10L(root=f"{os.path.expanduser('~')}/data/10-logits", dataset_part=dataset_part.TRAIN, transform=transform)
eval = base.CustomCIFAR10L(root=f"{os.path.expanduser('~')}/data/10-logits", dataset_part=dataset_part.EVAL, transform=transform)

In [7]:
augment_transform = base.aug_transforms()
train_aug = base.CustomCIFAR10L(root=f"{os.path.expanduser('~')}/data/10-logits", dataset_part=dataset_part.TRAIN, transform=augment_transform)

In [8]:
train_aug = base.remove_diff_pred_class(train, train_aug, pytorch_dataset=True)
train_combo = ConcatDataset([train, train_aug])

Removing entries from augmented dataset that are different from the base one - based on saved logits:   0%|   …

In [9]:
train[0]["labels"]

tensor(6)

In [10]:
# Test rozložení --> Good Enough
df = pd.DataFrame(eval.labels)
print(df.value_counts())

0
5    1025
9    1022
3    1016
0    1014
1    1014
8    1003
4     997
6     980
7     977
2     952
Name: count, dtype: int64


### Standardní trénink náhodně inicializovaného modelu. 

In [11]:
student_model = base.get_mobilenet(10)

config.json:   0%|          | 0.00/69.8k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/14.2M [00:00<?, ?B/s]

In [13]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar10-random", logging_dir=f"~/logs/{DATASET}/cifar10-random", remove_unused_columns=False)

In [14]:
base.reset_seed()

In [17]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
)

In [18]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.5742,0.347547,0.8771,0.883398,0.877196,0.876988
2,0.2851,0.29566,0.9049,0.909231,0.904809,0.905567
3,0.2294,0.274181,0.9116,0.913168,0.911627,0.911539
4,0.2032,0.250855,0.9196,0.922027,0.919773,0.91974
5,0.189,0.254556,0.9188,0.919912,0.91915,0.918799


TrainOutput(global_step=1565, training_loss=0.29617781959021816, metrics={'train_runtime': 245.2355, 'train_samples_per_second': 815.542, 'train_steps_per_second': 6.382, 'total_flos': 4.040199217152e+17, 'train_loss': 0.29617781959021816, 'epoch': 5.0})

In [19]:
base.reset_seed()

In [20]:
student_model = base.get_mobilenet(10)
teacher_model = AutoModelForImageClassification.from_pretrained(
    "aaraki/vit-base-patch16-224-in21k-finetuned-cifar10",
    num_labels=10,
)
teacher_model.eval()
teacher_model.to(device)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [21]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar10-random-KD", logging_dir=f"~/logs/{DATASET}/cifar10-random-KD", remove_unused_columns=False)

In [22]:
base.reset_seed()

In [23]:
trainer = base.DistilTrainerInfer(
    student_model=student_model,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics
)

In [24]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.562,0.343917,0.8801,0.886331,0.880106,0.87997
2,0.2803,0.288551,0.9078,0.911852,0.907647,0.908578
3,0.2277,0.274579,0.9126,0.91461,0.912606,0.912678
4,0.2023,0.250524,0.9193,0.921502,0.9195,0.919455
5,0.1883,0.252376,0.9206,0.9219,0.920891,0.920712


TrainOutput(global_step=1565, training_loss=0.29210466622544556, metrics={'train_runtime': 274.3557, 'train_samples_per_second': 728.981, 'train_steps_per_second': 5.704, 'total_flos': 4.040199217152e+17, 'train_loss': 0.29210466622544556, 'epoch': 5.0})

In [25]:
student_model = base.get_mobilenet(10)

In [28]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar10-random", logging_dir=f"~/logs/{DATASET}/cifar10-random", remove_unused_columns=False)

In [29]:
base.reset_seed()

In [30]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
)

In [31]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.5496,0.28729,0.9086,0.909263,0.908559,0.908476
2,0.3145,0.271343,0.911,0.915552,0.910453,0.911408
3,0.2617,0.234006,0.9304,0.931797,0.930296,0.930676
4,0.2335,0.254683,0.9183,0.921042,0.918829,0.918204
5,0.2187,0.242656,0.9251,0.929101,0.925108,0.925319


TrainOutput(global_step=2665, training_loss=0.3155986649905092, metrics={'train_runtime': 415.0748, 'train_samples_per_second': 821.226, 'train_steps_per_second': 6.421, 'total_flos': 6.885913535753011e+17, 'train_loss': 0.3155986649905092, 'epoch': 5.0})

In [32]:
base.reset_seed()

In [33]:
student_model = base.get_mobilenet(10)

In [34]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar10-random-KD", logging_dir=f"~/logs/{DATASET}/cifar10-random-KD", remove_unused_columns=False)

In [35]:
base.reset_seed()

In [36]:
trainer = base.DistilTrainerInfer(
    student_model=student_model,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics
)

In [37]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.5445,0.284731,0.9125,0.91348,0.912592,0.912544
2,0.3187,0.26919,0.9113,0.915941,0.910811,0.911818
3,0.2709,0.229222,0.9296,0.93145,0.92952,0.93005
4,0.2454,0.246386,0.9216,0.923437,0.922095,0.921487
5,0.2329,0.235258,0.9266,0.931581,0.926628,0.926994


TrainOutput(global_step=2665, training_loss=0.322469545976306, metrics={'train_runtime': 438.44, 'train_samples_per_second': 777.461, 'train_steps_per_second': 6.078, 'total_flos': 6.885913535753011e+17, 'train_loss': 0.322469545976306, 'epoch': 5.0})

In [38]:
student_model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=10, ignore_mismatched_sizes=True)

config.json:   0%|          | 0.00/583 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/21.6M [00:00<?, ?B/s]

Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([10, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [39]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar10-random", logging_dir=f"~/logs/{DATASET}/cifar10-random", remove_unused_columns=False)

In [40]:
base.reset_seed()

In [41]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
)

In [42]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.4168,0.181341,0.9541,0.955021,0.954239,0.954347
2,0.1541,0.165034,0.9624,0.962915,0.962488,0.962566
3,0.1291,0.155416,0.9655,0.965687,0.96562,0.965611
4,0.1184,0.151969,0.967,0.967206,0.967194,0.967142
5,0.1142,0.150379,0.9659,0.966206,0.966078,0.966071


TrainOutput(global_step=1565, training_loss=0.18649732205814448, metrics={'train_runtime': 269.2038, 'train_samples_per_second': 742.932, 'train_steps_per_second': 5.813, 'total_flos': 9.167108235264e+17, 'train_loss': 0.18649732205814448, 'epoch': 5.0})

In [43]:
base.reset_seed()

In [44]:
student_model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=10, ignore_mismatched_sizes=True)

Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([10, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [45]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar10-random-KD", logging_dir=f"~/logs/{DATASET}/cifar10-random-KD", remove_unused_columns=False)

In [46]:
base.reset_seed()

In [47]:
trainer = base.DistilTrainerInfer(
    student_model=student_model,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics
)

In [48]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.4058,0.191524,0.9469,0.948398,0.947149,0.947096
2,0.1563,0.164814,0.9596,0.960609,0.959817,0.959843
3,0.1302,0.155662,0.9648,0.964981,0.964892,0.964884
4,0.1188,0.153296,0.9652,0.965403,0.965378,0.965351
5,0.1142,0.151308,0.966,0.966318,0.966161,0.966148


TrainOutput(global_step=1565, training_loss=0.18506409032657123, metrics={'train_runtime': 277.9549, 'train_samples_per_second': 719.541, 'train_steps_per_second': 5.63, 'total_flos': 9.167108235264e+17, 'train_loss': 0.18506409032657123, 'epoch': 5.0})

In [49]:
student_model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=10, ignore_mismatched_sizes=True)

Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([10, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [51]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar10-random", logging_dir=f"~/logs/{DATASET}/cifar10-random", remove_unused_columns=False)

In [52]:
base.reset_seed()

In [53]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
)

In [54]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.371,0.181518,0.953,0.953867,0.953154,0.9532
2,0.178,0.15731,0.9652,0.965439,0.965381,0.965337
3,0.1505,0.154486,0.9666,0.966675,0.966802,0.966722
4,0.1393,0.151521,0.9679,0.968015,0.968052,0.968016
5,0.1344,0.150385,0.9687,0.96886,0.968848,0.968827


TrainOutput(global_step=2665, training_loss=0.1946171544058909, metrics={'train_runtime': 455.7883, 'train_samples_per_second': 747.869, 'train_steps_per_second': 5.847, 'total_flos': 1.5623960920772198e+18, 'train_loss': 0.1946171544058909, 'epoch': 5.0})

In [55]:
base.reset_seed()

In [56]:
student_model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=10, ignore_mismatched_sizes=True)

Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([10, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [57]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar10-random-KD", logging_dir=f"~/logs/{DATASET}/cifar10-random-KD", remove_unused_columns=False)

In [58]:
base.reset_seed()

In [59]:
trainer = base.DistilTrainerInfer(
    student_model=student_model,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics
)

In [60]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.3764,0.170028,0.9592,0.960216,0.959266,0.959481
2,0.1949,0.153462,0.9678,0.967953,0.967922,0.96791
3,0.1725,0.151122,0.969,0.969162,0.969125,0.969139
4,0.1623,0.147421,0.97,0.970238,0.970173,0.97017
5,0.1574,0.146599,0.9707,0.970914,0.970859,0.970848


TrainOutput(global_step=2665, training_loss=0.21267932941944917, metrics={'train_runtime': 448.3294, 'train_samples_per_second': 760.312, 'train_steps_per_second': 5.944, 'total_flos': 1.5623960920772198e+18, 'train_loss': 0.21267932941944917, 'epoch': 5.0})