# Notebook pro trénink s destilací nad datasetem CIFAR10
V tomto notebooku je trénován MobileNetV2 nad datasetem CIFAR10, jako učitelsý model je využíván finetunued ViT nad stejným datasetem. 

MobileNetV2 je používán s náhodnou inicializací, tréninkem pouze klasifikační hlavy inicializovaného (předtrénovaného nad ImageNetem) MobileNetuV2 a trénink celého modelu, taktéž inicializovaného. Tyto tři úlohy jsou trénovány bězným způsobem a také s pomocí destilace výše zmíněného modelu.  

Při destilaci je využíváno předpočítaných logitů ze sešitu precompute_logits.

In [4]:
%pip install transformers[torch] huggingface_hub datasets evaluate torchvision

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)
Installing collected packages: evaluate
Successfully installed evaluate-0.4.3

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


## Import knihoven a definice metod

In [29]:
from transformers import Trainer, TrainingArguments, MobileNetV2Config, MobileNetV2ForImageClassification, AutoModelForImageClassification
from torchvision import transforms
from torch.utils.data import Dataset
import torch.nn.functional as F
from PIL import Image
import torch.nn as nn
import numpy as np
import evaluate
import random
import pickle
import torch
import os

Resetování náhodného seedu pro replikovatelnost výsledků.
Zřejmě je možné části odebrat.

TODO: Odebrat zbytečná nastavení.

In [None]:
def reset_seed(seed=42):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

Nový wrapper, který pracuje přímo se soubory staženého a upraveného datasetu CIFAR10.
Využití načtení pomocí metody jako dříve není možné kvůli jiné checksum. 

Zároveň se již dotahují logity přímo z datasetu.

In [30]:
class CustomCIFAR10(Dataset):
    def __init__(self, root, train=True, transform=None, target_transform=None):
        self.root = root
        self.train = train
        self.transform = transform
        self.target_transform = target_transform

        self.data = []
        self.targets = []
        self.logits = []
        
        if self.train:
             for i in range(1, 6):
                 data_file = os.path.join(self.root, 'cifar-10-batches-py', f'data_batch_{i}')
                 with open(data_file, 'rb') as fo:
                     dict = pickle.load(fo, encoding='bytes')
                     self.data.append(dict[b'data'])
                     self.targets.extend(dict[b'labels'])
                     self.logits.extend(dict[b'logits'])  
        else:
            data_file = os.path.join(self.root, 'cifar-10-batches-py', 'test_batch')
            with open(data_file, 'rb') as fo:
                dict = pickle.load(fo, encoding='bytes')
                self.data.append(dict[b'data'])
                self.targets.extend(dict[b'labels'])
                self.logits.extend(dict[b'logits'])  

        self.data = np.concatenate(self.data, axis=0)
        self.targets = np.array(self.targets)
        self.logits = np.array(self.logits)


    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image = self.data[index].reshape(3, 32, 32).transpose(1, 2, 0)
        label = self.targets[index]
        logit = self.logits[index]
        
        image = Image.fromarray(image.astype('uint8'), 'RGB')
        logit = torch.tensor(logit, dtype=torch.float)
        if self.transform:
            image = self.transform(image)

        if self.target_transform:
            target = self.target_transform(target)
            
        return {
            'pixel_values': image,
            'labels': label,
            'logits': logit
        }


Definice accuracy metriky pro trénování modelu.

In [None]:
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    acc = accuracy.compute(references=labels, predictions=np.argmax(predictions, axis=1))
    return {"accuracy": acc["accuracy"]}

Trénovací argumenty pro trainer. 

In [None]:
def get_training_args(output_dir, logging_dir, remove_unused_columns):
    return (
        TrainingArguments(
        output_dir=output_dir,
        eval_strategy="epoch",
        save_strategy="epoch",
        learning_rate=5e-5, #Defaultní hodnota 
        per_device_train_batch_size=64,
        per_device_eval_batch_size=64,
        num_train_epochs=20,
        weight_decay=0.01,
        seed = 42,  #Defaultní hodnota 
        metric_for_best_model="accuracy",
        load_best_model_at_end=True,
        fp16=True, 
        logging_dir=logging_dir,
        remove_unused_columns=remove_unused_columns,
    ))


