# Notebook pro trénink s destilací nad datasetem CIFAR10
V tomto notebooku je trénován MobileNetV2 nad datasetem CIFAR10, jako učitelsý model je využíván finetunued ViT nad stejným datasetem. 

MobileNetV2 je používán s náhodnou inicializací, tréninkem pouze klasifikační hlavy inicializovaného (předtrénovaného nad ImageNetem) MobileNetuV2 a trénink celého modelu, taktéž inicializovaného. Tyto tři úlohy jsou trénovány bězným způsobem a také s pomocí destilace výše zmíněného modelu.  

Při destilaci je využíváno předpočítaných logitů ze sešitu precompute_logits.

In [1]:
%pip install transformers[torch] huggingface_hub datasets evaluate torchvision

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)
Installing collected packages: evaluate
Successfully installed evaluate-0.4.3

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


## Import knihoven a definice metod

In [1]:
from transformers import Trainer, EarlyStoppingCallback
from torch.utils.data import ConcatDataset
import torch

import base

In [2]:
dataset_part = base.get_dataset_part()

Resetování náhodného seedu pro replikovatelnost výsledků.

In [3]:
base.reset_seed()

In [4]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA A100 80GB PCIe MIG 2g.20gb


Provedení transformací nad datasetem.

In [5]:
transform = base.base_transforms()

#Poslední train batch použijeme jako eval část...
test = base.CustomCIFAR10L(root='./data/10-logits', dataset_part=dataset_part.TEST, transform=transform)
train = base.CustomCIFAR10L(root='./data/10-logits', dataset_part=dataset_part.TRAIN, transform=transform)
eval = base.CustomCIFAR10L(root='./data/10-logits', dataset_part=dataset_part.EVAL, transform=transform)

In [6]:
augment_transform = base.aug_transforms()

train_aug = base.CustomCIFAR10L(root='./data/10-logits', dataset_part=dataset_part.TRAIN, transform=augment_transform)

In [7]:
train_aug = base.remove_diff_pred_class(train, train_aug)

In [8]:
train_combo = ConcatDataset([train, train_aug])

In [9]:
# Test rozložení --> Good Enough
import pandas as pd
df = pd.DataFrame(eval.labels)
print(df.value_counts())

0
5    1025
9    1022
3    1016
0    1014
1    1014
8    1003
4     997
6     980
7     977
2     952
Name: count, dtype: int64


### Standardní trénink náhodně inicializovaného modelu. 

In [10]:
base.reset_seed()

In [11]:
training_args = base.get_training_args(output_dir="./results/cifar10-random", logging_dir='./logs/cifar10-random', lr=0.0005,  epochs=30)
model = base.get_random_init_mobilenet(10)
model.to(device)

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [12]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 4)]
)

In [13]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.4638,1.03793,0.6411,0.641416,0.640152,0.634947
2,0.9601,0.749105,0.7383,0.741586,0.737722,0.737781
3,0.7591,0.639601,0.7791,0.778725,0.779504,0.776135
4,0.6277,0.555916,0.8069,0.808455,0.806523,0.80439
5,0.5319,0.502715,0.829,0.828873,0.828787,0.827541
6,0.451,0.486683,0.8313,0.831596,0.831682,0.829763
7,0.3895,0.470692,0.8402,0.846481,0.840874,0.840412
8,0.3259,0.435173,0.8526,0.856123,0.852471,0.853282
9,0.269,0.453755,0.8557,0.862448,0.855463,0.857196
10,0.2197,0.490151,0.8511,0.854547,0.851239,0.851187


TrainOutput(global_step=15990, training_loss=0.23267211390331882, metrics={'train_runtime': 7692.9563, 'train_samples_per_second': 265.864, 'train_steps_per_second': 2.079, 'total_flos': 4.1316693274283213e+18, 'train_loss': 0.23267211390331882, 'epoch': 30.0})

In [14]:
model.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [15]:
trainer.evaluate(test)

{'eval_loss': 0.793152928352356,
 'eval_accuracy': 0.871,
 'eval_precision': 0.8718504007131591,
 'eval_recall': 0.8709999999999999,
 'eval_f1': 0.8707612282318162,
 'eval_runtime': 20.1856,
 'eval_samples_per_second': 495.402,
 'eval_steps_per_second': 3.914,
 'epoch': 30.0}

## Definice destilačního tréninku

Třída, která upravuje hugging face trenéra pro destilaci znalostí. Nově pracuje s logity uloženými v datasetu.

