# Notebook pro trénink s destilací nad datasetem CIFAR10
V tomto notebooku je trénován MobileNetV2 nad datasetem CIFAR10, jako učitelsý model je využíván finetunued ViT nad stejným datasetem. 

MobileNetV2 je používán s náhodnou inicializací, tréninkem pouze klasifikační hlavy inicializovaného (předtrénovaného nad ImageNetem) MobileNetuV2 a trénink celého modelu, taktéž inicializovaného. Tyto tři úlohy jsou trénovány bězným způsobem a také s pomocí destilace výše zmíněného modelu.  

Při destilaci je využíváno předpočítaných logitů ze sešitu precompute_logits.

## Import knihoven a definice metod

In [2]:
from transformers import Trainer, EarlyStoppingCallback
from torch.utils.data import ConcatDataset
import torch

import base

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/jovyan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /home/jovyan/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /home/jovyan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!


In [3]:
dataset_part = base.get_dataset_part()

Resetování náhodného seedu pro replikovatelnost výsledků.

In [4]:
base.reset_seed()

In [5]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA A100 80GB PCIe MIG 2g.20gb


Provedení transformací nad datasetem.

In [6]:
transform = base.base_transforms()

#Poslední train batch použijeme jako eval část...
test = base.CustomCIFAR10L(root='./data/10-logits', dataset_part=dataset_part.TEST, transform=transform)
train = base.CustomCIFAR10L(root='./data/10-logits', dataset_part=dataset_part.TRAIN, transform=transform)
eval = base.CustomCIFAR10L(root='./data/10-logits', dataset_part=dataset_part.EVAL, transform=transform)

In [7]:
train[0]["labels"]

6

In [8]:
augment_transform = base.aug_transforms()

train_aug = base.CustomCIFAR10L(root='./data/10-logits', dataset_part=dataset_part.TRAIN, transform=augment_transform)

In [9]:
train_aug = base.remove_diff_pred_class(train, train_aug, pytorch_dataset=True)

Removing entries from augmented dataset that are different from the base one - based on saved logits:   0%|   …

In [10]:
train_combo = ConcatDataset([train, train_aug])

In [11]:
# Test rozložení --> Good Enough
import pandas as pd
df = pd.DataFrame(eval.labels)
print(df.value_counts())

0
5    1025
9    1022
3    1016
0    1014
1    1014
8    1003
4     997
6     980
7     977
2     952
Name: count, dtype: int64


### Standardní trénink náhodně inicializovaného modelu. 

In [12]:
base.reset_seed()

In [13]:
training_args = base.get_training_args(output_dir="./results/cifar10-random", logging_dir='./logs/cifar10-random', lr=0.0005,  epochs=30)
model = base.get_random_init_mobilenet(10)
model.to(device)

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [14]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 4)]
)

In [15]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.4566,1.085385,0.6207,0.630447,0.620568,0.617004
2,0.9678,0.759064,0.7326,0.73541,0.732139,0.732459
3,0.7644,0.627911,0.7846,0.785555,0.78517,0.782673
4,0.6321,0.536573,0.8174,0.817306,0.817084,0.815561
5,0.5372,0.528521,0.8202,0.820634,0.820247,0.819502
6,0.4635,0.501323,0.827,0.830054,0.827411,0.826516
7,0.393,0.501257,0.8353,0.840354,0.835944,0.834849
8,0.3288,0.452091,0.8492,0.850775,0.849456,0.848926
9,0.2764,0.487182,0.8442,0.853845,0.844097,0.846081
10,0.2321,0.50161,0.8505,0.854877,0.850649,0.851107


TrainOutput(global_step=12259, training_loss=0.3063449862730893, metrics={'train_runtime': 5498.3766, 'train_samples_per_second': 371.979, 'train_steps_per_second': 2.908, 'total_flos': 3.1676131510283796e+18, 'train_loss': 0.3063449862730893, 'epoch': 23.0})

In [16]:
model.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [17]:
trainer.evaluate(test)

{'eval_loss': 0.6681524515151978,
 'eval_accuracy': 0.8631,
 'eval_precision': 0.8664824830150939,
 'eval_recall': 0.8631,
 'eval_f1': 0.8636231878553907,
 'eval_runtime': 28.5516,
 'eval_samples_per_second': 350.243,
 'eval_steps_per_second': 2.767,
 'epoch': 23.0}

## Definice destilačního tréninku

Třída, která upravuje hugging face trenéra pro destilaci znalostí. Nově pracuje s logity uloženými v datasetu.

### Trénink náhodně inicializovaného modelu s pomocí destilace znalostí

