# Notebook pro trénink s destilací nad datasetem CIFAR100
V tomto notebooku je trénován MobileNetV2 nad datasetem CIFAR100, jako učitelsý model je využíván finetunued ViT nad stejným datasetem. 

MobileNetV2 je používán s náhodnou inicializací, tréninkem pouze klasifikační hlavy inicializovaného (předtrénovaného nad ImageNetem) MobileNetuV2 a trénink celého modelu, taktéž inicializovaného. Tyto tři úlohy jsou trénovány bězným způsobem a také s pomocí destilace výše zmíněného modelu.  

Při destilaci je využíváno předpočítaných logitů ze sešitu precompute_logits.

## Import knihoven a definice metod

In [1]:
from transformers import Trainer, EarlyStoppingCallback, AutoModelForImageClassification
from torch.utils.data import ConcatDataset, DataLoader
import torch
import base
import os

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/jovyan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /home/jovyan/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /home/jovyan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!


In [2]:
dataset_part = base.get_dataset_part()
DATASET = "cifar10"

Inicializovaný MobileNetV2.

In [None]:
base.reset_seed()

In [4]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA A100 80GB PCIe MIG 2g.20gb


Provedení transformací nad datasetem.

In [5]:
transform = base.base_transforms()

train = base.CustomCIFAR10L(root=f"{os.path.expanduser('~')}/data/10-logits", dataset_part=dataset_part.TRAIN, transform=transform)
eval = base.CustomCIFAR10L(root=f"{os.path.expanduser('~')}/data/10-logits", dataset_part=dataset_part.EVAL, transform=transform)
test = base.CustomCIFAR10L(root=f"{os.path.expanduser('~')}/data/10-logits", dataset_part=dataset_part.TEST, transform=transform)


In [6]:
augment_transform = base.aug_transforms()

train_aug = base.CustomCIFAR10L(root=f"{os.path.expanduser('~')}/data/10-logits", dataset_part=dataset_part.TRAIN, transform=augment_transform)

In [7]:
train_part_cpu = base.CustomCIFAR10(root=f"{os.path.expanduser('~')}/data/10", train=True, batch=1, transform=transform, device="cpu")
cpu_data_loader = DataLoader(train_part_cpu, batch_size=1, shuffle=False)
train_part_gpu = base.CustomCIFAR10(root=f"{os.path.expanduser('~')}/data/10", train=True, batch=1, transform=transform, device="cuda")
gpu_data_loader = DataLoader(train_part_gpu, batch_size=1, shuffle=False)

In [8]:
train_aug = base.remove_diff_pred_class(train, train_aug, pytorch_dataset=True)
print(len(train_aug))
train_combo = ConcatDataset([train, train_aug])

Removing entries from augmented dataset that are different from the base one - based on saved logits:   0%|   …

28176


### Standardní trénink náhodně inicializovaného modelu. 

