# Notebook pro trénink s destilací nad datasetem CIFAR10
V tomto notebooku je trénován MobileNetV2 nad datasetem CIFAR10, jako učitelsý model je využíván finetunued ViT nad stejným datasetem. 

MobileNetV2 je používán s náhodnou inicializací, tréninkem pouze klasifikační hlavy inicializovaného (předtrénovaného nad ImageNetem) MobileNetuV2 a trénink celého modelu, taktéž inicializovaného. Tyto tři úlohy jsou trénovány bězným způsobem a také s pomocí destilace výše zmíněného modelu.  

Při destilaci je využíváno předpočítaných logitů ze sešitu precompute_logits.

## Import knihoven a definice metod

In [2]:
from transformers import Trainer, EarlyStoppingCallback
from torch.utils.data import ConcatDataset
import torch

import base

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/jovyan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /home/jovyan/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /home/jovyan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!


In [3]:
dataset_part = base.get_dataset_part()

Resetování náhodného seedu pro replikovatelnost výsledků.

In [4]:
base.reset_seed()

In [5]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA A100 80GB PCIe MIG 2g.20gb


Provedení transformací nad datasetem.

In [6]:
transform = base.base_transforms()

#Poslední train batch použijeme jako eval část...
test = base.CustomCIFAR10L(root='./data/10-logits', dataset_part=dataset_part.TEST, transform=transform)
train = base.CustomCIFAR10L(root='./data/10-logits', dataset_part=dataset_part.TRAIN, transform=transform)
eval = base.CustomCIFAR10L(root='./data/10-logits', dataset_part=dataset_part.EVAL, transform=transform)

In [7]:
train[0]["labels"]

6

In [8]:
augment_transform = base.aug_transforms()

train_aug = base.CustomCIFAR10L(root='./data/10-logits', dataset_part=dataset_part.TRAIN, transform=augment_transform)

In [9]:
train_aug = base.remove_diff_pred_class(train, train_aug, pytorch_dataset=True)

Removing entries from augmented dataset that are different from the base one - based on saved logits:   0%|   …

In [10]:
train_combo = ConcatDataset([train, train_aug])

In [11]:
# Test rozložení --> Good Enough
import pandas as pd
df = pd.DataFrame(eval.labels)
print(df.value_counts())

0
5    1025
9    1022
3    1016
0    1014
1    1014
8    1003
4     997
6     980
7     977
2     952
Name: count, dtype: int64


### Standardní trénink náhodně inicializovaného modelu. 

In [12]:
base.reset_seed()

In [13]:
training_args = base.get_training_args(output_dir="./results/cifar10-random", logging_dir='./logs/cifar10-random', lr=0.0005,  epochs=30)
model = base.get_random_init_mobilenet(10)
model.to(device)

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [14]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 4)]
)

In [15]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.5511,1.185193,0.572,0.594277,0.57201,0.564728
2,0.9851,0.828357,0.7032,0.707348,0.703069,0.70269
3,0.7642,0.741043,0.7522,0.755381,0.751711,0.749523
4,0.6188,0.590698,0.7944,0.804327,0.794533,0.794986
5,0.5236,0.578956,0.7995,0.805582,0.799959,0.798998
6,0.4427,0.587013,0.8018,0.818735,0.801472,0.802597
7,0.3675,0.597391,0.8114,0.815665,0.811934,0.808441
8,0.3117,0.497279,0.8362,0.839014,0.836433,0.836298
9,0.2608,0.591331,0.8257,0.836817,0.824929,0.825767
10,0.2083,0.50583,0.8398,0.842174,0.840338,0.838622


TrainOutput(global_step=9390, training_loss=0.22946450757586792, metrics={'train_runtime': 5242.7742, 'train_samples_per_second': 228.886, 'train_steps_per_second': 1.791, 'total_flos': 2.4241195302912e+18, 'train_loss': 0.22946450757586792, 'epoch': 30.0})

In [16]:
model.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [17]:
trainer.evaluate(test)

{'eval_loss': 0.7575582265853882,
 'eval_accuracy': 0.8647,
 'eval_precision': 0.8659174284226413,
 'eval_recall': 0.8647,
 'eval_f1': 0.8641716293655083,
 'eval_runtime': 29.4759,
 'eval_samples_per_second': 339.261,
 'eval_steps_per_second': 2.68,
 'epoch': 30.0}

## Definice destilačního tréninku