### Trénink náhodně inicializovaného modelu s pomocí destilace znalostí

In [16]:
base.reset_seed()

In [17]:
student_model = base.get_random_init_mobilenet(10)

In [18]:
training_args = base.get_training_args(output_dir="./results/cifar10-random-KD", logging_dir='./logs/cifar10-random-KD', remove_unused_columns=False, epochs=30, lr=0.00047, lambda_param=.75, temp=6)

In [20]:
trainer = base.ImageDistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 5)]
)

In [21]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.7383,0.545457,0.6488,0.653876,0.648107,0.644203
2,0.4966,0.396176,0.754,0.767343,0.753355,0.755172
3,0.4036,0.344784,0.7898,0.796058,0.790547,0.785987
4,0.3495,0.320562,0.8053,0.806584,0.805155,0.802286
5,0.3108,0.286864,0.8359,0.838791,0.836096,0.835011
6,0.281,0.276472,0.8376,0.842309,0.83769,0.838338
7,0.2567,0.255284,0.8504,0.858105,0.851103,0.850481
8,0.2357,0.239922,0.8622,0.865662,0.862366,0.862374
9,0.2164,0.244177,0.8645,0.873414,0.864458,0.866097
10,0.2009,0.248147,0.8601,0.86313,0.860333,0.859171


TrainOutput(global_step=15990, training_loss=0.2088465210495329, metrics={'train_runtime': 9314.3709, 'train_samples_per_second': 219.583, 'train_steps_per_second': 1.717, 'total_flos': 4.1316693274283213e+18, 'train_loss': 0.2088465210495329, 'epoch': 30.0})

In [22]:
student_model.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [23]:
trainer.evaluate(test)

{'eval_loss': 0.20578446984291077,
 'eval_accuracy': 0.8831,
 'eval_precision': 0.8854887713656214,
 'eval_recall': 0.8831,
 'eval_f1': 0.8832870382361205,
 'eval_runtime': 32.7221,
 'eval_samples_per_second': 305.604,
 'eval_steps_per_second': 2.414,
 'epoch': 30.0}

## Získání inicializovaného MobileNetV2 modelu

In [None]:
base.reset_seed()

In [None]:
model_pretrained = base.get_mobilenet(10)

In [31]:
print(model_pretrained)

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [None]:
model_pretrained = base.freeze_model(model_pretrained)

In [None]:
training_args = base.get_training_args(output_dir="./results/cifar10-pretrained-head", logging_dir='./logs/cifar10-pretrained-head')

In [None]:
trainer = Trainer(
    model=model_pretrained,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 2)]
)

In [35]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.9835,1.666508,0.5835,0.586924,0.582902,0.578297
2,1.5151,1.370681,0.655,0.658174,0.654084,0.651459
3,1.3037,1.236479,0.6755,0.671237,0.674684,0.66926
4,1.2076,1.175666,0.6895,0.690785,0.689174,0.687818
5,1.1672,1.183259,0.6826,0.688172,0.682072,0.683982


TrainOutput(global_step=1565, training_loss=1.435398460120058, metrics={'train_runtime': 676.8433, 'train_samples_per_second': 295.489, 'train_steps_per_second': 2.312, 'total_flos': 4.040199217152e+17, 'train_loss': 1.435398460120058, 'epoch': 5.0})

In [36]:
model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [37]:
trainer.evaluate(test)

{'eval_loss': 1.1798793077468872,
 'eval_accuracy': 0.6864,
 'eval_precision': 0.6879813785360978,
 'eval_recall': 0.6864,
 'eval_f1': 0.6848686230961495,
 'eval_runtime': 24.3581,
 'eval_samples_per_second': 410.541,
 'eval_steps_per_second': 3.243,
 'epoch': 5.0}

### Trénink inicializovaného MobileNetV2

In [None]:
base.reset_seed()

In [None]:
model_pretrained_whole = base.get_mobilenet(10)

In [None]:
training_args = base.get_training_args(output_dir="./results/cifar10-pretrained", logging_dir='./logs/cifar10-pretrained')

