# Notebook pro trénink s destilací nad datasetem CIFAR100
V tomto notebooku je trénován MobileNetV2 nad datasetem CIFAR100, jako učitelsý model je využíván finetunued ViT nad stejným datasetem. 

MobileNetV2 je používán s náhodnou inicializací, tréninkem pouze klasifikační hlavy inicializovaného (předtrénovaného nad ImageNetem) MobileNetuV2 a trénink celého modelu, taktéž inicializovaného. Tyto tři úlohy jsou trénovány bězným způsobem a také s pomocí destilace výše zmíněného modelu.  

Při destilaci je využíváno předpočítaných logitů ze sešitu precompute_logits.

## Import knihoven a definice metod

In [1]:
from transformers import Trainer, EarlyStoppingCallback, AutoModelForImageClassification
from torch.utils.data import ConcatDataset, DataLoader
import torch
import base
import os

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/jovyan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /home/jovyan/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /home/jovyan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!


In [2]:
dataset_part = base.get_dataset_part()
DATASET = "cifar100"

In [3]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA A100 80GB PCIe MIG 1g.10gb


Provedení transformací nad datasetem.

In [4]:
transform = base.base_transforms()

train = base.CustomCIFAR100L(root=f"{os.path.expanduser('~')}/data/100-logits", dataset_part=dataset_part.TRAIN, transform=transform)
eval = base.CustomCIFAR100L(root=f"{os.path.expanduser('~')}/data/100-logits", dataset_part=dataset_part.EVAL, transform=transform)
test = base.CustomCIFAR100L(root=f"{os.path.expanduser('~')}/data/100-logits", dataset_part=dataset_part.TEST, transform=transform)


In [5]:
augment_transform = base.aug_transforms()

train_aug = base.CustomCIFAR100L(root=f"{os.path.expanduser('~')}/data/100-logits", dataset_part=dataset_part.TRAIN, transform=augment_transform)

In [6]:
train_part_cpu = base.CustomCIFAR100(root=f"{os.path.expanduser('~')}/data/100", train=True, transform=transform, device="cpu")
cpu_data_loader = DataLoader(train_part_cpu, batch_size=1, shuffle=False)
train_part_gpu = base.CustomCIFAR100(root=f"{os.path.expanduser('~')}/data/100", train=True, transform=transform, device="cuda")
gpu_data_loader = DataLoader(train_part_gpu, batch_size=1, shuffle=False)

In [7]:
train_aug = base.remove_diff_pred_class(train, train_aug, pytorch_dataset=True)
print(len(train_aug))
train_combo = ConcatDataset([train, train_aug])

Removing entries from augmented dataset that are different from the base one - based on saved logits:   0%|   …

25912


### Standardní trénink náhodně inicializovaného modelu. 

