# Some Models
> To avoid copy-pasta #2

In [None]:
# default_exp models

In [None]:
# exports
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn
import torch.optim
from torch import nn as nn
from torch.nn import functional as F, Module
from torch.utils.data import DataLoader, Dataset

from batchbald_redux.active_learning import RandomFixedLengthSampler
from batchbald_redux.black_box_model_training import train
from batchbald_redux.consistent_mc_dropout import (
    BayesianModule,
    ConsistentMCDropout,
    ConsistentMCDropout2d,
)

from batchbald_redux.model_optimizer_factory import ModelOptimizer, ModelOptimizerFactory

In [None]:
# exports
from batchbald_redux.trained_model import ModelTrainer, TrainedModel, TrainedBayesianModel


class BayesianMNISTCNN(BayesianModule):
    def __init__(self, num_classes=10):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv1_drop = ConsistentMCDropout2d()
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.conv2_drop = ConsistentMCDropout2d()
        self.fc1 = nn.Linear(1024, 128)
        self.fc1_drop = ConsistentMCDropout()
        self.fc2 = nn.Linear(128, num_classes)

    def mc_forward_impl(self, input: torch.Tensor):
        input = F.relu(F.max_pool2d(self.conv1_drop(self.conv1(input)), 2))
        input = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(input)), 2))
        input = input.view(-1, 1024)
        input = F.relu(self.fc1_drop(self.fc1(input)))
        input = self.fc2(input)
        input = F.log_softmax(input, dim=1)

        return input

In [None]:
BayesianMNISTCNN()

BayesianMNISTCNN(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv1_drop): ConsistentMCDropout2d(p=0.5)
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): ConsistentMCDropout2d(p=0.5)
  (fc1): Linear(in_features=1024, out_features=128, bias=True)
  (fc1_drop): ConsistentMCDropout(p=0.5)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

In [None]:
# exports


class BayesianMNISTCNN_EBM(BayesianModule):
    """Without Softmax."""

    def __init__(self, num_classes=10):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv1_drop = ConsistentMCDropout2d()
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.conv2_drop = ConsistentMCDropout2d()
        self.fc1 = nn.Linear(1024, 128)
        self.fc1_drop = ConsistentMCDropout()
        self.fc2 = nn.Linear(128, num_classes)

    def mc_forward_impl(self, input: torch.Tensor):
        input = F.relu(F.max_pool2d(self.conv1_drop(self.conv1(input)), 2))
        input = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(input)), 2))
        input = input.view(-1, 1024)
        input = F.relu(self.fc1_drop(self.fc1(input)))
        input = self.fc2(input)

        return input

In [None]:
BayesianMNISTCNN_EBM()

BayesianMNISTCNN_EBM(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv1_drop): ConsistentMCDropout2d(p=0.5)
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): ConsistentMCDropout2d(p=0.5)
  (fc1): Linear(in_features=1024, out_features=128, bias=True)
  (fc1_drop): ConsistentMCDropout(p=0.5)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

In [None]:
# exports


class MnistOptimizerFactory(ModelOptimizerFactory):
    def create_model_optimizer(self) -> ModelOptimizer:
        model = BayesianMNISTCNN()
        optimizer = torch.optim.Adam(model.parameters(), weight_decay=5e-4)
        return ModelOptimizer(model=model, optimizer=optimizer)


