# Notebook pro trénink s destilací nad datasetem CIFAR10
V tomto notebooku je trénován MobileNetV2 nad datasetem CIFAR10, jako učitelsý model je využíván finetunued ViT nad stejným datasetem. 

MobileNetV2 je používán s náhodnou inicializací, tréninkem pouze klasifikační hlavy inicializovaného (předtrénovaného nad ImageNetem) MobileNetuV2 a trénink celého modelu, taktéž inicializovaného. Tyto tři úlohy jsou trénovány bězným způsobem a také s pomocí destilace výše zmíněného modelu.  

Při destilaci je využíváno předpočítaných logitů ze sešitu precompute_logits.

## Import knihoven a definice metod

In [1]:
from transformers import Trainer, EarlyStoppingCallback, AutoModelForImageClassification
from torch.utils.data import DataLoader, ConcatDataset
import pandas as pd
import torch
import base
import os

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/jovyan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /home/jovyan/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /home/jovyan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!


In [2]:
dataset_part = base.get_dataset_part()

Resetování náhodného seedu pro replikovatelnost výsledků.

In [3]:
base.reset_seed()

In [4]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA H100 PCIe


Provedení transformací nad datasetem.

In [5]:
DATASET = "cifar100"

In [6]:
transform = base.base_transforms()

#Poslední train batch použijeme jako eval část...
test = base.CustomCIFAR100L(root=f"{os.path.expanduser('~')}/data/100-logits", dataset_part=dataset_part.TEST, transform=transform)
train = base.CustomCIFAR100L(root=f"{os.path.expanduser('~')}/data/100-logits", dataset_part=dataset_part.TRAIN, transform=transform)
eval = base.CustomCIFAR100L(root=f"{os.path.expanduser('~')}/data/100-logits", dataset_part=dataset_part.EVAL, transform=transform)

In [7]:
augment_transform = base.aug_transforms()
train_aug = base.CustomCIFAR100L(root=f"{os.path.expanduser('~')}/data/100-logits", dataset_part=dataset_part.TRAIN, transform=augment_transform)

In [8]:
train_aug = base.remove_diff_pred_class(train, train_aug, pytorch_dataset=True)
train_combo = ConcatDataset([train, train_aug])

Removing entries from augmented dataset that are different from the base one - based on saved logits:   0%|   …

In [9]:
train[0]["labels"]

tensor(34)

In [10]:
# Test rozložení --> Good Enough
df = pd.DataFrame(eval.labels)
print(df.value_counts())

0 
0     100
63    100
73    100
72    100
71    100
     ... 
30    100
29    100
28    100
27    100
99    100
Name: count, Length: 100, dtype: int64


### Standardní trénink náhodně inicializovaného modelu. 

In [11]:
student_model = base.get_mobilenet(100)

In [12]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar100-random", logging_dir=f"~/logs/{DATASET}/cifar100-random", remove_unused_columns=False)

In [13]:
base.reset_seed()

In [14]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
)

In [15]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,3.1372,2.192527,0.5113,0.545492,0.5113,0.493532
2,1.8018,1.610577,0.6148,0.631019,0.6148,0.608108
3,1.3908,1.400766,0.6538,0.659259,0.6538,0.647266
4,1.2098,1.400598,0.6441,0.666674,0.6441,0.642694
5,1.1269,1.275834,0.6777,0.683856,0.6777,0.671405


TrainOutput(global_step=1565, training_loss=1.7332860587123102, metrics={'train_runtime': 267.3891, 'train_samples_per_second': 747.974, 'train_steps_per_second': 5.853, 'total_flos': 4.248451694592e+17, 'train_loss': 1.7332860587123102, 'epoch': 5.0})

In [16]:
base.reset_seed()

In [17]:
student_model = base.get_mobilenet(100)
teacher_model = AutoModelForImageClassification.from_pretrained(
    "Ahmed9275/Vit-Cifar100",
    num_labels=100,
)
teacher_model.eval()
teacher_model.to(device)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [18]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar100-random-KD", logging_dir=f"~/logs/{DATASET}/cifar100-random-KD", remove_unused_columns=False)

In [19]:
base.reset_seed()

In [20]:
trainer = base.DistilTrainerInfer(
    student_model=student_model,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics
)

In [21]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,3.1152,2.151023,0.5153,0.555202,0.5153,0.500679
2,1.7816,1.612269,0.6122,0.633029,0.6122,0.607108
3,1.3841,1.401302,0.655,0.661947,0.655,0.648693
4,1.202,1.385344,0.6474,0.667519,0.6474,0.646152
5,1.1191,1.275649,0.6814,0.687685,0.6814,0.675182


TrainOutput(global_step=1565, training_loss=1.7204134712584864, metrics={'train_runtime': 275.4416, 'train_samples_per_second': 726.107, 'train_steps_per_second': 5.682, 'total_flos': 4.248451694592e+17, 'train_loss': 1.7204134712584864, 'epoch': 5.0})

In [22]:
student_model = base.get_mobilenet(100)

In [23]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar100-random", logging_dir=f"~/logs/{DATASET}/cifar100-random", remove_unused_columns=False)

In [24]:
base.reset_seed()

In [25]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
)

In [26]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.752,1.748866,0.5887,0.59865,0.5887,0.577558
2,1.5066,1.327883,0.6688,0.677723,0.6688,0.665669
3,1.189,1.17643,0.6996,0.705596,0.6996,0.69655
4,1.0426,1.135068,0.7077,0.715674,0.7077,0.705803
5,0.9708,1.110362,0.7132,0.719548,0.7132,0.710693