In [18]:
base.reset_seed()

In [19]:
student_model = base.get_random_init_mobilenet(10)

In [20]:
training_args = base.get_training_args(output_dir="./results/cifar10-random-KD", logging_dir='./logs/cifar10-random-KD', remove_unused_columns=False, epochs=30, lr=0.00047, lambda_param=.75, temp=6)

In [21]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 5)]
)

In [22]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.738,0.557416,0.6395,0.642289,0.638583,0.63096
2,0.4952,0.403076,0.7502,0.76426,0.749433,0.752833
3,0.3998,0.34251,0.7948,0.798838,0.795608,0.791812
4,0.346,0.310966,0.8219,0.823681,0.821637,0.819998
5,0.3091,0.279215,0.8355,0.841818,0.835432,0.835
6,0.2785,0.266389,0.8444,0.845622,0.84465,0.844451
7,0.2547,0.249271,0.8613,0.864555,0.861605,0.86137
8,0.2328,0.236179,0.8654,0.868529,0.865652,0.865553
9,0.214,0.251809,0.8555,0.867711,0.855401,0.858369
10,0.1986,0.248183,0.8603,0.864496,0.860768,0.860053


TrainOutput(global_step=15990, training_loss=0.20749242542832252, metrics={'train_runtime': 7894.7113, 'train_samples_per_second': 259.07, 'train_steps_per_second': 2.025, 'total_flos': 4.1316693274283213e+18, 'train_loss': 0.20749242542832252, 'epoch': 30.0})

In [23]:
student_model.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [24]:
trainer.evaluate(test)

{'eval_loss': 0.20338153839111328,
 'eval_accuracy': 0.8838,
 'eval_precision': 0.8863326211911587,
 'eval_recall': 0.8837999999999999,
 'eval_f1': 0.8841969103507468,
 'eval_runtime': 24.2148,
 'eval_samples_per_second': 412.97,
 'eval_steps_per_second': 3.262,
 'epoch': 30.0}

## Získání inicializovaného MobileNetV2 modelu

In [25]:
base.reset_seed()

In [26]:
model_pretrained = base.get_mobilenet(10)

In [27]:
print(model_pretrained)

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [28]:
model_pretrained = base.freeze_model(model_pretrained)

In [29]:
training_args = base.get_training_args(output_dir="./results/cifar10-pretrained-head", logging_dir='./logs/cifar10-pretrained-head')

In [30]:
trainer = Trainer(
    model=model_pretrained,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 2)]
)

In [31]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.879,1.420145,0.6643,0.663692,0.663245,0.66209
2,1.4109,1.144414,0.7035,0.705178,0.702888,0.702857
3,1.2503,1.043337,0.7108,0.713902,0.710497,0.709513
4,1.1816,0.951599,0.7235,0.726,0.722782,0.722154
5,1.1542,0.96246,0.7265,0.724567,0.72579,0.723991


TrainOutput(global_step=2665, training_loss=1.375192861932751, metrics={'train_runtime': 786.304, 'train_samples_per_second': 433.522, 'train_steps_per_second': 3.389, 'total_flos': 6.886115545713869e+17, 'train_loss': 1.375192861932751, 'epoch': 5.0})

In [32]:
model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [33]:
trainer.evaluate(test)

{'eval_loss': 0.9762470722198486,
 'eval_accuracy': 0.7176,
 'eval_precision': 0.7157769346330052,
 'eval_recall': 0.7175999999999999,
 'eval_f1': 0.715879420988425,
 'eval_runtime': 21.5781,
 'eval_samples_per_second': 463.433,
 'eval_steps_per_second': 3.661,
 'epoch': 5.0}

### Trénink inicializovaného MobileNetV2

In [34]:
base.reset_seed()

In [35]:
model_pretrained_whole = base.get_mobilenet(10)

In [36]:
training_args = base.get_training_args(output_dir="./results/cifar10-pretrained", logging_dir='./logs/cifar10-pretrained')

