# Notebook pro trénink s destilací nad datasetem CIFAR100
V tomto notebooku je trénován MobileNetV2 nad datasetem CIFAR100, jako učitelsý model je využíván finetunued ViT nad stejným datasetem. 

MobileNetV2 je používán s náhodnou inicializací, tréninkem pouze klasifikační hlavy inicializovaného (předtrénovaného nad ImageNetem) MobileNetuV2 a trénink celého modelu, taktéž inicializovaného. Tyto tři úlohy jsou trénovány bězným způsobem a také s pomocí destilace výše zmíněného modelu.  

Při destilaci je využíváno předpočítaných logitů ze sešitu precompute_logits.

In [1]:
%pip install transformers[torch] huggingface_hub datasets evaluate torchvision

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


## Import knihoven a definice metod

In [1]:
from transformers import Trainer, EarlyStoppingCallback
from torch.utils.data import ConcatDataset
import torch
import base

In [2]:
dataset_part = base.get_dataset_part()

Inicializovaný MobileNetV2.

In [3]:
base.reset_seed()

In [4]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA A100 80GB PCIe MIG 2g.20gb


Provedení transformací nad datasetem.

In [5]:
transform = base.base_transforms()

train = base.CustomCIFAR100L(root='./data/100-logits', dataset_part=dataset_part.TRAIN, transform=transform)
eval = base.CustomCIFAR100L(root='./data/100-logits', dataset_part=dataset_part.EVAL, transform=transform)
test = base.CustomCIFAR100L(root='./data/100-logits', dataset_part=dataset_part.TEST, transform=transform)


In [6]:
augment_transform = base.aug_transforms()

train_aug = base.CustomCIFAR100L(root='./data/100-logits', dataset_part=dataset_part.TRAIN, transform=augment_transform)

In [7]:
train_aug = base.remove_diff_pred_class(train, train_aug)
train_combo = ConcatDataset([train, train_aug])

### Standardní trénink náhodně inicializovaného modelu. 

In [8]:
training_args = base.get_training_args(output_dir="./results/cifar100-random", logging_dir='./logs/cifar100-random', lr=0.0005, epochs=30)
model = base.get_random_init_mobilenet(100)

In [9]:
base.count_parameters(model)

model size: 9.103MB.
Total Trainable Params: 2351972.


Unnamed: 0,Modules,Parameters
0,mobilenet_v2.conv_stem.first_conv.convolution....,864
1,mobilenet_v2.conv_stem.first_conv.normalizatio...,32
2,mobilenet_v2.conv_stem.first_conv.normalizatio...,32
3,mobilenet_v2.conv_stem.conv_3x3.convolution.we...,288
4,mobilenet_v2.conv_stem.conv_3x3.normalization....,32
...,...,...
153,mobilenet_v2.conv_1x1.convolution.weight,409600
154,mobilenet_v2.conv_1x1.normalization.weight,1280
155,mobilenet_v2.conv_1x1.normalization.bias,1280
156,classifier.weight,128000


In [10]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 5)]
)

In [11]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,3.9937,3.417222,0.1572,0.141172,0.1572,0.12413
2,3.2137,2.752205,0.2794,0.283398,0.2794,0.250506
3,2.7377,2.390321,0.3625,0.360134,0.3625,0.342836
4,2.3884,2.153566,0.414,0.414411,0.414,0.396061
5,2.1141,1.999104,0.4506,0.458645,0.4506,0.43766
6,1.8977,1.868613,0.4857,0.49205,0.4857,0.474608
7,1.7046,1.77753,0.51,0.517645,0.51,0.50082
8,1.5522,1.711429,0.5271,0.533046,0.5271,0.519891
9,1.3946,1.699551,0.5318,0.539368,0.5318,0.5274
10,1.2475,1.662014,0.543,0.543743,0.543,0.536978


TrainOutput(global_step=8755, training_loss=1.6272482653606422, metrics={'train_runtime': 5010.951, 'train_samples_per_second': 394.608, 'train_steps_per_second': 3.083, 'total_flos': 2.380203558798557e+18, 'train_loss': 1.6272482653606422, 'epoch': 17.0})

In [12]:
model.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [13]:
trainer.evaluate(test)

{'eval_loss': 1.623160719871521,
 'eval_accuracy': 0.5606,
 'eval_precision': 0.5716917684856422,
 'eval_recall': 0.5606,
 'eval_f1': 0.5604123533731383,
 'eval_runtime': 21.4916,
 'eval_samples_per_second': 465.298,
 'eval_steps_per_second': 3.676,
 'epoch': 17.0}

## Definice destilačního tréninku

Třída, která upravuje hugging face trenéra pro destilaci znalostí. Nově pracuje s logity uloženými v datasetu.

### Trénink náhodně inicializovaného modelu s pomocí destilace znalostí

