# Notebook pro trénink s destilací nad datasetem CIFAR10
V tomto notebooku je trénován mobilenet nad datasetem CIFAR10, jako učitelsý model je využíván finetunued ViT nad stejným datasetem. 

MobileNetV2 je používán s náhodnou inicializací, tréninkem pouze klasifikační hlavy inicializovaného (předtrénovaného nad ImageNetem) MobileNetuV2 a trénink celého modelu, taktéž inicializovaného. Tyto tři úlohy jsou trénovány bězným způsobem a také s pomocí destilace výše zmíněného modelu.  

In [1]:
%pip install transformers[torch]==4.45.2 huggingface_hub datasets evaluate torchvision

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


## Import knihoven a definice metod

In [2]:
from transformers import MobileNetV2Config, MobileNetV2ForImageClassification, AutoModelForImageClassification, Trainer, TrainingArguments
from torchvision import transforms, datasets as dataset
from torch.utils.data import Dataset
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import evaluate
import random
import torch
import os

Resetování náhodného seedu pro replikovatelnost výsledků.
Zřejmě je možné části odebrat.

TODO: Odebrat zbytečná nastavení.

In [3]:
def reset_seed(seed=42):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

Wrapper pro dataset CIFAR10 stažený skrze torchvision, který jej umožňuje předat hugging face traineru.

In [4]:
class CIFAR10HFDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __getitem__(self, idx): #Trainer potřebuje dict s těmito hodnotami ... 
        image, label = self.dataset[idx]
        return {
            'pixel_values': image,
            'labels': label
        }

    def __len__(self):
        return len(self.dataset)

Definice accuracy metriky pro trénování modelu.

In [5]:
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    acc = accuracy.compute(references=labels, predictions=np.argmax(predictions, axis=1))
    return {"accuracy": acc["accuracy"]}

Trénovací argumenty pro trainer. 

In [6]:
def get_training_args(output_dir, logging_dir):
    return (
        TrainingArguments(
        output_dir=output_dir,
        eval_strategy="epoch",
        save_strategy="epoch",
        learning_rate=5e-5, #Defaultní hodnota 
        per_device_train_batch_size=64,
        per_device_eval_batch_size=64,
        num_train_epochs=20,
        weight_decay=0.01,
        seed = 42,  #Defaultní hodnota 
        metric_for_best_model="accuracy",
        load_best_model_at_end=True,
        fp16=True, 
        logging_dir=logging_dir,
    ))


Náhodně inicializovaný MobileNetV2.

In [7]:
def get_random_init_mobilenet():
    reset_seed(42)
    student_config = MobileNetV2Config()
    student_config.num_labels = 10
    return MobileNetV2ForImageClassification(student_config)

Zamražení modelu a trénink pouze klasifikační hlavy.

In [8]:
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False

    for param in model.classifier.parameters():
        param.requires_grad = True

Inicializovaný MobileNetV2.

In [9]:
def get_mobilenet():
    model_pretrained = MobileNetV2ForImageClassification.from_pretrained("google/mobilenet_v2_1.0_224")
    in_features = model_pretrained.classifier.in_features

    model_pretrained.classifier = nn.Linear(in_features,10) #Úprava klasifikační hlavy
    model_pretrained.num_labels = 10
    model_pretrained.config.num_labels = 10

    return model_pretrained

In [10]:
reset_seed(42)

In [11]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA A100 80GB PCIe MIG 2g.20gb


Stažení/ověření stažení datasetu CIFAR10, provedení transformací (rozlišení, které odpovídá ImageNetu a učitelskému modelu). Následuje wrapnutím pro hugging face.

In [12]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])


train = dataset.CIFAR10(root='./data/10', train=True, download=True, transform=transform) #Stáhne dataset nebo ověťí jeho uložení v daném místě.
test = dataset.CIFAR10(root='./data/10', train=False, download=True, transform=transform)

train_dataset_hf = CIFAR10HFDataset(train)
test_dataset_hf = CIFAR10HFDataset(test)

Files already downloaded and verified
Files already downloaded and verified


### Standardní trénink náhodně inicializovaného modelu. 

In [13]:
training_args = get_training_args("./results/cifar10-random", './logs/cifar10-random')
model = get_random_init_mobilenet()

In [14]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset_hf,
    eval_dataset=test_dataset_hf,
    compute_metrics=compute_metrics,
)

In [15]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,1.9994,1.465344,0.4619
2,1.338,1.238446,0.5633
3,1.1357,1.040728,0.6262
4,0.8884,0.949196,0.6778
5,0.7736,0.940272,0.6785
6,0.6184,0.835595,0.7268
7,0.5585,1.199692,0.6553
8,0.4315,1.572898,0.6026
9,0.356,0.850076,0.7435
10,0.2889,0.897842,0.7531


TrainOutput(global_step=15640, training_loss=0.45282880482466326, metrics={'train_runtime': 5147.5413, 'train_samples_per_second': 194.268, 'train_steps_per_second': 3.038, 'total_flos': 2.020099608576e+18, 'train_loss': 0.45282880482466326, 'epoch': 20.0})