In [37]:
trainer = Trainer(
    model=model_pretrained_whole,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [38]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.6933,0.301634,0.8947,0.897118,0.894932,0.894389
2,0.2802,0.22807,0.9216,0.9226,0.921798,0.92183
3,0.1815,0.227875,0.9267,0.927004,0.927143,0.926419
4,0.1253,0.225906,0.927,0.928416,0.927152,0.927101
5,0.0944,0.211628,0.9303,0.930654,0.930394,0.930318


Using the latest cached version of the module from /home/jovyan/.cache/huggingface/modules/evaluate_modules/metrics/evaluate-metric--recall/11f90e583db35601050aed380d48e83202a896976b9608432fba9244fb447f24 (last modified on Fri Jan 10 23:14:00 2025) since it couldn't be found locally at evaluate-metric--recall, or remotely on the Hugging Face Hub.


TrainOutput(global_step=2665, training_loss=0.2749346233890383, metrics={'train_runtime': 670.5109, 'train_samples_per_second': 508.388, 'train_steps_per_second': 3.975, 'total_flos': 6.886115545713869e+17, 'train_loss': 0.2749346233890383, 'epoch': 5.0})

In [39]:
model_pretrained_whole.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [40]:
trainer.evaluate(test)

{'eval_loss': 0.22986409068107605,
 'eval_accuracy': 0.928,
 'eval_precision': 0.9280443794202833,
 'eval_recall': 0.9279999999999999,
 'eval_f1': 0.9277933650945471,
 'eval_runtime': 12.9704,
 'eval_samples_per_second': 770.988,
 'eval_steps_per_second': 6.091,
 'epoch': 5.0}

## Trénink s pomocí destilace znalostí inicializovaného MobileNetV2

### Trénink inicializovaného modelu - pouze klasifikační hlavy s pomocí destilace

In [41]:
base.reset_seed()

In [42]:
student_model_pretrained = base.get_mobilenet(10)

In [43]:
student_model_pretrained = base.freeze_model(student_model_pretrained)

In [44]:
training_args = base.get_training_args(output_dir="./results/cifar10-pretrained-head-KD", logging_dir='./logs/cifar10-pretrained-head-KD', remove_unused_columns=False, temp=6, lambda_param=.8)

In [45]:
trainer = base.DistilTrainer(
    student_model=student_model_pretrained,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 2)]
)

In [46]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.8026,0.643857,0.6635,0.671676,0.662443,0.661529
2,0.6254,0.558374,0.7083,0.714589,0.707622,0.707785
3,0.5853,0.53414,0.7136,0.723572,0.713251,0.711829
4,0.5698,0.507588,0.7295,0.738154,0.728758,0.728316
5,0.5639,0.509392,0.7313,0.731489,0.730568,0.729314


TrainOutput(global_step=2665, training_loss=0.6293962145835776, metrics={'train_runtime': 410.5196, 'train_samples_per_second': 830.362, 'train_steps_per_second': 6.492, 'total_flos': 6.886115545713869e+17, 'train_loss': 0.6293962145835776, 'epoch': 5.0})

In [47]:
student_model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [48]:
trainer.evaluate(test)

{'eval_loss': 0.5133946537971497,
 'eval_accuracy': 0.7264,
 'eval_precision': 0.7269993923247379,
 'eval_recall': 0.7264000000000002,
 'eval_f1': 0.7250795615194547,
 'eval_runtime': 13.6754,
 'eval_samples_per_second': 731.239,
 'eval_steps_per_second': 5.777,
 'epoch': 5.0}

### Trénink inicializovaného modelu s pomocí destilace

In [49]:
base.reset_seed()

In [50]:
student_model_pretrained_whole = base.get_mobilenet(10)

In [51]:
training_args = base.get_training_args(output_dir="./results/cifar10-pretrained-KD", logging_dir='./logs/cifar10-pretrained-KD', remove_unused_columns=False, temp=6, lambda_param=1)

In [52]:
trainer = base.DistilTrainer(
    student_model=student_model_pretrained_whole.to(device),
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [53]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.2231,0.109142,0.8927,0.895638,0.893001,0.892645
2,0.107,0.079007,0.9144,0.917583,0.914703,0.914564
3,0.0832,0.077427,0.9131,0.91659,0.913799,0.912279
4,0.0712,0.074837,0.915,0.919454,0.915232,0.915101
5,0.0647,0.068708,0.923,0.925496,0.923362,0.923091


TrainOutput(global_step=2665, training_loss=0.1098458037814772, metrics={'train_runtime': 589.2059, 'train_samples_per_second': 578.541, 'train_steps_per_second': 4.523, 'total_flos': 6.886115545713869e+17, 'train_loss': 0.1098458037814772, 'epoch': 5.0})

In [54]:
student_model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [55]:
trainer.evaluate(test)

{'eval_loss': 0.06962764263153076,
 'eval_accuracy': 0.9185,
 'eval_precision': 0.9205598074577601,
 'eval_recall': 0.9185000000000001,
 'eval_f1': 0.9183121257023202,
 'eval_runtime': 13.2719,
 'eval_samples_per_second': 753.47,
 'eval_steps_per_second': 5.952,
 'epoch': 5.0}