# Notebook pro trénink s destilací nad datasetem CIFAR10
V tomto notebooku je trénován MobileNetV2 nad datasetem CIFAR10, jako učitelsý model je využíván finetunued ViT nad stejným datasetem. 

MobileNetV2 je používán s náhodnou inicializací, tréninkem pouze klasifikační hlavy inicializovaného (předtrénovaného nad ImageNetem) MobileNetuV2 a trénink celého modelu, taktéž inicializovaného. Tyto tři úlohy jsou trénovány bězným způsobem a také s pomocí destilace výše zmíněného modelu.  

Při destilaci je využíváno předpočítaných logitů ze sešitu precompute_logits.

## Import knihoven a definice metod

In [2]:
from transformers import AutoModelForImageClassification
from torch.utils.data import ConcatDataset
import pandas as pd
import torch
import base
import os

In [3]:
dataset_part = base.get_dataset_part()

Resetování náhodného seedu pro replikovatelnost výsledků.

In [4]:
base.reset_seed()

In [5]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA A100 80GB PCIe MIG 2g.20gb


Provedení transformací nad datasetem.

In [6]:
DATASET = "cifar10"

In [7]:
transform = base.base_transforms()

#Poslední train batch použijeme jako eval část...
test = base.CustomCIFAR10L(root=f"{os.path.expanduser('~')}/data/10-logits", dataset_part=dataset_part.TEST, transform=transform)
train = base.CustomCIFAR10L(root=f"{os.path.expanduser('~')}/data/10-logits", dataset_part=dataset_part.TRAIN, transform=transform)
eval = base.CustomCIFAR10L(root=f"{os.path.expanduser('~')}/data/10-logits", dataset_part=dataset_part.EVAL, transform=transform)

In [8]:
augment_transform = base.aug_transforms()
train_aug = base.CustomCIFAR10L(root=f"{os.path.expanduser('~')}/data/10-logits", dataset_part=dataset_part.TRAIN, transform=augment_transform)

In [9]:
train_aug = base.remove_diff_pred_class(train, train_aug, pytorch_dataset=True)
train_combo = ConcatDataset([train, train_aug])

Removing entries from augmented dataset that are different from the base one - based on saved logits:   0%|   …

In [10]:
train[0]["labels"]

tensor(6)

In [11]:
# Test rozložení --> Good Enough
df = pd.DataFrame(eval.labels)
print(df.value_counts())

0
5    1025
9    1022
3    1016
0    1014
1    1014
8    1003
4     997
6     980
7     977
2     952
Name: count, dtype: int64


### Standardní trénink náhodně inicializovaného modelu. 

In [56]:
student_model = base.get_mobilenet(10)

In [57]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar10-random", logging_dir=f"~/logs/{DATASET}/cifar10-random", remove_unused_columns=False)

In [58]:
base.reset_seed()

In [59]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
)

In [60]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.5619,0.340989,0.8822,0.887258,0.882247,0.881975
2,0.2814,0.294031,0.9089,0.913302,0.908771,0.90954
3,0.2274,0.274416,0.9108,0.912861,0.91083,0.910804
4,0.2025,0.248416,0.9226,0.924363,0.922807,0.9227
5,0.1881,0.251953,0.9215,0.922356,0.921818,0.921473


TrainOutput(global_step=1565, training_loss=0.2922747042232428, metrics={'train_runtime': 378.4098, 'train_samples_per_second': 528.527, 'train_steps_per_second': 4.136, 'total_flos': 4.040199217152e+17, 'train_loss': 0.2922747042232428, 'epoch': 5.0})

In [61]:
base.reset_seed()

In [62]:
student_model = base.get_mobilenet(10)
teacher_model = AutoModelForImageClassification.from_pretrained(
    "aaraki/vit-base-patch16-224-in21k-finetuned-cifar10",
    num_labels=10,
)
teacher_model.eval()
teacher_model.to(device)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [19]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar10-random-KD", logging_dir=f"~/logs/{DATASET}/cifar10-random-KD", remove_unused_columns=False)

In [20]:
base.reset_seed()

In [21]:
trainer = base.DistilTrainerInfer(
    student_model=student_model,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics
)

In [22]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.5631,0.34403,0.8782,0.885102,0.878158,0.878197
2,0.2806,0.286562,0.9088,0.912392,0.908664,0.909394
3,0.227,0.271677,0.9143,0.91605,0.914327,0.914416
4,0.2012,0.249703,0.9198,0.921998,0.92001,0.919921
5,0.1873,0.251812,0.9191,0.920426,0.919408,0.919206


TrainOutput(global_step=1565, training_loss=0.29186008928682855, metrics={'train_runtime': 820.5936, 'train_samples_per_second': 243.726, 'train_steps_per_second': 1.907, 'total_flos': 4.040199217152e+17, 'train_loss': 0.29186008928682855, 'epoch': 5.0})

In [23]:
student_model = base.get_mobilenet(10)

In [24]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar10-random", logging_dir=f"~/logs/{DATASET}/cifar10-random", remove_unused_columns=False)

In [25]:
base.reset_seed()

In [26]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
)

In [27]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.5479,0.304709,0.8974,0.899663,0.897617,0.897191
2,0.3125,0.249679,0.9241,0.925305,0.924307,0.924212
3,0.2597,0.244111,0.9265,0.927182,0.926955,0.926168
4,0.2318,0.237831,0.9285,0.930272,0.928582,0.928577
5,0.2166,0.226526,0.9313,0.932195,0.931433,0.931362


