# Notebook pro trénink s destilací nad datasetem CIFAR10
V tomto notebooku je trénován MobileNetV2 nad datasetem CIFAR10, jako učitelsý model je využíván finetunued ViT nad stejným datasetem. 

MobileNetV2 je používán s náhodnou inicializací, tréninkem pouze klasifikační hlavy inicializovaného (předtrénovaného nad ImageNetem) MobileNetuV2 a trénink celého modelu, taktéž inicializovaného. Tyto tři úlohy jsou trénovány bězným způsobem a také s pomocí destilace výše zmíněného modelu.  

Při destilaci je využíváno předpočítaných logitů ze sešitu precompute_logits.

## Import knihoven a definice metod

In [1]:
from transformers import AutoModelForImageClassification
from torch.utils.data import ConcatDataset
import pandas as pd
import torch
import base
import os

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/jovyan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /home/jovyan/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /home/jovyan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!


In [2]:
dataset_part = base.get_dataset_part()

Resetování náhodného seedu pro replikovatelnost výsledků.

In [3]:
base.reset_seed()

In [4]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA A100 80GB PCIe MIG 2g.20gb


Provedení transformací nad datasetem.

In [5]:
DATASET = "cifar100"

In [6]:
transform = base.base_transforms()

#Poslední train batch použijeme jako eval část...
test = base.CustomCIFAR100L(root=f"{os.path.expanduser('~')}/data/100-logits", dataset_part=dataset_part.TEST, transform=transform)
train = base.CustomCIFAR100L(root=f"{os.path.expanduser('~')}/data/100-logits", dataset_part=dataset_part.TRAIN, transform=transform)
eval = base.CustomCIFAR100L(root=f"{os.path.expanduser('~')}/data/100-logits", dataset_part=dataset_part.EVAL, transform=transform)

In [7]:
augment_transform = base.aug_transforms()
train_aug = base.CustomCIFAR100L(root=f"{os.path.expanduser('~')}/data/100-logits", dataset_part=dataset_part.TRAIN, transform=augment_transform)

In [8]:
train_aug = base.remove_diff_pred_class(train, train_aug, pytorch_dataset=True)
train_combo = ConcatDataset([train, train_aug])

Removing entries from augmented dataset that are different from the base one - based on saved logits:   0%|   …

In [9]:
train[0]["labels"]

tensor(34)

In [10]:
# Test rozložení --> Good Enough
df = pd.DataFrame(eval.labels)
print(df.value_counts())

0 
0     100
63    100
73    100
72    100
71    100
     ... 
30    100
29    100
28    100
27    100
99    100
Name: count, Length: 100, dtype: int64


### Standardní trénink náhodně inicializovaného modelu. 

In [11]:
student_model = base.get_mobilenet(100)

In [12]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar100-random", logging_dir=f"~/logs/{DATASET}/cifar100-random", remove_unused_columns=False)

In [13]:
base.reset_seed()

In [14]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
)

In [15]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,3.1473,2.204412,0.5057,0.538717,0.5057,0.487075
2,1.813,1.614076,0.6153,0.630102,0.6153,0.608206
3,1.3974,1.405575,0.6534,0.658984,0.6534,0.647178
4,1.2141,1.407125,0.6415,0.663386,0.6415,0.639384
5,1.133,1.281165,0.6765,0.681558,0.6765,0.670228


TrainOutput(global_step=1565, training_loss=1.7409627408646167, metrics={'train_runtime': 379.7494, 'train_samples_per_second': 526.663, 'train_steps_per_second': 4.121, 'total_flos': 4.248451694592e+17, 'train_loss': 1.7409627408646167, 'epoch': 5.0})

In [16]:
base.reset_seed()

In [17]:
student_model = base.get_mobilenet(100)
teacher_model = AutoModelForImageClassification.from_pretrained(
    "Ahmed9275/Vit-Cifar100",
    num_labels=100,
)
teacher_model.eval()
teacher_model.to(device)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [18]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar100-random-KD", logging_dir=f"~/logs/{DATASET}/cifar100-random-KD", remove_unused_columns=False)

In [19]:
base.reset_seed()

In [20]:
trainer = base.DistilTrainerInfer(
    student_model=student_model,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics
)

In [21]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,3.1068,2.150495,0.5184,0.557492,0.5184,0.504905
2,1.7756,1.608173,0.6137,0.634833,0.6137,0.608372
3,1.3797,1.400802,0.6522,0.659143,0.6522,0.646228
4,1.1984,1.384205,0.6465,0.666467,0.6465,0.645163
5,1.1153,1.272863,0.6844,0.689867,0.6844,0.678011


TrainOutput(global_step=1565, training_loss=1.7151414962622304, metrics={'train_runtime': 817.9248, 'train_samples_per_second': 244.521, 'train_steps_per_second': 1.913, 'total_flos': 4.248451694592e+17, 'train_loss': 1.7151414962622304, 'epoch': 5.0})

In [22]:
student_model = base.get_mobilenet(100)

In [23]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar100-random", logging_dir=f"~/logs/{DATASET}/cifar100-random", remove_unused_columns=False)

In [24]:
base.reset_seed()

In [25]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
)

In [26]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.7516,1.76495,0.5888,0.595461,0.5888,0.577706
2,1.5115,1.315853,0.6698,0.673487,0.6698,0.663694
3,1.1945,1.174958,0.7021,0.704691,0.7021,0.697374
4,1.0472,1.123964,0.7102,0.716218,0.7102,0.708246
5,0.9753,1.095964,0.7143,0.71762,0.7143,0.711574