In [14]:
base.reset_seed()

In [15]:
student_model = base.get_random_init_mobilenet(100)

In [16]:
training_args = base.get_training_args(output_dir="./results/cifar100-random-KD", logging_dir='./logs/cifar100-random-KD', remove_unused_columns=False, epochs=30, lr=0.00047, lambda_param=.75, temp=6)

In [17]:
trainer = base.ImageDistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train_combo,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 5)]
)

In [18]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.9811,2.711757,0.1693,0.18017,0.1693,0.128997
2,2.4962,2.241964,0.3007,0.30643,0.3007,0.25876
3,2.1712,1.985207,0.3811,0.39528,0.3811,0.355237
4,1.9303,1.807191,0.4362,0.449677,0.4362,0.417526
5,1.7486,1.677898,0.4763,0.492969,0.4763,0.457196
6,1.6051,1.547201,0.516,0.528934,0.516,0.501914
7,1.4865,1.492544,0.5303,0.542028,0.5303,0.519945
8,1.381,1.423602,0.5512,0.561862,0.5512,0.542417
9,1.2844,1.39642,0.5521,0.568849,0.5521,0.548224
10,1.199,1.336429,0.5771,0.581431,0.5771,0.571775


TrainOutput(global_step=10300, training_loss=1.3414556825508193, metrics={'train_runtime': 5416.7574, 'train_samples_per_second': 365.045, 'train_steps_per_second': 2.852, 'total_flos': 2.800239480939479e+18, 'train_loss': 1.3414556825508193, 'epoch': 20.0})

In [19]:
student_model.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [20]:
trainer.evaluate(test)

{'eval_loss': 1.1311068534851074,
 'eval_accuracy': 0.6051,
 'eval_precision': 0.6187711296715633,
 'eval_recall': 0.6051000000000001,
 'eval_f1': 0.6058615769375616,
 'eval_runtime': 20.7986,
 'eval_samples_per_second': 480.802,
 'eval_steps_per_second': 3.798,
 'epoch': 20.0}

## Získání inicializovaného MobileNetV2 modelu

In [29]:
base.reset_seed()

In [4]:
model_pretrained = base.get_mobilenet(100)

In [5]:
base.count_parameters(model_pretrained)

model size: 9.103MB
Total Trainable Params: 2351972


Unnamed: 0,Modules,Parameters
0,mobilenet_v2.conv_stem.first_conv.convolution....,864
1,mobilenet_v2.conv_stem.first_conv.normalizatio...,32
2,mobilenet_v2.conv_stem.first_conv.normalizatio...,32
3,mobilenet_v2.conv_stem.conv_3x3.convolution.we...,288
4,mobilenet_v2.conv_stem.conv_3x3.normalization....,32
...,...,...
153,mobilenet_v2.conv_1x1.convolution.weight,409600
154,mobilenet_v2.conv_1x1.normalization.weight,1280
155,mobilenet_v2.conv_1x1.normalization.bias,1280
156,classifier.weight,128000


In [31]:
print(model_pretrained)

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [32]:
model_pretrained = base.freeze_model(model_pretrained)

In [33]:
training_args = base.get_training_args(output_dir="./results/cifar100-pretrained-head", logging_dir='./logs/cifar100-pretrained-head')

In [34]:
trainer = Trainer(
    model=model_pretrained,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 4)]
)

In [35]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,4.3666,4.12372,0.1839,0.208631,0.1839,0.177796
2,3.9275,3.765841,0.3024,0.325632,0.3024,0.293932
3,3.6436,3.570344,0.3498,0.359684,0.3498,0.333879
4,3.4753,3.458225,0.3587,0.382333,0.3587,0.344111
5,3.3925,3.428162,0.3769,0.382731,0.3769,0.359658


TrainOutput(global_step=1565, training_loss=3.761084686376797, metrics={'train_runtime': 469.9788, 'train_samples_per_second': 425.551, 'train_steps_per_second': 3.33, 'total_flos': 4.248451694592e+17, 'train_loss': 3.761084686376797, 'epoch': 5.0})

In [36]:
model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [37]:
trainer.evaluate(test)

{'eval_loss': 3.432467222213745,
 'eval_accuracy': 0.3773,
 'eval_precision': 0.3774894959232473,
 'eval_recall': 0.3772999999999999,
 'eval_f1': 0.35815750868611934,
 'eval_runtime': 18.3126,
 'eval_samples_per_second': 546.073,
 'eval_steps_per_second': 4.314,
 'epoch': 5.0}

### Trénink inicializovaného MobileNetV2

In [38]:
base.reset_seed()

In [39]:
model_pretrained_whole = base.get_mobilenet(100)

In [40]:
training_args = base.get_training_args(output_dir="./results/cifar100-pretrained", logging_dir='./logs/cifar100-pretrained')