TrainOutput(global_step=2665, training_loss=0.3136786598649302, metrics={'train_runtime': 588.8747, 'train_samples_per_second': 578.867, 'train_steps_per_second': 4.526, 'total_flos': 6.886115545713869e+17, 'train_loss': 0.3136786598649302, 'epoch': 5.0})

In [28]:
base.reset_seed()

In [29]:
student_model = base.get_mobilenet(10)

In [30]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar10-random-KD", logging_dir=f"~/logs/{DATASET}/cifar10-random-KD", remove_unused_columns=False)

In [31]:
base.reset_seed()

In [32]:
trainer = base.DistilTrainerInfer(
    student_model=student_model,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics
)

In [33]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.5451,0.298021,0.8998,0.901568,0.899982,0.899661
2,0.3187,0.243472,0.9255,0.926705,0.925721,0.9256
3,0.2704,0.236763,0.9282,0.928749,0.928695,0.927802
4,0.2454,0.231667,0.9278,0.92982,0.927939,0.927953
5,0.2316,0.221609,0.9345,0.93529,0.934651,0.934527


TrainOutput(global_step=2665, training_loss=0.32222727476767704, metrics={'train_runtime': 1302.0067, 'train_samples_per_second': 261.811, 'train_steps_per_second': 2.047, 'total_flos': 6.886115545713869e+17, 'train_loss': 0.32222727476767704, 'epoch': 5.0})

In [34]:
student_model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=10, ignore_mismatched_sizes=True)

Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([10, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [35]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar10-random", logging_dir=f"~/logs/{DATASET}/cifar10-random", remove_unused_columns=False)

In [36]:
base.reset_seed()

In [37]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
)

In [38]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.4174,0.189456,0.9509,0.95294,0.9509,0.951353
2,0.1567,0.163709,0.9629,0.963876,0.96304,0.963123
3,0.1303,0.155766,0.9654,0.965561,0.96551,0.9655
4,0.1192,0.150646,0.9685,0.968681,0.968635,0.968623
5,0.1145,0.14901,0.9694,0.969656,0.969563,0.969531


TrainOutput(global_step=1565, training_loss=0.18759201829806685, metrics={'train_runtime': 422.6402, 'train_samples_per_second': 473.216, 'train_steps_per_second': 3.703, 'total_flos': 9.167108235264e+17, 'train_loss': 0.18759201829806685, 'epoch': 5.0})

In [39]:
base.reset_seed()

In [40]:
student_model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=10, ignore_mismatched_sizes=True)

Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([10, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [41]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar10-random-KD", logging_dir=f"~/logs/{DATASET}/cifar10-random-KD", remove_unused_columns=False)

In [42]:
base.reset_seed()

In [43]:
trainer = base.DistilTrainerInfer(
    student_model=student_model,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics
)

In [44]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.4018,0.188737,0.9483,0.949496,0.948419,0.948627
2,0.1563,0.161348,0.9609,0.961397,0.961062,0.961114
3,0.1301,0.156382,0.9644,0.964838,0.964578,0.964591
4,0.119,0.153967,0.9658,0.966014,0.965989,0.965932
5,0.1145,0.152464,0.9665,0.966777,0.966693,0.966662


TrainOutput(global_step=1565, training_loss=0.18433443989616613, metrics={'train_runtime': 872.1228, 'train_samples_per_second': 229.326, 'train_steps_per_second': 1.794, 'total_flos': 9.167108235264e+17, 'train_loss': 0.18433443989616613, 'epoch': 5.0})

In [45]:
student_model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=10, ignore_mismatched_sizes=True)

Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([10, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [46]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar10-random", logging_dir=f"~/logs/{DATASET}/cifar10-random", remove_unused_columns=False)

In [47]:
base.reset_seed()

In [48]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
)

In [49]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.3702,0.173078,0.9584,0.95854,0.958553,0.958462
2,0.1741,0.158141,0.9635,0.963699,0.963729,0.963606
3,0.1485,0.152017,0.9686,0.968772,0.968753,0.968724
4,0.1382,0.149792,0.9703,0.970588,0.970429,0.970433
5,0.1335,0.149073,0.9695,0.969806,0.96961,0.969647


TrainOutput(global_step=2665, training_loss=0.19291731957870398, metrics={'train_runtime': 662.5422, 'train_samples_per_second': 514.503, 'train_steps_per_second': 4.022, 'total_flos': 1.5624419276183962e+18, 'train_loss': 0.19291731957870398, 'epoch': 5.0})

In [50]:
base.reset_seed()

In [51]:
student_model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=10, ignore_mismatched_sizes=True)

Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([10, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [52]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/cifar10-random-KD", logging_dir=f"~/logs/{DATASET}/cifar10-random-KD", remove_unused_columns=False)

In [53]:
base.reset_seed()

In [54]:
trainer = base.DistilTrainerInfer(
    student_model=student_model,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics
)

In [55]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.3765,0.168763,0.9586,0.958967,0.95882,0.95873
2,0.1936,0.15475,0.9653,0.965452,0.965497,0.965386
3,0.1708,0.149874,0.9694,0.969623,0.969548,0.969538
4,0.1606,0.145816,0.9715,0.971699,0.971625,0.971635
5,0.1561,0.145137,0.9709,0.971132,0.971039,0.971059


TrainOutput(global_step=2665, training_loss=0.21153791187851784, metrics={'train_runtime': 1370.9898, 'train_samples_per_second': 248.638, 'train_steps_per_second': 1.944, 'total_flos': 1.5624419276183962e+18, 'train_loss': 0.21153791187851784, 'epoch': 5.0})