Náhodně inicializovaný MobileNetV2.

In [35]:
def get_random_init_mobilenet():
    reset_seed(42)
    student_config = MobileNetV2Config()
    student_config.num_labels = 10
    return MobileNetV2ForImageClassification(student_config)

Zamražení modelu a trénink pouze klasifikační hlavy.

In [None]:
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False

    for param in model.classifier.parameters():
        param.requires_grad = True

Inicializovaný MobileNetV2.

In [None]:
def get_mobilenet():
    model_pretrained = MobileNetV2ForImageClassification.from_pretrained("google/mobilenet_v2_1.0_224")
    in_features = model_pretrained.classifier.in_features

    model_pretrained.classifier = nn.Linear(in_features,10) #Úprava klasifikační hlavy
    model_pretrained.num_labels = 10
    model_pretrained.config.num_labels = 10

    return model_pretrained

In [None]:
reset_seed(42)

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

Provedení transformací nad datasetem.

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
test = CustomCIFAR10(root='./data/10-logits', train=False, transform=transform)
train = CustomCIFAR10(root='./data/10-logits', train=True, transform=transform)

### Standardní trénink náhodně inicializovaného modelu. 

In [None]:
training_args = get_training_args("./results/cifar10-random", './logs/cifar10-random', True)
model = get_random_init_mobilenet()

In [38]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train,
    eval_dataset=test,
    compute_metrics=compute_metrics,
)

In [39]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,1.9994,1.465344,0.4619
2,1.338,1.238446,0.5633
3,1.1357,1.040728,0.6262
4,0.8884,0.949196,0.6778
5,0.7736,0.940272,0.6785
6,0.6184,0.835595,0.7268
7,0.5585,1.199692,0.6553
8,0.4315,1.572898,0.6026
9,0.356,0.850076,0.7435
10,0.2889,0.897842,0.7531


TrainOutput(global_step=15640, training_loss=0.45282880482466326, metrics={'train_runtime': 2628.6583, 'train_samples_per_second': 380.422, 'train_steps_per_second': 5.95, 'total_flos': 2.020099608576e+18, 'train_loss': 0.45282880482466326, 'epoch': 20.0})

In [40]:
model.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [41]:
trainer.evaluate()

{'eval_loss': 0.8750341534614563,
 'eval_accuracy': 0.756,
 'eval_runtime': 12.619,
 'eval_samples_per_second': 792.457,
 'eval_steps_per_second': 12.442,
 'epoch': 20.0}

## Definice destilačního tréninku

Třída, která upravuje hugging face trenéra pro destilaci znalostí. Nově pracuje s logity uloženými v datasetu.

In [42]:
class ImageDistilTrainer(Trainer):
    def __init__(self, student_model=None, temperature=None, lambda_param=None, *args, **kwargs):
        super().__init__(model=student_model, *args, **kwargs)
        self.student = student_model
        self.loss_function = nn.KLDivLoss(reduction="batchmean")
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.temperature = temperature
        self.lambda_param = lambda_param



    def compute_loss(self, student, inputs, return_outputs=False, num_items_in_batch=None):
        logits = inputs.pop("logits")

        student_output = self.student(**inputs)

        soft_teacher = F.softmax(logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_output.logits / self.temperature, dim=-1)


        distillation_loss = self.loss_function(soft_student, soft_teacher) * (self.temperature ** 2)


        student_target_loss = student_output.loss

        loss = ((1. - self.lambda_param) * student_target_loss + self.lambda_param * distillation_loss)
        return (loss, student_output) if return_outputs else loss

### Trénink náhodně inicializovaného modelu s pomocí destilace znalostí

In [43]:
reset_seed(42)