In [16]:
model.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [17]:
trainer.evaluate()

{'eval_loss': 0.8750341534614563,
 'eval_accuracy': 0.756,
 'eval_runtime': 27.1409,
 'eval_samples_per_second': 368.448,
 'eval_steps_per_second': 5.785,
 'epoch': 20.0}

## Definice destilačního tréninku

Třída, která upravuje hugging face trenéra pro destilaci znalostí. Vychází z templatu nad datasetem Beans. 

In [18]:
class ImageDistilTrainer(Trainer):
    def __init__(self, teacher_model=None, student_model=None, temperature=None, lambda_param=None, *args, **kwargs):
        super().__init__(model=student_model, *args, **kwargs)
        self.teacher = teacher_model
        self.student = student_model
        self.loss_function = nn.KLDivLoss(reduction="batchmean")
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.teacher.to(device)
        self.teacher.eval()
        self.temperature = temperature
        self.lambda_param = lambda_param



    def compute_loss(self, student, inputs, return_outputs=False):
        student_output = self.student(**inputs)

        with torch.no_grad():
          teacher_output = self.teacher(**inputs)


        soft_teacher = F.softmax(teacher_output.logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_output.logits / self.temperature, dim=-1)

        
        distillation_loss = self.loss_function(soft_student, soft_teacher) * (self.temperature ** 2)

 
        student_target_loss = student_output.loss


        loss = (1. - self.lambda_param) * student_target_loss + self.lambda_param * distillation_loss
        return (loss, student_output) if return_outputs else loss

### Trénink náhodně inicializovaného modelu s pomocí destilace znalostí

In [19]:
reset_seed(42)

In [20]:
teacher_model = AutoModelForImageClassification.from_pretrained(
    "aaraki/vit-base-patch16-224-in21k-finetuned-cifar10",
    num_labels=10
)

student_model = get_random_init_mobilenet()

In [21]:
training_args = get_training_args("./results/cifar10-random-KD", './logs/cifar10-random-KD')

In [22]:
trainer = ImageDistilTrainer(
    student_model=student_model,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=train_dataset_hf,
    eval_dataset=test_dataset_hf,
    compute_metrics=compute_metrics,
    temperature = 5,
    lambda_param = 0.6
)

In [23]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,1.3775,1.069836,0.4218
2,0.9434,0.865689,0.5497
3,0.8097,0.713972,0.6428
4,0.6529,0.641991,0.6835
5,0.5842,0.769138,0.6297
6,0.4944,0.567288,0.7248
7,0.4616,0.726513,0.6638
8,0.3935,0.85648,0.6029
9,0.3507,0.595236,0.7191
10,0.315,0.493688,0.7723


TrainOutput(global_step=15640, training_loss=0.41463849806724606, metrics={'train_runtime': 10542.3033, 'train_samples_per_second': 94.856, 'train_steps_per_second': 1.484, 'total_flos': 2.020099608576e+18, 'train_loss': 0.41463849806724606, 'epoch': 20.0})

In [24]:
student_model.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [25]:
trainer.evaluate()

{'eval_loss': 0.5087332725524902,
 'eval_accuracy': 0.7792,
 'eval_runtime': 67.221,
 'eval_samples_per_second': 148.763,
 'eval_steps_per_second': 2.336,
 'epoch': 20.0}

## Získání inicializovaného MobileNetV2 modelu

In [26]:
reset_seed(42)

In [27]:
model_pretrained = get_mobilenet()

In [28]:
print(model_pretrained)

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

### Trénink pouze klasifikační hlavy MobileNetV2

In [29]:
freeze_model(model_pretrained)

In [30]:
training_args = get_training_args("./results/cifar10-pretrained-head", './logs/cifar10-pretrained-head')

In [31]:
trainer = Trainer(
    model=model_pretrained,
    args=training_args,
    train_dataset=train_dataset_hf,
    eval_dataset=test_dataset_hf,
    compute_metrics=compute_metrics,
)

In [32]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,1.8904,1.359444,0.6308
2,1.1226,1.207275,0.6241
3,1.0168,0.97121,0.7096
4,0.9041,0.907923,0.7149
5,0.8768,0.984846,0.6807
6,0.8324,0.870077,0.7219
7,0.8274,1.021028,0.6552
8,0.8021,1.105069,0.6294
9,0.7919,0.86678,0.7069
10,0.785,0.874822,0.7126


TrainOutput(global_step=15640, training_loss=0.8626149301943572, metrics={'train_runtime': 3415.2469, 'train_samples_per_second': 292.805, 'train_steps_per_second': 4.579, 'total_flos': 2.020099608576e+18, 'train_loss': 0.8626149301943572, 'epoch': 20.0})

In [33]:
model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [34]:
trainer.evaluate()

{'eval_loss': 0.7993389368057251,
 'eval_accuracy': 0.7356,
 'eval_runtime': 26.4339,
 'eval_samples_per_second': 378.302,
 'eval_steps_per_second': 5.939,
 'epoch': 20.0}