In [None]:
trainer = Trainer(
    model=model_pretrained_whole,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [42]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.7171,0.357813,0.878,0.883757,0.877975,0.878126
2,0.2419,0.285486,0.9048,0.909034,0.904613,0.90553
3,0.1488,0.261888,0.909,0.910553,0.908979,0.909042
4,0.1002,0.245204,0.9167,0.91819,0.916911,0.916785
5,0.0736,0.254316,0.9164,0.917647,0.91671,0.916418


TrainOutput(global_step=1565, training_loss=0.25632537134920064, metrics={'train_runtime': 768.3689, 'train_samples_per_second': 260.292, 'train_steps_per_second': 2.037, 'total_flos': 4.040199217152e+17, 'train_loss': 0.25632537134920064, 'epoch': 5.0})

In [43]:
model_pretrained_whole.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [44]:
trainer.evaluate(test)

{'eval_loss': 0.2556077837944031,
 'eval_accuracy': 0.9174,
 'eval_precision': 0.9186320015704013,
 'eval_recall': 0.9174000000000001,
 'eval_f1': 0.917269695640841,
 'eval_runtime': 24.0426,
 'eval_samples_per_second': 415.929,
 'eval_steps_per_second': 3.286,
 'epoch': 5.0}

## Trénink s pomocí destilace znalostí inicializovaného MobileNetV2

### Trénink inicializovaného modelu - pouze klasifikační hlavy s pomocí destilace

In [None]:
base.reset_seed()

In [None]:
student_model_pretrained = base.get_mobilenet(10)

In [None]:
student_model_pretrained = base.freeze_model(student_model_pretrained)

In [None]:
training_args = base.get_training_args(output_dir="./results/cifar10-pretrained-head-KD", logging_dir='./logs/cifar10-pretrained-head-KD', remove_unused_columns=False, temp=6, lambda_param=.8)

In [None]:
trainer = base.ImageDistilTrainer(
    student_model=student_model_pretrained,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 2)]
)

In [50]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.9798,0.806478,0.5856,0.589648,0.585169,0.580398
2,0.7477,0.682494,0.6602,0.662438,0.65933,0.655939
3,0.6682,0.639946,0.6793,0.675224,0.67869,0.672222
4,0.6376,0.61699,0.6924,0.692186,0.692019,0.689034
5,0.6255,0.624481,0.6896,0.692095,0.689086,0.689528


TrainOutput(global_step=1565, training_loss=0.731778314852486, metrics={'train_runtime': 611.4863, 'train_samples_per_second': 327.072, 'train_steps_per_second': 2.559, 'total_flos': 4.040199217152e+17, 'train_loss': 0.731778314852486, 'epoch': 5.0})

In [51]:
student_model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [52]:
trainer.evaluate(test)

{'eval_loss': 1.215173602104187,
 'eval_accuracy': 0.688,
 'eval_precision': 0.691842072272977,
 'eval_recall': 0.688,
 'eval_f1': 0.6884649153396096,
 'eval_runtime': 22.6288,
 'eval_samples_per_second': 441.915,
 'eval_steps_per_second': 3.491,
 'epoch': 5.0}

### Trénink inicializovaného modelu s pomocí destilace

In [None]:
base.reset_seed()

In [None]:
student_model_pretrained_whole = base.get_mobilenet(10)

In [None]:
training_args = base.get_training_args(output_dir="./results/cifar10-pretrained-KD", logging_dir='./logs/cifar10-pretrained-KD', remove_unused_columns=False, temp=6, lambda_param=1)

In [None]:
trainer = base.ImageDistilTrainer(
    student_model=student_model_pretrained_whole.to(device),
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [57]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.2951,0.155328,0.878,0.885381,0.877831,0.877973
2,0.1159,0.112287,0.9043,0.908469,0.904112,0.905005
3,0.0813,0.101031,0.9093,0.910739,0.909396,0.909276
4,0.0655,0.087329,0.9176,0.920174,0.917791,0.917662
5,0.0571,0.087939,0.9202,0.921565,0.920564,0.9203


TrainOutput(global_step=1565, training_loss=0.1229662374185678, metrics={'train_runtime': 739.084, 'train_samples_per_second': 270.605, 'train_steps_per_second': 2.117, 'total_flos': 4.040199217152e+17, 'train_loss': 0.1229662374185678, 'epoch': 5.0})

In [58]:
student_model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [59]:
trainer.evaluate(test)

{'eval_loss': 1.5552592277526855,
 'eval_accuracy': 0.9149,
 'eval_precision': 0.9162574545858053,
 'eval_recall': 0.9148999999999999,
 'eval_f1': 0.914701817700383,
 'eval_runtime': 27.3443,
 'eval_samples_per_second': 365.707,
 'eval_steps_per_second': 2.889,
 'epoch': 5.0}