In [44]:
student_model = get_random_init_mobilenet()

In [45]:
training_args = get_training_args("./results/cifar10-random-KD", './logs/cifar10-random-KD', False)

In [46]:
trainer = ImageDistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=test,
    compute_metrics=compute_metrics,
    temperature = 5,
    lambda_param = 0.3
)

In [47]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,1.7047,1.573836,0.423
2,1.191,1.402557,0.5728
3,1.0417,1.290165,0.6502
4,0.844,1.284608,0.6738
5,0.7506,1.362965,0.6422
6,0.6409,1.242605,0.6926
7,0.5949,1.382708,0.6663
8,0.4986,1.37808,0.6323
9,0.4486,1.273678,0.7031
10,0.3996,1.176472,0.7561


TrainOutput(global_step=15640, training_loss=0.5201061487807642, metrics={'train_runtime': 2702.9095, 'train_samples_per_second': 369.972, 'train_steps_per_second': 5.786, 'total_flos': 2.020099608576e+18, 'train_loss': 0.5201061487807642, 'epoch': 20.0})

In [48]:
student_model.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [49]:
trainer.evaluate()

{'eval_loss': 1.1884382963180542,
 'eval_accuracy': 0.769,
 'eval_runtime': 12.707,
 'eval_samples_per_second': 786.971,
 'eval_steps_per_second': 12.355,
 'epoch': 20.0}

## Získání inicializovaného MobileNetV2 modelu

In [50]:
reset_seed(42)

In [51]:
model_pretrained = get_mobilenet()

In [52]:
print(model_pretrained)

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [53]:
freeze_model(model_pretrained)

In [54]:
training_args = get_training_args("./results/cifar10-pretrained-head", './logs/cifar10-pretrained-head', True)

In [55]:
trainer = Trainer(
    model=model_pretrained,
    args=training_args,
    train_dataset=train,
    eval_dataset=test,
    compute_metrics=compute_metrics,
)

In [56]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,1.8904,1.359444,0.6308
2,1.1226,1.207275,0.6241
3,1.0168,0.97121,0.7096
4,0.9041,0.907923,0.7149
5,0.8768,0.984846,0.6807
6,0.8324,0.870077,0.7219
7,0.8274,1.021028,0.6552
8,0.8021,1.105069,0.6294
9,0.7919,0.86678,0.7069
10,0.785,0.874822,0.7126


TrainOutput(global_step=15640, training_loss=0.8626149301943572, metrics={'train_runtime': 1631.6784, 'train_samples_per_second': 612.866, 'train_steps_per_second': 9.585, 'total_flos': 2.020099608576e+18, 'train_loss': 0.8626149301943572, 'epoch': 20.0})

In [57]:
model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [58]:
trainer.evaluate()

{'eval_loss': 0.7993389368057251,
 'eval_accuracy': 0.7356,
 'eval_runtime': 12.7355,
 'eval_samples_per_second': 785.204,
 'eval_steps_per_second': 12.328,
 'epoch': 20.0}

### Trénink inicializovaného MobileNetV2

In [59]:
reset_seed(42)

In [60]:
model_pretrained_whole = get_mobilenet()

In [61]:
training_args = get_training_args("./results/cifar10-pretrained", './logs/cifar10-pretrained', True)

In [62]:
trainer = Trainer(
    model=model_pretrained_whole,
    args=training_args,
    train_dataset=train,
    eval_dataset=test,
    compute_metrics=compute_metrics,
)

In [63]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.6678,0.352117,0.8815
2,0.1829,0.473985,0.8552
3,0.1099,0.307834,0.9086
4,0.0515,0.306266,0.9136
5,0.0356,0.562306,0.8742
6,0.016,0.446769,0.9061
7,0.0135,0.61808,0.8736
8,0.0084,0.747778,0.854
9,0.0059,0.461412,0.9117
10,0.0038,0.43447,0.9198