TrainOutput(global_step=2575, training_loss=1.4921817520289744, metrics={'train_runtime': 409.0285, 'train_samples_per_second': 805.738, 'train_steps_per_second': 6.295, 'total_flos': 7.000811124933427e+17, 'train_loss': 1.4921817520289744, 'epoch': 5.0})

In [27]:
base.reset_seed()

In [28]:
student_model = base.get_mobilenet(100)

In [29]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar100-random-KD", logging_dir=f"~/logs/{DATASET}/cifar100-random-KD", remove_unused_columns=False)

In [30]:
base.reset_seed()

In [31]:
trainer = base.DistilTrainerInfer(
    student_model=student_model,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics
)

In [32]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.6389,1.704158,0.6007,0.612029,0.6007,0.590272
2,1.4331,1.290551,0.6823,0.690474,0.6823,0.679117
3,1.1336,1.144578,0.7076,0.715188,0.7076,0.705132
4,0.9955,1.10077,0.7142,0.721721,0.7142,0.712157
5,0.9281,1.072174,0.7229,0.727874,0.7229,0.720492


TrainOutput(global_step=2575, training_loss=1.4258579883760618, metrics={'train_runtime': 428.1161, 'train_samples_per_second': 769.814, 'train_steps_per_second': 6.015, 'total_flos': 7.000811124933427e+17, 'train_loss': 1.4258579883760618, 'epoch': 5.0})

In [33]:
student_model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=100, ignore_mismatched_sizes=True)

Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([100]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([100, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [34]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar100-random", logging_dir=f"~/logs/{DATASET}/cifar100-random", remove_unused_columns=False)

In [35]:
base.reset_seed()

In [36]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
)

In [37]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.8672,1.776593,0.6664,0.678417,0.6664,0.654188
2,1.38,1.183323,0.7649,0.767052,0.7649,0.761423
3,0.965,0.980394,0.7866,0.789237,0.7866,0.784953
4,0.7715,0.882884,0.8003,0.801994,0.8003,0.799398
5,0.6768,0.853583,0.8053,0.806176,0.8053,0.804274


TrainOutput(global_step=1565, training_loss=1.3320983399217503, metrics={'train_runtime': 297.8359, 'train_samples_per_second': 671.511, 'train_steps_per_second': 5.255, 'total_flos': 9.219293282304e+17, 'train_loss': 1.3320983399217503, 'epoch': 5.0})

In [38]:
base.reset_seed()

In [39]:
student_model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=100, ignore_mismatched_sizes=True)

Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([100]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([100, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [40]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar100-random-KD", logging_dir=f"~/logs/{DATASET}/cifar100-random-KD", remove_unused_columns=False)

In [41]:
base.reset_seed()

In [42]:
trainer = base.DistilTrainerInfer(
    student_model=student_model,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics
)

In [43]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.8694,1.779445,0.6604,0.674785,0.6604,0.648297
2,1.3813,1.189529,0.7618,0.763692,0.7618,0.759341
3,0.965,0.979173,0.7925,0.793821,0.7925,0.790915
4,0.7715,0.886743,0.8,0.801213,0.8,0.798998
5,0.6779,0.856561,0.8044,0.804471,0.8044,0.80329


TrainOutput(global_step=1565, training_loss=1.333033932816868, metrics={'train_runtime': 300.186, 'train_samples_per_second': 666.254, 'train_steps_per_second': 5.213, 'total_flos': 9.219293282304e+17, 'train_loss': 1.333033932816868, 'epoch': 5.0})

In [44]:
student_model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=100, ignore_mismatched_sizes=True)

Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([100]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([100, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [45]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar100-random", logging_dir=f"~/logs/{DATASET}/cifar100-random", remove_unused_columns=False)

In [46]:
base.reset_seed()

In [47]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
)

In [48]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.3424,1.367719,0.7212,0.728263,0.7212,0.717428
2,1.0198,0.927009,0.7927,0.795879,0.7927,0.791397
3,0.6971,0.774762,0.8115,0.813618,0.8115,0.810379
4,0.5468,0.711697,0.8203,0.822647,0.8203,0.820045
5,0.4721,0.69184,0.8235,0.825281,0.8235,0.823051


TrainOutput(global_step=2575, training_loss=1.0156465637799605, metrics={'train_runtime': 422.6329, 'train_samples_per_second': 779.802, 'train_steps_per_second': 6.093, 'total_flos': 1.5192012435244646e+18, 'train_loss': 1.0156465637799605, 'epoch': 5.0})

In [49]:
base.reset_seed()

In [50]:
student_model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=100, ignore_mismatched_sizes=True)

Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([100]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([100, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [51]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar100-random-KD", logging_dir=f"~/logs/{DATASET}/cifar100-random-KD", remove_unused_columns=False)

In [52]:
base.reset_seed()

In [53]:
trainer = base.DistilTrainerInfer(
    student_model=student_model,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics
)

In [54]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.2752,1.336749,0.7293,0.734611,0.7293,0.726724
2,0.9831,0.900092,0.8005,0.804688,0.8005,0.799618
3,0.6831,0.760107,0.8188,0.821899,0.8188,0.818186
4,0.5476,0.695434,0.8281,0.830291,0.8281,0.827864
5,0.4805,0.677319,0.8299,0.831768,0.8299,0.829535


TrainOutput(global_step=2575, training_loss=0.9939036248957069, metrics={'train_runtime': 473.1663, 'train_samples_per_second': 696.52, 'train_steps_per_second': 5.442, 'total_flos': 1.5192012435244646e+18, 'train_loss': 0.9939036248957069, 'epoch': 5.0})