TrainOutput(global_step=2575, training_loss=1.4960278201797634, metrics={'train_runtime': 570.9148, 'train_samples_per_second': 577.249, 'train_steps_per_second': 4.51, 'total_flos': 7.000598702348698e+17, 'train_loss': 1.4960278201797634, 'epoch': 5.0})

In [27]:
base.reset_seed()

In [28]:
student_model = base.get_mobilenet(100)

In [29]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar100-random-KD", logging_dir=f"~/logs/{DATASET}/cifar100-random-KD", remove_unused_columns=False)

In [30]:
base.reset_seed()

In [31]:
trainer = base.DistilTrainerInfer(
    student_model=student_model,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics
)

In [32]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.6485,1.737405,0.5965,0.607579,0.5965,0.586606
2,1.442,1.282101,0.6804,0.684411,0.6804,0.675516
3,1.1397,1.147731,0.7041,0.709171,0.7041,0.700798
4,1.0004,1.091868,0.7211,0.72737,0.7211,0.719718
5,0.9325,1.064466,0.7259,0.729126,0.7259,0.723699


TrainOutput(global_step=2575, training_loss=1.4326189059655643, metrics={'train_runtime': 1251.452, 'train_samples_per_second': 263.342, 'train_steps_per_second': 2.058, 'total_flos': 7.000598702348698e+17, 'train_loss': 1.4326189059655643, 'epoch': 5.0})

In [38]:
student_model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=100, ignore_mismatched_sizes=True)

Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([100]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([100, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [39]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar100-random", logging_dir=f"~/logs/{DATASET}/cifar100-random", remove_unused_columns=False)

In [40]:
base.reset_seed()

In [41]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
)

In [42]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.9039,1.798962,0.6641,0.678632,0.6641,0.652051
2,1.3889,1.188195,0.7616,0.764834,0.7616,0.758666
3,0.964,0.973595,0.7907,0.792471,0.7907,0.789026
4,0.7661,0.881513,0.8011,0.802506,0.8011,0.800094
5,0.6712,0.85028,0.8066,0.807335,0.8066,0.805465


TrainOutput(global_step=1565, training_loss=1.3388394389289637, metrics={'train_runtime': 417.9989, 'train_samples_per_second': 478.47, 'train_steps_per_second': 3.744, 'total_flos': 9.219293282304e+17, 'train_loss': 1.3388394389289637, 'epoch': 5.0})

In [43]:
base.reset_seed()

In [44]:
student_model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=100, ignore_mismatched_sizes=True)

Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([100]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([100, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [45]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar100-random-KD", logging_dir=f"~/logs/{DATASET}/cifar100-random-KD", remove_unused_columns=False)

In [46]:
base.reset_seed()

In [47]:
trainer = base.DistilTrainerInfer(
    student_model=student_model,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics
)

In [48]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.9673,1.950313,0.6329,0.647448,0.6329,0.619488
2,1.5237,1.287865,0.7431,0.744981,0.7431,0.740344
3,1.0597,1.04532,0.78,0.780772,0.78,0.777969
4,0.8439,0.937034,0.7932,0.793578,0.7932,0.791511
5,0.7451,0.90584,0.7989,0.799038,0.7989,0.797425


TrainOutput(global_step=1565, training_loss=1.427964621839432, metrics={'train_runtime': 865.1506, 'train_samples_per_second': 231.174, 'train_steps_per_second': 1.809, 'total_flos': 9.219293282304e+17, 'train_loss': 1.427964621839432, 'epoch': 5.0})

In [49]:
student_model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=100, ignore_mismatched_sizes=True)

Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([100]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([100, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [50]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar100-random", logging_dir=f"~/logs/{DATASET}/cifar100-random", remove_unused_columns=False)

In [51]:
base.reset_seed()

In [52]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
)

In [53]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.4325,1.450343,0.7074,0.711554,0.7074,0.701835
2,1.074,0.965786,0.7797,0.781598,0.7797,0.778404
3,0.7308,0.794486,0.8051,0.806699,0.8051,0.803591
4,0.5713,0.728414,0.8175,0.820518,0.8175,0.817077
5,0.4929,0.708025,0.819,0.820449,0.819,0.818538


TrainOutput(global_step=2575, training_loss=1.0602800269265777, metrics={'train_runtime': 641.612, 'train_samples_per_second': 513.644, 'train_steps_per_second': 4.013, 'total_flos': 1.519155147058053e+18, 'train_loss': 1.0602800269265777, 'epoch': 5.0})

In [54]:
base.reset_seed()

In [55]:
student_model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=100, ignore_mismatched_sizes=True)

Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([100]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([100, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [56]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar100-random-KD", logging_dir=f"~/logs/{DATASET}/cifar100-random-KD", remove_unused_columns=False)

In [57]:
base.reset_seed()

In [58]:
trainer = base.DistilTrainerInfer(
    student_model=student_model,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics
)

In [59]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.312,1.370719,0.7286,0.732267,0.7286,0.724716
2,1.012,0.920162,0.7986,0.80136,0.7986,0.797943
3,0.704,0.766804,0.8142,0.81528,0.8142,0.812885
4,0.5665,0.704982,0.8255,0.827208,0.8255,0.824798
5,0.4976,0.684593,0.8302,0.831809,0.8302,0.829788


TrainOutput(global_step=2575, training_loss=1.0184253522261832, metrics={'train_runtime': 1327.8856, 'train_samples_per_second': 248.184, 'train_steps_per_second': 1.939, 'total_flos': 1.519155147058053e+18, 'train_loss': 1.0184253522261832, 'epoch': 5.0})