TrainOutput(global_step=15640, training_loss=0.048249347503666225, metrics={'train_runtime': 3023.316, 'train_samples_per_second': 330.763, 'train_steps_per_second': 5.173, 'total_flos': 2.020099608576e+18, 'train_loss': 0.048249347503666225, 'epoch': 20.0})

In [64]:
model_pretrained_whole.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [65]:
trainer.evaluate()

{'eval_loss': 0.4153886139392853,
 'eval_accuracy': 0.9295,
 'eval_runtime': 19.3723,
 'eval_samples_per_second': 516.201,
 'eval_steps_per_second': 8.104,
 'epoch': 20.0}

## Trénink s pomocí destilace znalostí inicializovaného MobileNetV2

### Trénink inicializovaného modelu - pouze klasifikační hlavy s pomocí destilace

In [66]:
reset_seed(42)

In [67]:
student_model_pretrained = get_mobilenet()

In [68]:
freeze_model(student_model_pretrained)

In [69]:
training_args = get_training_args("./results/cifar10-pretrained-head-KD", './logs/cifar10-pretrained-head-KD', False)

In [70]:
trainer = ImageDistilTrainer(
    student_model=student_model_pretrained,
    args=training_args,
    train_dataset=train,
    eval_dataset=test,
    compute_metrics=compute_metrics,
    temperature = 5,
    lambda_param = 0.6
)

In [71]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,1.1942,1.295173,0.638
2,0.773,1.352618,0.6388
3,0.7344,1.305183,0.7129
4,0.6947,1.334172,0.7149
5,0.6873,1.368952,0.6847
6,0.6717,1.329848,0.7203
7,0.6704,1.372895,0.6647
8,0.6614,1.381742,0.6395
9,0.6585,1.350707,0.7035
10,0.6571,1.364204,0.7202


TrainOutput(global_step=15640, training_loss=0.688703476498499, metrics={'train_runtime': 2465.457, 'train_samples_per_second': 405.604, 'train_steps_per_second': 6.344, 'total_flos': 2.020099608576e+18, 'train_loss': 0.688703476498499, 'epoch': 20.0})

In [72]:
student_model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [73]:
trainer.evaluate()

{'eval_loss': 1.321333408355713,
 'eval_accuracy': 0.7356,
 'eval_runtime': 22.1336,
 'eval_samples_per_second': 451.803,
 'eval_steps_per_second': 7.093,
 'epoch': 20.0}

### Trénink inicializovaného modelu s pomocí destilace

In [74]:
reset_seed(42)

In [75]:
student_model_pretrained_whole = get_mobilenet()

In [76]:
training_args = get_training_args("./results/cifar10-pretrained-KD", './logs/cifar10-pretrained-KD', False)

In [77]:
trainer = ImageDistilTrainer(
    student_model=student_model_pretrained_whole.to(device),
    args=training_args,
    train_dataset=train,
    eval_dataset=test,
    compute_metrics=compute_metrics,
    temperature = 5,
    lambda_param = 0.4
)

In [78]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.5882,1.11303,0.8868
2,0.2465,1.138879,0.8644
3,0.1968,1.028809,0.9216
4,0.1559,0.991875,0.9252
5,0.1398,1.172694,0.8737
6,0.1223,1.053637,0.9134
7,0.118,1.166352,0.8826
8,0.1099,1.100621,0.878
9,0.1057,1.038172,0.9194
10,0.1042,0.995655,0.9324


TrainOutput(global_step=15640, training_loss=0.13773624153088426, metrics={'train_runtime': 3402.8965, 'train_samples_per_second': 293.867, 'train_steps_per_second': 4.596, 'total_flos': 2.020099608576e+18, 'train_loss': 0.13773624153088426, 'epoch': 20.0})

In [79]:
student_model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [80]:
trainer.evaluate()

{'eval_loss': 0.9895767569541931,
 'eval_accuracy': 0.936,
 'eval_runtime': 18.7393,
 'eval_samples_per_second': 533.637,
 'eval_steps_per_second': 8.378,
 'epoch': 20.0}