In [85]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/vit-basetrain", logging_dir=f"~/logs/{DATASET}/vit-basetrain", lr=0.0002, warmup_steps=25, epochs=20)
model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=100, ignore_mismatched_sizes=True)

Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([100]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([100, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [86]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [87]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.8891,0.907664,0.7471,0.767294,0.7471,0.746593
2,0.5953,0.687753,0.7992,0.809146,0.7992,0.797973
3,0.3653,0.603011,0.8237,0.833367,0.8237,0.824719
4,0.2321,0.649604,0.8138,0.824856,0.8138,0.814173
5,0.1582,0.631279,0.8277,0.835012,0.8277,0.827885
6,0.1086,0.691468,0.8215,0.826127,0.8215,0.82063
7,0.0799,0.693256,0.83,0.837371,0.83,0.830322
8,0.0587,0.711615,0.8312,0.835736,0.8312,0.831258
9,0.0421,0.74216,0.8323,0.836948,0.8323,0.83227
10,0.0334,0.764705,0.8326,0.839053,0.8326,0.833081


TrainOutput(global_step=6260, training_loss=0.1815786790662109, metrics={'train_runtime': 2853.1255, 'train_samples_per_second': 280.394, 'train_steps_per_second': 2.194, 'total_flos': 3.6877173129216e+18, 'train_loss': 0.1815786790662109, 'epoch': 20.0})

In [88]:
model.eval()

TimmWrapperForImageClassification(
  (timm_model): TinyVit(
    (patch_embed): PatchEmbed(
      (conv1): ConvNorm(
        (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (act): GELU(approximate='none')
      (conv2): ConvNorm(
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (stages): Sequential(
      (0): ConvLayer(
        (blocks): Sequential(
          (0): MBConv(
            (conv1): ConvNorm(
              (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
            (act1): GELU(approximate='none')
            (conv2): ConvNorm(
              (c

In [89]:
trainer.evaluate(test)

{'eval_loss': 0.8002419471740723,
 'eval_accuracy': 0.8539,
 'eval_precision': 0.855637204074962,
 'eval_recall': 0.8539,
 'eval_f1': 0.8540311959806233,
 'eval_runtime': 14.5973,
 'eval_samples_per_second': 685.057,
 'eval_steps_per_second': 5.412,
 'epoch': 20.0}

In [13]:
torch.save(model.state_dict(), f"{os.path.expanduser('~')}/models/{DATASET}/vit-basetrain.pth")

## Definice destilačního tréninku

Třída, která upravuje hugging face trenéra pro destilaci znalostí. Nově pracuje s logity uloženými v datasetu.

### Trénink náhodně inicializovaného modelu s pomocí destilace znalostí

In [17]:
base.reset_seed()

In [18]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/vit-distilltrain", logging_dir=f"~/logs/{DATASET}/vit-distilltrain", remove_unused_columns=False, epochs=20, lr=0.00026, warmup_steps=25, weight_decay=0.002, lambda_param=.8, temp=4)
student_model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=100, ignore_mismatched_sizes=True)

Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([100]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([100, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [19]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [20]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.2494,1.135772,0.7437,0.76645,0.7437,0.74417
2,0.7936,0.902559,0.7825,0.797615,0.7825,0.782237
3,0.5445,0.78406,0.8042,0.815616,0.8042,0.804493
4,0.4013,0.721304,0.8163,0.827228,0.8163,0.817295
5,0.3254,0.704685,0.8214,0.831154,0.8214,0.822258
6,0.2731,0.698496,0.8246,0.833959,0.8246,0.824671
7,0.2354,0.661701,0.8359,0.843434,0.8359,0.836983
8,0.2054,0.646265,0.8325,0.840137,0.8325,0.833397
9,0.1868,0.613126,0.8418,0.847736,0.8418,0.842513
10,0.1708,0.614429,0.8386,0.845444,0.8386,0.839597


TrainOutput(global_step=6260, training_loss=0.3249329813753073, metrics={'train_runtime': 2851.7587, 'train_samples_per_second': 280.529, 'train_steps_per_second': 2.195, 'total_flos': 3.6877173129216e+18, 'train_loss': 0.3249329813753073, 'epoch': 20.0})

In [21]:
student_model.eval()

TimmWrapperForImageClassification(
  (timm_model): TinyVit(
    (patch_embed): PatchEmbed(
      (conv1): ConvNorm(
        (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (act): GELU(approximate='none')
      (conv2): ConvNorm(
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (stages): Sequential(
      (0): ConvLayer(
        (blocks): Sequential(
          (0): MBConv(
            (conv1): ConvNorm(
              (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
            (act1): GELU(approximate='none')
            (conv2): ConvNorm(
              (c

In [22]:
trainer.evaluate(test)

{'eval_loss': 0.4296223223209381,
 'eval_accuracy': 0.857,
 'eval_precision': 0.8619413810627935,
 'eval_recall': 0.8570000000000003,
 'eval_f1': 0.8575483875362276,
 'eval_runtime': 14.3963,
 'eval_samples_per_second': 694.622,
 'eval_steps_per_second': 5.488,
 'epoch': 20.0}

In [23]:
torch.save(student_model.state_dict(), f"{os.path.expanduser('~')}/models/{DATASET}/vit-distilltrain.pth")

In [47]:
base.reset_seed()

In [48]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/vit-base-aug", logging_dir=f"~/logs/{DATASET}/vit-base-aug", lr=0.0002, weight_decay=0.006,	 warmup_steps=25, epochs=20)
model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=100, ignore_mismatched_sizes=True)

Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([100]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([100, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [49]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [50]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.4794,0.80427,0.7674,0.78308,0.7674,0.765975
2,0.4696,0.64772,0.8076,0.817776,0.8076,0.806484
3,0.2749,0.68239,0.8117,0.82202,0.8117,0.810421
4,0.1784,0.642784,0.827,0.832708,0.827,0.826682
5,0.1218,0.73099,0.8164,0.824934,0.8164,0.816781
6,0.0895,0.734367,0.8207,0.830008,0.8207,0.81998
7,0.0663,0.755954,0.8263,0.833348,0.8263,0.827101
8,0.0507,0.755929,0.8342,0.83745,0.8342,0.833732
9,0.0396,0.796329,0.832,0.836739,0.832,0.831415
10,0.0281,0.820643,0.8356,0.839777,0.8356,0.835426


TrainOutput(global_step=10300, training_loss=0.1429132774942419, metrics={'train_runtime': 4491.7074, 'train_samples_per_second': 293.483, 'train_steps_per_second': 2.293, 'total_flos': 6.076620588232212e+18, 'train_loss': 0.1429132774942419, 'epoch': 20.0})

In [51]:
model.eval()

TimmWrapperForImageClassification(
  (timm_model): TinyVit(
    (patch_embed): PatchEmbed(
      (conv1): ConvNorm(
        (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (act): GELU(approximate='none')
      (conv2): ConvNorm(
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (stages): Sequential(
      (0): ConvLayer(
        (blocks): Sequential(
          (0): MBConv(
            (conv1): ConvNorm(
              (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
            (act1): GELU(approximate='none')
            (conv2): ConvNorm(
              (c

In [52]:
trainer.evaluate(test)

{'eval_loss': 0.8661243319511414,
 'eval_accuracy': 0.854,
 'eval_precision': 0.8560095977657519,
 'eval_recall': 0.8539999999999999,
 'eval_f1': 0.8541807597168135,
 'eval_runtime': 14.5049,
 'eval_samples_per_second': 689.421,
 'eval_steps_per_second': 5.446,
 'epoch': 20.0}

In [53]:
torch.save(model.state_dict(), f"{os.path.expanduser('~')}/models/{DATASET}/vit-basetrain-aug.pth")

Destilace

In [79]:
base.reset_seed()

In [80]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/vit-distill-aug", logging_dir=f"~/logs/{DATASET}/vit-distill-aug", remove_unused_columns=False, weight_decay=0.007, epochs=20, lr=0.00023, lambda_param=.8, temp=4)
student_model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=100, ignore_mismatched_sizes=True)

Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([100]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([100, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [81]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [82]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.7407,0.950765,0.7773,0.785016,0.7773,0.77655
2,0.6704,0.78859,0.8061,0.816434,0.8061,0.806857
3,0.4696,0.729935,0.8207,0.826967,0.8207,0.82008
4,0.3669,0.6827,0.8273,0.835293,0.8273,0.827602
5,0.3059,0.679056,0.8302,0.837416,0.8302,0.830512
6,0.2636,0.638735,0.8364,0.842191,0.8364,0.837108
7,0.2358,0.651164,0.8353,0.842438,0.8353,0.836101
8,0.2131,0.621165,0.8391,0.845348,0.8391,0.839817
9,0.1939,0.605474,0.8441,0.848942,0.8441,0.844689
10,0.1786,0.583228,0.8462,0.851931,0.8462,0.847011


TrainOutput(global_step=9785, training_loss=0.303062785542858, metrics={'train_runtime': 4270.5001, 'train_samples_per_second': 308.685, 'train_steps_per_second': 2.412, 'total_flos': 5.772789558820602e+18, 'train_loss': 0.303062785542858, 'epoch': 19.0})

In [83]:
student_model.eval()

TimmWrapperForImageClassification(
  (timm_model): TinyVit(
    (patch_embed): PatchEmbed(
      (conv1): ConvNorm(
        (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (act): GELU(approximate='none')
      (conv2): ConvNorm(
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (stages): Sequential(
      (0): ConvLayer(
        (blocks): Sequential(
          (0): MBConv(
            (conv1): ConvNorm(
              (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
            (act1): GELU(approximate='none')
            (conv2): ConvNorm(
              (c

In [84]:
trainer.evaluate(test)

{'eval_loss': 0.45350322127342224,
 'eval_accuracy': 0.854,
 'eval_precision': 0.8582488555626767,
 'eval_recall': 0.8540000000000001,
 'eval_f1': 0.8544667534154797,
 'eval_runtime': 14.7511,
 'eval_samples_per_second': 677.916,
 'eval_steps_per_second': 5.356,
 'epoch': 19.0}

In [43]:
torch.save(student_model.state_dict(), f"{os.path.expanduser('~')}/models/{DATASET}/vit-distilltrain_aug.pth")

In [44]:
base.count_parameters(student_model)

model size: 21.339MB.
Total Trainable Params: 5103864.


Unnamed: 0,Modules,Parameters
0,timm_model.patch_embed.conv1.conv.weight,864
1,timm_model.patch_embed.conv1.bn.weight,32
2,timm_model.patch_embed.conv1.bn.bias,32
3,timm_model.patch_embed.conv2.conv.weight,18432
4,timm_model.patch_embed.conv2.bn.weight,64
...,...,...
210,timm_model.stages.3.blocks.1.local_conv.bn.bias,320
211,timm_model.head.norm.weight,320
212,timm_model.head.norm.bias,320
213,timm_model.head.fc.weight,32000


In [45]:
cpu_benchmark = base.BenchMarkRunner(student_model, cpu_data_loader, "cpu", 1000)
print(cpu_benchmark.run_benchmark())

<torch.utils.benchmark.utils.common.Measurement object at 0x76a8e5fc8b80>
self.infer_speed_comp()
  38.48 ms
  1 measurement, 1000 runs , 6 threads


In [46]:
gpu_benchmark = base.BenchMarkRunner(student_model, gpu_data_loader, "cuda", 1000)
print(gpu_benchmark.run_benchmark())

<torch.utils.benchmark.utils.common.Measurement object at 0x76a8f2139b70>
self.infer_speed_comp()
  10.04 ms
  1 measurement, 1000 runs , 6 threads