Třída, která upravuje hugging face trenéra pro destilaci znalostí. Nově pracuje s logity uloženými v datasetu.

### Trénink náhodně inicializovaného modelu s pomocí destilace znalostí

In [18]:
base.reset_seed()

In [19]:
student_model = base.get_random_init_mobilenet(10)

In [20]:
training_args = base.get_training_args(output_dir="./results/cifar10-random-KD", logging_dir='./logs/cifar10-random-KD', remove_unused_columns=False, epochs=30, lr=0.00047, lambda_param=.75, temp=6)

In [21]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 5)]
)

In [22]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.8304,0.61108,0.5988,0.620624,0.598449,0.590508
2,0.5314,0.441133,0.7228,0.731297,0.722259,0.721603
3,0.4177,0.399523,0.7588,0.765465,0.758029,0.757509
4,0.3554,0.320996,0.8026,0.811291,0.802362,0.802266
5,0.3105,0.304771,0.8193,0.829683,0.81886,0.821035
6,0.2775,0.318304,0.8088,0.828012,0.808011,0.810201
7,0.2477,0.309018,0.8244,0.834661,0.825062,0.821531
8,0.2246,0.276398,0.8415,0.846507,0.841988,0.840742
9,0.2047,0.279324,0.8448,0.854129,0.844438,0.84555
10,0.1855,0.272511,0.8433,0.849767,0.843931,0.841404


TrainOutput(global_step=9390, training_loss=0.2010892837573164, metrics={'train_runtime': 5242.7032, 'train_samples_per_second': 228.89, 'train_steps_per_second': 1.791, 'total_flos': 2.4241195302912e+18, 'train_loss': 0.2010892837573164, 'epoch': 30.0})

In [23]:
student_model.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [24]:
trainer.evaluate(test)

{'eval_loss': 0.2124260514974594,
 'eval_accuracy': 0.8775,
 'eval_precision': 0.8796237748698346,
 'eval_recall': 0.8775000000000001,
 'eval_f1': 0.8774100994643572,
 'eval_runtime': 28.8978,
 'eval_samples_per_second': 346.047,
 'eval_steps_per_second': 2.734,
 'epoch': 30.0}

## Získání inicializovaného MobileNetV2 modelu

In [25]:
base.reset_seed()

In [26]:
model_pretrained = base.get_mobilenet(10)

In [27]:
print(model_pretrained)

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [28]:
model_pretrained = base.freeze_model(model_pretrained)

In [29]:
training_args = base.get_training_args(output_dir="./results/cifar10-pretrained-head", logging_dir='./logs/cifar10-pretrained-head')

In [30]:
trainer = Trainer(
    model=model_pretrained,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 2)]
)

In [31]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.967,1.642708,0.6006,0.605226,0.600302,0.597025
2,1.4838,1.349267,0.6649,0.668122,0.663935,0.661215
3,1.2713,1.206598,0.6849,0.682437,0.68414,0.678568
4,1.1771,1.14976,0.6988,0.700366,0.698829,0.696635
5,1.1373,1.154203,0.692,0.699147,0.691623,0.693513


TrainOutput(global_step=1565, training_loss=1.407314234724441, metrics={'train_runtime': 553.1527, 'train_samples_per_second': 361.564, 'train_steps_per_second': 2.829, 'total_flos': 4.040199217152e+17, 'train_loss': 1.407314234724441, 'epoch': 5.0})

In [32]:
model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [33]:
trainer.evaluate(test)

{'eval_loss': 1.1549885272979736,
 'eval_accuracy': 0.6912,
 'eval_precision': 0.6930555270604424,
 'eval_recall': 0.6911999999999999,
 'eval_f1': 0.6887658694953495,
 'eval_runtime': 27.4138,
 'eval_samples_per_second': 364.779,
 'eval_steps_per_second': 2.882,
 'epoch': 5.0}

### Trénink inicializovaného MobileNetV2

In [34]:
base.reset_seed()

In [35]:
model_pretrained_whole = base.get_mobilenet(10)

In [36]:
training_args = base.get_training_args(output_dir="./results/cifar10-pretrained", logging_dir='./logs/cifar10-pretrained')