@dataclass
class MnistModelTrainer(ModelTrainer):
    device: str

    num_training_samples: int = 1
    num_validation_samples: int = 20
    num_patience_epochs: int = 20
    max_training_epochs: int = 120

    min_samples_per_epoch: int = 1024
    num_training_batch_size: int = 64
    num_evaluation_batch_size: int = 128

    @staticmethod
    def create_model_optimizer() -> ModelOptimizer:
        model = BayesianMNISTCNN()
        optimizer = torch.optim.Adam(model.parameters(), weight_decay=5e-4)
        return ModelOptimizer(model=model, optimizer=optimizer)

    def get_train_dataloader(self, dataset: Dataset):
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.num_training_batch_size,
            sampler=RandomFixedLengthSampler(dataset, self.min_samples_per_epoch),
            drop_last=True,
        )
        return train_loader

    def get_evaluation_dataloader(self, dataset: Dataset):
        evaluation_loader = torch.utils.data.DataLoader(
            dataset, batch_size=self.num_evaluation_batch_size, drop_last=False, shuffle=False
        )
        return evaluation_loader

    def get_trained(self, *, train_loader: DataLoader, train_augmentations: Optional[Module], validation_loader: DataLoader,
                    log) -> TrainedModel:
        model_optimizer = self.create_model_optimizer()

        train(
            model=model_optimizer.model,
            optimizer=model_optimizer.optimizer,
            training_samples=self.num_training_samples,
            validation_samples=self.num_validation_samples,
            train_loader=train_loader,
            train_augmentations=train_augmentations,
            validation_loader=validation_loader,
            patience=self.num_patience_epochs,
            max_epochs=self.max_training_epochs,
            device=self.device,
            training_log=log,
        )

        return TrainedBayesianModel(model_optimizer.model)

    def get_distilled(self, *, prediction_loader: DataLoader, train_augmentations: Optional[Module],
                      validation_loader: DataLoader, log) -> TrainedModel:
        model_optimizer = self.create_model_optimizer()

        loss = torch.nn.KLDivLoss(log_target=True, reduction="batchmean")

        train(
            model=model_optimizer.model,
            optimizer=model_optimizer.optimizer,
            loss=loss,
            validation_loss=torch.nn.NLLLoss(),
            training_samples=self.num_training_samples,
            validation_samples=self.num_validation_samples,
            train_loader=prediction_loader,
            train_augmentations=train_augmentations,
            validation_loader=validation_loader,
            patience=self.num_patience_epochs,
            max_epochs=self.max_training_epochs,
            prefer_accuracy=True,
            device=self.device,
            training_log=log,
        )

        return TrainedBayesianModel(model_optimizer.model)

In [None]:
# slow

from batchbald_redux import dataset_challenges

fast_mnist_train, fast_mnist_test = dataset_challenges.create_MNIST_dataset("cuda")

model_trainer = MnistModelTrainer("cuda")

train_loader = model_trainer.get_train_dataloader(fast_mnist_train)

test_loader = model_trainer.get_evaluation_dataloader(fast_mnist_test)

log = {}
trained_model = model_trainer.get_trained(train_loader=train_loader, train_augmentations=None, validation_loader=test_loader, log=log)

In [None]:
# slow

subset_mnist = fast_mnist_train * 0.2
print(subset_mnist)

train_loader = model_trainer.get_train_dataloader(subset_mnist)

test_loader = model_trainer.get_evaluation_dataloader(fast_mnist_test)

log = {}
trained_model = model_trainer.get_trained(train_loader=train_loader, train_augmentations=None, validation_loader=test_loader, log=log)

('FastMNIST (Train)')~x0.2


  1%|          | 1/120 [00:00<?, ?it/s]