In [41]:
trainer = Trainer(
    model=model_pretrained_whole,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 5)]
)

In [42]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,3.2751,2.070399,0.5244,0.543151,0.5244,0.508957
2,1.6438,1.473536,0.6097,0.625245,0.6097,0.605971
3,1.2066,1.26809,0.6559,0.6618,0.6559,0.651263
4,1.0111,1.259078,0.6489,0.670386,0.6489,0.648706
5,0.9242,1.154022,0.6805,0.685636,0.6805,0.674981


TrainOutput(global_step=1565, training_loss=1.61216064063124, metrics={'train_runtime': 585.1473, 'train_samples_per_second': 341.794, 'train_steps_per_second': 2.675, 'total_flos': 4.248451694592e+17, 'train_loss': 1.61216064063124, 'epoch': 5.0})

In [43]:
model_pretrained_whole.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [44]:
trainer.evaluate(test)

{'eval_loss': 1.1403955221176147,
 'eval_accuracy': 0.6805,
 'eval_precision': 0.6896003843930464,
 'eval_recall': 0.6805000000000001,
 'eval_f1': 0.6767723010859202,
 'eval_runtime': 17.6611,
 'eval_samples_per_second': 566.216,
 'eval_steps_per_second': 4.473,
 'epoch': 5.0}

## Trénink s pomocí destilace znalostí inicializovaného MobileNetV2

### Trénink inicializovaného modelu - pouze klasifikační hlavy s pomocí destilace

In [45]:
base.reset_seed()

In [46]:
student_model_pretrained = base.get_mobilenet(100)

In [47]:
student_model_pretrained = base.freeze_model(student_model_pretrained)

In [48]:
training_args = base.get_training_args(output_dir="./results/cifar100-pretrained-head-KD", logging_dir='./logs/cifar100-pretrained-head-KD', remove_unused_columns=False, temp=6, lambda_param=.8)

In [49]:
trainer = base.ImageDistilTrainer(
    student_model=student_model_pretrained,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 2)]
)

In [50]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,3.5178,3.355098,0.1797,0.220875,0.1797,0.168106
2,3.2409,3.142756,0.2831,0.346824,0.2831,0.279037
3,3.0797,3.033154,0.3392,0.379779,0.3392,0.325473
4,2.99,2.978981,0.3436,0.406908,0.3436,0.332064
5,2.9475,2.958369,0.3593,0.402591,0.3593,0.34664


TrainOutput(global_step=1565, training_loss=3.155180383261781, metrics={'train_runtime': 433.0658, 'train_samples_per_second': 461.824, 'train_steps_per_second': 3.614, 'total_flos': 4.248451694592e+17, 'train_loss': 3.155180383261781, 'epoch': 5.0})

In [51]:
student_model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [52]:
trainer.evaluate(test)

{'eval_loss': 2.81398868560791,
 'eval_accuracy': 0.3537,
 'eval_precision': 0.39045062851507895,
 'eval_recall': 0.35369999999999996,
 'eval_f1': 0.33894118699147735,
 'eval_runtime': 17.613,
 'eval_samples_per_second': 567.761,
 'eval_steps_per_second': 4.485,
 'epoch': 5.0}

### Trénink inicializovaného modelu s pomocí destilace

In [53]:
base.reset_seed()

In [54]:
student_model_pretrained_whole = base.get_mobilenet(100)

In [55]:
training_args = base.get_training_args("./results/cifar100-pretrained-KD", './logs/cifar100-pretrained-KD', remove_unused_columns=False, temp=6, lambda_param=1)

In [56]:
trainer = base.ImageDistilTrainer(
    student_model=student_model_pretrained_whole.to(device),
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [57]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.6861,2.036413,0.4577,0.530295,0.4577,0.426668
2,1.7625,1.579123,0.566,0.607231,0.566,0.547062
3,1.4056,1.371883,0.619,0.641029,0.619,0.603205
4,1.2316,1.354033,0.6157,0.645628,0.6157,0.605793
5,1.1537,1.25149,0.6543,0.673847,0.6543,0.64309


TrainOutput(global_step=1565, training_loss=1.647921840679912, metrics={'train_runtime': 533.0318, 'train_samples_per_second': 375.212, 'train_steps_per_second': 2.936, 'total_flos': 4.248451694592e+17, 'train_loss': 1.647921840679912, 'epoch': 5.0})

In [58]:
student_model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [59]:
trainer.evaluate(test)

{'eval_loss': 1.0503957271575928,
 'eval_accuracy': 0.651,
 'eval_precision': 0.6701687624603764,
 'eval_recall': 0.6510000000000001,
 'eval_f1': 0.6393124124046187,
 'eval_runtime': 13.9925,
 'eval_samples_per_second': 714.669,
 'eval_steps_per_second': 5.646,
 'epoch': 5.0}