In [37]:
trainer = Trainer(
    model=model_pretrained_whole,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [38]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.7184,0.353336,0.8776,0.882056,0.877694,0.87737
2,0.2408,0.282614,0.9051,0.90969,0.904914,0.905916
3,0.1493,0.261676,0.9094,0.911356,0.909361,0.909485
4,0.1011,0.239741,0.9204,0.921434,0.920613,0.920459
5,0.0758,0.247335,0.9167,0.91799,0.917032,0.916728


TrainOutput(global_step=1565, training_loss=0.25708472523064657, metrics={'train_runtime': 843.3831, 'train_samples_per_second': 237.14, 'train_steps_per_second': 1.856, 'total_flos': 4.040199217152e+17, 'train_loss': 0.25708472523064657, 'epoch': 5.0})

In [39]:
model_pretrained_whole.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [40]:
trainer.evaluate(test)

{'eval_loss': 0.24592812359333038,
 'eval_accuracy': 0.9193,
 'eval_precision': 0.9202001142503461,
 'eval_recall': 0.9193,
 'eval_f1': 0.919114594687519,
 'eval_runtime': 21.9373,
 'eval_samples_per_second': 455.845,
 'eval_steps_per_second': 3.601,
 'epoch': 5.0}

## Trénink s pomocí destilace znalostí inicializovaného MobileNetV2

### Trénink inicializovaného modelu - pouze klasifikační hlavy s pomocí destilace

In [41]:
base.reset_seed()

In [42]:
student_model_pretrained = base.get_mobilenet(10)

In [43]:
student_model_pretrained = base.freeze_model(student_model_pretrained)

In [44]:
training_args = base.get_training_args(output_dir="./results/cifar10-pretrained-head-KD", logging_dir='./logs/cifar10-pretrained-head-KD', remove_unused_columns=False, temp=6, lambda_param=.8)

In [45]:
trainer = base.DistilTrainer(
    student_model=student_model_pretrained,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 2)]
)

In [46]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.895,0.728218,0.5972,0.613117,0.597011,0.59424
2,0.6719,0.616265,0.6707,0.675367,0.669734,0.668132
3,0.601,0.576232,0.6899,0.684775,0.689169,0.684294
4,0.5749,0.55831,0.6982,0.702963,0.698162,0.695983
5,0.5645,0.563195,0.6945,0.702647,0.694184,0.695959


TrainOutput(global_step=1565, training_loss=0.6614379492811502, metrics={'train_runtime': 466.4128, 'train_samples_per_second': 428.805, 'train_steps_per_second': 3.355, 'total_flos': 4.040199217152e+17, 'train_loss': 0.6614379492811502, 'epoch': 5.0})

In [47]:
student_model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [48]:
trainer.evaluate(test)

{'eval_loss': 0.558600902557373,
 'eval_accuracy': 0.6936,
 'eval_precision': 0.6980308409071143,
 'eval_recall': 0.6936,
 'eval_f1': 0.6910105542765418,
 'eval_runtime': 23.849,
 'eval_samples_per_second': 419.305,
 'eval_steps_per_second': 3.313,
 'epoch': 5.0}

### Trénink inicializovaného modelu s pomocí destilace

In [49]:
base.reset_seed()

In [50]:
student_model_pretrained_whole = base.get_mobilenet(10)

In [51]:
training_args = base.get_training_args(output_dir="./results/cifar10-pretrained-KD", logging_dir='./logs/cifar10-pretrained-KD', remove_unused_columns=False, temp=6, lambda_param=1)

In [52]:
trainer = base.DistilTrainer(
    student_model=student_model_pretrained_whole.to(device),
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [53]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.254,0.136995,0.8641,0.877378,0.864093,0.864102
2,0.103,0.10127,0.9046,0.908879,0.904415,0.90567
3,0.0755,0.093441,0.9048,0.907944,0.904907,0.904977
4,0.0631,0.081664,0.9063,0.912641,0.906672,0.906659
5,0.0565,0.081574,0.9129,0.915123,0.913242,0.913046


TrainOutput(global_step=1565, training_loss=0.11042473887483152, metrics={'train_runtime': 641.3483, 'train_samples_per_second': 311.843, 'train_steps_per_second': 2.44, 'total_flos': 4.040199217152e+17, 'train_loss': 0.11042473887483152, 'epoch': 5.0})

In [54]:
student_model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [55]:
trainer.evaluate(test)

{'eval_loss': 0.08232155442237854,
 'eval_accuracy': 0.9105,
 'eval_precision': 0.912371019020204,
 'eval_recall': 0.9104999999999999,
 'eval_f1': 0.9104505478792004,
 'eval_runtime': 23.1993,
 'eval_samples_per_second': 431.048,
 'eval_steps_per_second': 3.405,
 'epoch': 5.0}