[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.4094, 'crossentropy': 1.9953440071105957}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.7788, 'crossentropy': 1.1713629653930664}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.8632, 'crossentropy': 0.7872836687088013}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.8799, 'crossentropy': 0.6420327793598175}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9026, 'crossentropy': 0.5500037454128265}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9183, 'crossentropy': 0.506918834733963}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9294, 'crossentropy': 0.43596688237190245}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.935, 'crossentropy': 0.4013356751203537}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9457, 'crossentropy': 0.37077586855888367}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.943, 'crossentropy': 0.34797975437641143}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9515, 'crossentropy': 0.3268883366584778}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9526, 'crossentropy': 0.31412847449779513}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9527, 'crossentropy': 0.30681407272815703}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9556, 'crossentropy': 0.2916686305999756}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9607, 'crossentropy': 0.26903275294303897}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9594, 'crossentropy': 0.2621008333206177}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9573, 'crossentropy': 0.2734296598315239}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9612, 'crossentropy': 0.2479740678191185}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9603, 'crossentropy': 0.24664798829555512}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9629, 'crossentropy': 0.24069213716983795}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9643, 'crossentropy': 0.2280900681257248}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9631, 'crossentropy': 0.2271005573153496}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9683, 'crossentropy': 0.21772976952791215}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9655, 'crossentropy': 0.23733734132051468}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9683, 'crossentropy': 0.2267121258020401}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9716, 'crossentropy': 0.20480671402215958}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9691, 'crossentropy': 0.2090889817237854}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9697, 'crossentropy': 0.20590644285678864}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9691, 'crossentropy': 0.2046949967622757}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9703, 'crossentropy': 0.20656840767264367}
RestoringEarlyStopping: 4 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9709, 'crossentropy': 0.202680607932806}
RestoringEarlyStopping: 5 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9708, 'crossentropy': 0.2007450315952301}
RestoringEarlyStopping: 6 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.967, 'crossentropy': 0.20886228442788124}
RestoringEarlyStopping: 7 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9729, 'crossentropy': 0.19618470904827118}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9732, 'crossentropy': 0.19611933329105377}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9738, 'crossentropy': 0.1847836072385311}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9733, 'crossentropy': 0.18611027659773827}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9739, 'crossentropy': 0.18001386508345604}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9775, 'crossentropy': 0.1674060090482235}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9729, 'crossentropy': 0.1818251102209091}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.974, 'crossentropy': 0.1747796884894371}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9743, 'crossentropy': 0.16952655346691609}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9775, 'crossentropy': 0.164365275901556}
RestoringEarlyStopping: 4 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9766, 'crossentropy': 0.16601514192819594}
RestoringEarlyStopping: 5 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9761, 'crossentropy': 0.16493101998865603}
RestoringEarlyStopping: 6 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9773, 'crossentropy': 0.17206964423060417}
RestoringEarlyStopping: 7 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.975, 'crossentropy': 0.17816308836340905}
RestoringEarlyStopping: 8 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9775, 'crossentropy': 0.15955212568044663}
RestoringEarlyStopping: 9 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9783, 'crossentropy': 0.15314062087535857}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9782, 'crossentropy': 0.16765735062360765}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9764, 'crossentropy': 0.1592094441652298}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9763, 'crossentropy': 0.16179060089886188}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9775, 'crossentropy': 0.16191123164892196}
RestoringEarlyStopping: 4 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9811, 'crossentropy': 0.14821218244433404}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9789, 'crossentropy': 0.15140198409557343}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9788, 'crossentropy': 0.15476766859292984}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9806, 'crossentropy': 0.1552284104704857}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9795, 'crossentropy': 0.14894249083399771}
RestoringEarlyStopping: 4 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9784, 'crossentropy': 0.15906811181902886}
RestoringEarlyStopping: 5 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9777, 'crossentropy': 0.16086507358551025}
RestoringEarlyStopping: 6 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9791, 'crossentropy': 0.14977331364750862}
RestoringEarlyStopping: 7 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9773, 'crossentropy': 0.16522770585417748}
RestoringEarlyStopping: 8 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9807, 'crossentropy': 0.14948120200037956}
RestoringEarlyStopping: 9 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9803, 'crossentropy': 0.15035601996481418}
RestoringEarlyStopping: 10 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9802, 'crossentropy': 0.14568143537342548}
RestoringEarlyStopping: 11 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9797, 'crossentropy': 0.14922687017321587}
RestoringEarlyStopping: 12 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9814, 'crossentropy': 0.1505122820407152}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9818, 'crossentropy': 0.13978955081701278}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9815, 'crossentropy': 0.14726527796983718}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9791, 'crossentropy': 0.15016660173386334}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9812, 'crossentropy': 0.1475597570002079}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9813, 'crossentropy': 0.14631702497005464}
RestoringEarlyStopping: 4 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9816, 'crossentropy': 0.14835208004713057}
RestoringEarlyStopping: 5 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9822, 'crossentropy': 0.14683693165183068}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9788, 'crossentropy': 0.14925556500703097}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9811, 'crossentropy': 0.14512694598138332}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9794, 'crossentropy': 0.14510144341588022}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9808, 'crossentropy': 0.14316769663095474}
RestoringEarlyStopping: 4 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9802, 'crossentropy': 0.14238641458153725}
RestoringEarlyStopping: 5 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9807, 'crossentropy': 0.14452635582685472}
RestoringEarlyStopping: 6 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9805, 'crossentropy': 0.14731967992782594}
RestoringEarlyStopping: 7 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9804, 'crossentropy': 0.14613135629594326}
RestoringEarlyStopping: 8 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9801, 'crossentropy': 0.1444930137038231}
RestoringEarlyStopping: 9 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9805, 'crossentropy': 0.14659327845573425}
RestoringEarlyStopping: 10 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9816, 'crossentropy': 0.1520789356648922}
RestoringEarlyStopping: 11 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9813, 'crossentropy': 0.14676751902103424}
RestoringEarlyStopping: 12 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9818, 'crossentropy': 0.1371298746585846}
RestoringEarlyStopping: 13 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9819, 'crossentropy': 0.13903203086256982}
RestoringEarlyStopping: 14 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9826, 'crossentropy': 0.13630688445568084}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9825, 'crossentropy': 0.13889403404295445}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9809, 'crossentropy': 0.14042093126177788}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.981, 'crossentropy': 0.14360348327457906}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9833, 'crossentropy': 0.1358983141809702}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9837, 'crossentropy': 0.13061524590849877}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9823, 'crossentropy': 0.13092151409983635}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9835, 'crossentropy': 0.13183467201292515}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9791, 'crossentropy': 0.1460676993727684}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9826, 'crossentropy': 0.13181567433178426}
RestoringEarlyStopping: 4 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9826, 'crossentropy': 0.12675335685908795}
RestoringEarlyStopping: 5 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9814, 'crossentropy': 0.1308151033759117}
RestoringEarlyStopping: 6 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9822, 'crossentropy': 0.13530614287555218}
RestoringEarlyStopping: 7 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.983, 'crossentropy': 0.12484264653474092}
RestoringEarlyStopping: 8 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9834, 'crossentropy': 0.12668347802460195}
RestoringEarlyStopping: 9 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.982, 'crossentropy': 0.14050724054276942}
RestoringEarlyStopping: 10 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9833, 'crossentropy': 0.1272612214833498}
RestoringEarlyStopping: 11 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9827, 'crossentropy': 0.13000124952197076}
RestoringEarlyStopping: 12 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9818, 'crossentropy': 0.1327671136111021}
RestoringEarlyStopping: 13 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.983, 'crossentropy': 0.1257890721693635}
RestoringEarlyStopping: 14 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9829, 'crossentropy': 0.12695484325885772}
RestoringEarlyStopping: 15 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9834, 'crossentropy': 0.12922529191672802}
RestoringEarlyStopping: 16 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9826, 'crossentropy': 0.13050008742213248}
RestoringEarlyStopping: 17 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9831, 'crossentropy': 0.13368907009512185}
RestoringEarlyStopping: 18 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9839, 'crossentropy': 0.12314552030414343}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9837, 'crossentropy': 0.1273065314412117}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9818, 'crossentropy': 0.12561971281468867}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9848, 'crossentropy': 0.12864230203926563}


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.983, 'crossentropy': 0.13225059789121152}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9836, 'crossentropy': 0.12992607839554549}
RestoringEarlyStopping: 2 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9803, 'crossentropy': 0.14164143002033233}
RestoringEarlyStopping: 3 / 20


[1/16]   6%|6          [00:00<?]

[1/79]   1%|1          [00:00<?]

Epoch metrics: {'accuracy': 0.9807, 'crossentropy': 0.14529887140989303}
RestoringEarlyStopping: 4 / 20
RestoringEarlyStopping: Restoring best parameters. (Score: 0.9848)
RestoringEarlyStopping: Restoring optimizer.