### Trénink inicializovaného MobileNetV2

In [35]:
reset_seed(42)

In [36]:
model_pretrained_whole = get_mobilenet()

In [37]:
training_args = get_training_args("./results/cifar10-pretrained", './logs/cifar10-pretrained')

In [38]:
trainer = Trainer(
    model=model_pretrained_whole,
    args=training_args,
    train_dataset=train_dataset_hf,
    eval_dataset=test_dataset_hf,
    compute_metrics=compute_metrics,
)

In [39]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.6678,0.352117,0.8815
2,0.1829,0.473985,0.8552
3,0.1099,0.307834,0.9086
4,0.0515,0.306266,0.9136
5,0.0356,0.562306,0.8742
6,0.016,0.446769,0.9061
7,0.0135,0.61808,0.8736
8,0.0084,0.747778,0.854
9,0.0059,0.461412,0.9117
10,0.0038,0.43447,0.9198


TrainOutput(global_step=15640, training_loss=0.048249347503666225, metrics={'train_runtime': 4725.7271, 'train_samples_per_second': 211.608, 'train_steps_per_second': 3.31, 'total_flos': 2.020099608576e+18, 'train_loss': 0.048249347503666225, 'epoch': 20.0})

In [40]:
model_pretrained_whole.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [41]:
trainer.evaluate()

{'eval_loss': 0.4153886139392853,
 'eval_accuracy': 0.9295,
 'eval_runtime': 29.9741,
 'eval_samples_per_second': 333.621,
 'eval_steps_per_second': 5.238,
 'epoch': 20.0}

## Trénink s pomocí destilace znalostí inicializovaného MobileNetV2

### Trénink inicializovaného modelu - pouze klasifikační hlavy s pomocí destilace

In [42]:
reset_seed(42)

In [43]:
student_model_pretrained = get_mobilenet()

In [44]:
freeze_model(student_model_pretrained)

In [45]:
training_args = get_training_args("./results/cifar10-pretrained-head-KD", './logs/cifar10-pretrained-head-KD')

In [46]:
trainer = ImageDistilTrainer(
    student_model=student_model_pretrained,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=train_dataset_hf,
    eval_dataset=test_dataset_hf,
    compute_metrics=compute_metrics,
    temperature = 5,
    lambda_param = 0.5
)

In [47]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,1.3103,0.962149,0.6373
2,0.8388,0.903257,0.6367
3,0.7938,0.769351,0.7129
4,0.7484,0.747497,0.7146
5,0.7396,0.805449,0.6852
6,0.7223,0.742221,0.7211
7,0.7209,0.832788,0.6644
8,0.7109,0.877352,0.6402
9,0.7076,0.753462,0.7047
10,0.7059,0.75036,0.7211


TrainOutput(global_step=15640, training_loss=0.7418107508393504, metrics={'train_runtime': 7600.697, 'train_samples_per_second': 131.567, 'train_steps_per_second': 2.058, 'total_flos': 2.020099608576e+18, 'train_loss': 0.7418107508393504, 'epoch': 20.0})

In [48]:
student_model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [49]:
trainer.evaluate()

{'eval_loss': 0.7156311273574829,
 'eval_accuracy': 0.7372,
 'eval_runtime': 31.4214,
 'eval_samples_per_second': 318.254,
 'eval_steps_per_second': 4.997,
 'epoch': 20.0}

### Trénink inicializovaného modelu s pomocí destilace

In [50]:
reset_seed(42)

In [51]:
student_model_pretrained_whole = get_mobilenet()

In [52]:
training_args = get_training_args("./results/cifar10-pretrained-KD", './logs/cifar10-pretrained-KD')

In [53]:
trainer = ImageDistilTrainer(
    student_model=student_model_pretrained_whole,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=train_dataset_hf,
    eval_dataset=test_dataset_hf,
    compute_metrics=compute_metrics,
    temperature = 5,
    lambda_param = 0.5
)

In [54]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.5453,0.344106,0.8853
2,0.2293,0.371379,0.8672
3,0.1839,0.247473,0.9193
4,0.1462,0.227526,0.9242
5,0.1315,0.372273,0.8768
6,0.1161,0.256997,0.9144
7,0.1118,0.344152,0.8827
8,0.1046,0.333011,0.8802
9,0.1006,0.237853,0.9223
10,0.099,0.211038,0.9314


TrainOutput(global_step=15640, training_loss=0.12979280393751685, metrics={'train_runtime': 4868.7591, 'train_samples_per_second': 205.391, 'train_steps_per_second': 3.212, 'total_flos': 2.020099608576e+18, 'train_loss': 0.12979280393751685, 'epoch': 20.0})

In [55]:
student_model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [56]:
trainer.evaluate()

{'eval_loss': 0.1964949071407318,
 'eval_accuracy': 0.9369,
 'eval_runtime': 31.3448,
 'eval_samples_per_second': 319.032,
 'eval_steps_per_second': 5.009,
 'epoch': 20.0}