In [97]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/vit-basetrain", logging_dir=f"~/logs/{DATASET}/vit-basetrain", lr=0.0001, weight_decay=0.01, epochs=20, warmup_steps=30)
model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=10, ignore_mismatched_sizes=True)

Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([10, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [98]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [99]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.5262,0.212952,0.9309,0.94098,0.930924,0.932803
2,0.106,0.128141,0.9575,0.957898,0.95769,0.957615
3,0.0556,0.135928,0.9584,0.95928,0.958556,0.958653
4,0.0344,0.137671,0.9608,0.961214,0.960912,0.960863
5,0.0229,0.1427,0.9644,0.965399,0.964493,0.964742
6,0.0165,0.161795,0.9628,0.963063,0.963032,0.96285
7,0.0095,0.151796,0.9671,0.967476,0.96718,0.967291
8,0.0073,0.179141,0.9655,0.965649,0.965744,0.965524
9,0.0041,0.174385,0.9674,0.967808,0.96756,0.967587
10,0.0028,0.16369,0.9716,0.971812,0.971734,0.971747


TrainOutput(global_step=4069, training_loss=0.06093742174198716, metrics={'train_runtime': 1079.7631, 'train_samples_per_second': 740.903, 'train_steps_per_second': 5.798, 'total_flos': 2.38344814116864e+18, 'train_loss': 0.06093742174198716, 'epoch': 13.0})

In [100]:
model.eval()

TimmWrapperForImageClassification(
  (timm_model): TinyVit(
    (patch_embed): PatchEmbed(
      (conv1): ConvNorm(
        (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (act): GELU(approximate='none')
      (conv2): ConvNorm(
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (stages): Sequential(
      (0): ConvLayer(
        (blocks): Sequential(
          (0): MBConv(
            (conv1): ConvNorm(
              (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
            (act1): GELU(approximate='none')
            (conv2): ConvNorm(
              (c

In [101]:
trainer.evaluate(test)

{'eval_loss': 0.17828308045864105,
 'eval_accuracy': 0.9697,
 'eval_precision': 0.9697211758351025,
 'eval_recall': 0.9697000000000001,
 'eval_f1': 0.9696868106456424,
 'eval_runtime': 12.9244,
 'eval_samples_per_second': 773.733,
 'eval_steps_per_second': 6.112,
 'epoch': 13.0}

In [14]:
torch.save(model.state_dict(), f"{os.path.expanduser('~')}/models/{DATASET}/vit-basetrain.pth")

## Definice destilačního tréninku

Třída, která upravuje hugging face trenéra pro destilaci znalostí. Nově pracuje s logity uloženými v datasetu.

### Trénink náhodně inicializovaného modelu s pomocí destilace znalostí

In [91]:
base.reset_seed()

In [92]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/vit-distilltrain", logging_dir=f"~/logs/{DATASET}/vit-distilltrain", remove_unused_columns=False, epochs=20, lr=0.00015, weight_decay=0.008, warmup_steps=20, lambda_param=.75, temp=3.5)
student_model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=10, ignore_mismatched_sizes=True)

Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([10, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [93]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [94]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.2572,0.131818,0.9464,0.947826,0.946772,0.946416
2,0.1149,0.127965,0.9509,0.952269,0.951148,0.951262
3,0.1015,0.114026,0.9627,0.963092,0.962885,0.962813
4,0.0935,0.108415,0.9655,0.966041,0.965762,0.965577
5,0.0888,0.102565,0.9662,0.966967,0.966358,0.966431
6,0.0855,0.101118,0.9692,0.969573,0.969395,0.969327
7,0.0828,0.099718,0.9688,0.969619,0.969036,0.969001
8,0.081,0.09715,0.9688,0.969951,0.969034,0.969102
9,0.0796,0.096051,0.9726,0.973034,0.97281,0.97277
10,0.0784,0.094639,0.974,0.974395,0.974162,0.974156


TrainOutput(global_step=4069, training_loss=0.09946667307218307, metrics={'train_runtime': 1076.6686, 'train_samples_per_second': 743.033, 'train_steps_per_second': 5.814, 'total_flos': 2.38344814116864e+18, 'train_loss': 0.09946667307218307, 'epoch': 13.0})

In [95]:
student_model.eval()

TimmWrapperForImageClassification(
  (timm_model): TinyVit(
    (patch_embed): PatchEmbed(
      (conv1): ConvNorm(
        (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (act): GELU(approximate='none')
      (conv2): ConvNorm(
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (stages): Sequential(
      (0): ConvLayer(
        (blocks): Sequential(
          (0): MBConv(
            (conv1): ConvNorm(
              (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
            (act1): GELU(approximate='none')
            (conv2): ConvNorm(
              (c

In [96]:
trainer.evaluate(test)

{'eval_loss': 0.0974622517824173,
 'eval_accuracy': 0.9724,
 'eval_precision': 0.972546661207717,
 'eval_recall': 0.9724,
 'eval_f1': 0.9723919824347605,
 'eval_runtime': 12.8876,
 'eval_samples_per_second': 775.937,
 'eval_steps_per_second': 6.13,
 'epoch': 13.0}

In [24]:
torch.save(student_model.state_dict(), f"{os.path.expanduser('~')}/models/{DATASET}/vit-distilltrain.pth")

In [60]:
base.reset_seed()

In [61]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/vit-base-aug", logging_dir=f"~/logs/{DATASET}/vit-base-aug", lr=0.0001, weight_decay=0.005, warmup_steps=30, epochs=20)
model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=10, ignore_mismatched_sizes=True)

Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([10, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [62]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [63]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.4219,0.144259,0.9495,0.950644,0.949736,0.949566
2,0.0985,0.14225,0.9541,0.955062,0.954236,0.954116
3,0.0494,0.123086,0.9634,0.963751,0.963463,0.963555
4,0.0304,0.136995,0.9628,0.963264,0.962913,0.962893
5,0.0213,0.165441,0.9625,0.963126,0.962644,0.962587
6,0.0169,0.147528,0.9665,0.966779,0.966642,0.96668
7,0.0108,0.174346,0.9665,0.967026,0.966601,0.966664
8,0.0109,0.170566,0.9677,0.967893,0.967844,0.967828
9,0.0072,0.188392,0.9676,0.96784,0.967711,0.967747
10,0.0064,0.180799,0.9659,0.966163,0.966026,0.966029


TrainOutput(global_step=10127, training_loss=0.03606330337952129, metrics={'train_runtime': 2488.5802, 'train_samples_per_second': 547.911, 'train_steps_per_second': 4.284, 'total_flos': 5.937279324949905e+18, 'train_loss': 0.03606330337952129, 'epoch': 19.0})

In [64]:
model.eval()

TimmWrapperForImageClassification(
  (timm_model): TinyVit(
    (patch_embed): PatchEmbed(
      (conv1): ConvNorm(
        (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (act): GELU(approximate='none')
      (conv2): ConvNorm(
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (stages): Sequential(
      (0): ConvLayer(
        (blocks): Sequential(
          (0): MBConv(
            (conv1): ConvNorm(
              (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
            (act1): GELU(approximate='none')
            (conv2): ConvNorm(
              (c

In [65]:
trainer.evaluate(test)

{'eval_loss': 0.18284110724925995,
 'eval_accuracy': 0.9726,
 'eval_precision': 0.9726192666951757,
 'eval_recall': 0.9725999999999999,
 'eval_f1': 0.9725995211048846,
 'eval_runtime': 12.0414,
 'eval_samples_per_second': 830.467,
 'eval_steps_per_second': 6.561,
 'epoch': 19.0}

In [34]:
torch.save(model.state_dict(), f"{os.path.expanduser('~')}/models/{DATASET}/vit-basetrain-aug.pth")

Destilace

In [108]:
base.reset_seed()

In [109]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/vit-distill-aug", logging_dir=f"~/logs/{DATASET}/vit-distill-aug", remove_unused_columns=False, epochs=20, lr=0.00013, weight_decay=0.002, warmup_steps=30, lambda_param=.4, temp=5.5)
student_model = AutoModelForImageClassification.from_pretrained("timm/tiny_vit_5m_224.in1k", num_labels=10, ignore_mismatched_sizes=True)

Some weights of TimmWrapperForImageClassification were not initialized from the model checkpoint at timm/tiny_vit_5m_224.in1k and are newly initialized because the shapes did not match:
- head.fc.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
- head.fc.weight: found shape torch.Size([1000, 320]) in the checkpoint and torch.Size([10, 320]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [110]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [111]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.3716,0.193846,0.9522,0.953134,0.95242,0.95242
2,0.1857,0.185898,0.9571,0.957366,0.95733,0.957102
3,0.161,0.171625,0.9631,0.963111,0.96329,0.963124
4,0.1497,0.176887,0.9628,0.963029,0.962952,0.96286
5,0.1428,0.167942,0.9672,0.967807,0.967202,0.967411
6,0.1393,0.171771,0.965,0.965319,0.965229,0.965093
7,0.1358,0.166049,0.9654,0.965655,0.965503,0.965545
8,0.1344,0.164976,0.9687,0.969086,0.968775,0.968804
9,0.1321,0.159503,0.9724,0.972594,0.972502,0.97249
10,0.131,0.157075,0.972,0.972074,0.972154,0.972086


TrainOutput(global_step=10660, training_loss=0.1480934873083519, metrics={'train_runtime': 2628.5444, 'train_samples_per_second': 518.736, 'train_steps_per_second': 4.055, 'total_flos': 6.249767710473585e+18, 'train_loss': 0.1480934873083519, 'epoch': 20.0})

In [112]:
student_model.eval()

TimmWrapperForImageClassification(
  (timm_model): TinyVit(
    (patch_embed): PatchEmbed(
      (conv1): ConvNorm(
        (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (act): GELU(approximate='none')
      (conv2): ConvNorm(
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (stages): Sequential(
      (0): ConvLayer(
        (blocks): Sequential(
          (0): MBConv(
            (conv1): ConvNorm(
              (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
            (act1): GELU(approximate='none')
            (conv2): ConvNorm(
              (c

In [113]:
trainer.evaluate(test)

{'eval_loss': 0.14835616946220398,
 'eval_accuracy': 0.9755,
 'eval_precision': 0.9755307694251802,
 'eval_recall': 0.9754999999999999,
 'eval_f1': 0.9754889147441126,
 'eval_runtime': 13.0586,
 'eval_samples_per_second': 765.781,
 'eval_steps_per_second': 6.05,
 'epoch': 20.0}

In [78]:
torch.save(student_model.state_dict(), f"{os.path.expanduser('~')}/models/{DATASET}/vit-distilltrain_aug.pth")

In [None]:
base.count_parameters(student_model)

model size: 21.229MB.
Total Trainable Params: 5074974.


Unnamed: 0,Modules,Parameters
0,timm_model.patch_embed.conv1.conv.weight,864
1,timm_model.patch_embed.conv1.bn.weight,32
2,timm_model.patch_embed.conv1.bn.bias,32
3,timm_model.patch_embed.conv2.conv.weight,18432
4,timm_model.patch_embed.conv2.bn.weight,64
...,...,...
210,timm_model.stages.3.blocks.1.local_conv.bn.bias,320
211,timm_model.head.norm.weight,320
212,timm_model.head.norm.bias,320
213,timm_model.head.fc.weight,3200


In [None]:
cpu_benchmark = base.BenchMarkRunner(student_model, cpu_data_loader, "cpu", 1000)
print(cpu_benchmark.run_benchmark())

<torch.utils.benchmark.utils.common.Measurement object at 0x7af7aae9b040>
self.infer_speed_comp()
  36.94 ms
  1 measurement, 1000 runs , 4 threads


In [None]:
gpu_benchmark = base.BenchMarkRunner(student_model, gpu_data_loader, "cuda", 1000)
print(gpu_benchmark.run_benchmark())

<torch.utils.benchmark.utils.common.Measurement object at 0x7af7aaf3ef20>
self.infer_speed_comp()
  9.86 ms
  1 measurement, 1000 runs , 4 threads
