# LML Unified Experiment Code
> Resistance is futile.

In [None]:
# default_exp lml_unified_experiment

In [None]:
# hide
import blackhc.project.script

Appended /home/blackhc/PycharmProjects/bald-ical/src to paths
Switched to directory /home/blackhc/PycharmProjects/bald-ical
%load_ext autoreload
%autoreload 2


Import modules and functions were are going to use.

In [None]:
# exports

import dataclasses
import traceback
from dataclasses import dataclass
from typing import Dict, List, Optional, Type, Union

import numpy as np
import torch
import torch.utils.data
from blackhc.project import is_run_from_ipython
from blackhc.project.experiment import embedded_experiments
from torch import nn
from torch.utils.data import Dataset

import batchbald_redux.acquisition_functions.batchbald
import batchbald_redux.acquisition_functions.epig
import wandb
from batchbald_redux import acquisition_functions, baseline_acquisition_functions
from batchbald_redux.acquisition_functions import (
    CandidateBatchComputer,
    EvalDatasetBatchComputer,
    EvalModelBatchComputer,
)
from batchbald_redux.active_learning import ActiveLearningData
from batchbald_redux.black_box_model_training import evaluate
from batchbald_redux.dataset_challenges import (
    NamedDataset,
    get_balanced_sample_indices,
    get_base_dataset_index,
    get_target,
)
from batchbald_redux.datasets import get_dataset
from batchbald_redux.di import DependencyInjection
from batchbald_redux.experiment_data import (
    ExperimentData,
    ExperimentDataConfig,
    OoDDatasetConfig,
    StandardExperimentDataConfig,
)
from batchbald_redux.experiment_logging import init_wandb, log2wandb
from batchbald_redux.models import MnistModelTrainer
from batchbald_redux.resnet_models import Cifar10ModelTrainer
from batchbald_redux.train_eval_model import (
    TrainEvalModel,
    TrainSelfDistillationEvalModel,
)
from batchbald_redux.trained_model import BayesianEnsembleModelTrainer, ModelTrainer

In [None]:
# exports


@dataclass
class LmlExperimentData:
    train_dataset: Dataset
    validation_dataset: Dataset
    test_dataset: Dataset
    train_augmentations: nn.Module
    initial_training_set_indices: List[int]
    device: str


@dataclass
class LmlExperimentDataConfig(ExperimentDataConfig):
    id_dataset_name: str
    initial_training_set_size: int
    validation_set_size: int
    validation_split_random_state: int

    def load(self, device) -> LmlExperimentData:
        return load_distribution_experiment_data(
            id_dataset_name=self.id_dataset_name,
            initial_training_set_size=self.initial_training_set_size,
            validation_set_size=self.validation_set_size,
            validation_split_random_state=self.validation_split_random_state,
            device=device,
        )


def load_distribution_experiment_data(
    *,
    id_dataset_name: str,
    initial_training_set_size: int,
    validation_set_size: int,
    validation_split_random_state: int,
    device: str,
) -> LmlExperimentData:
    split_dataset = get_dataset(
        id_dataset_name,
        root="data",
        validation_set_size=validation_set_size,
        validation_split_random_state=validation_split_random_state,
        normalize_like_cifar10=True,
        device_hint=device,
    )

    train_dataset = split_dataset.train

    targets = train_dataset.get_targets()
    num_classes = train_dataset.get_num_classes()
    initial_samples_per_class = initial_training_set_size // num_classes
    initial_training_set_indices = get_balanced_sample_indices(
        targets=targets,
        num_classes=num_classes,
        samples_per_class=initial_samples_per_class,
        seed=validation_split_random_state,
    )

    return LmlExperimentData(
        train_dataset=train_dataset,
        validation_dataset=split_dataset.validation,
        test_dataset=split_dataset.test,
        train_augmentations=split_dataset.train_augmentations,
        initial_training_set_indices=initial_training_set_indices,
        device=split_dataset.device,
    )

In [None]:
# exports


@dataclass
class LmlEstimates:
    marginal_log_predictive: float
    joint_log_predictive: float


def get_lml_estimates(log_probs_N_K_C_labels_N):
    log_probs_N_K_C, labels_N = log_probs_N_K_C_labels_N

    true_log_probs_N_K = log_probs_N_K_C[list(range(len(labels_N))), :, labels_N]

    marginal_log_predictive = torch.logsumexp(true_log_probs_N_K, dim=1).sum(dim=0) - np.log(
        true_log_probs_N_K.shape[1]
    )
    joint_log_predictive = torch.logsumexp(true_log_probs_N_K.sum(dim=0), dim=0) - np.log(true_log_probs_N_K.shape[1])
    return LmlEstimates(marginal_log_predictive, joint_log_predictive)


@dataclass
class LmlActiveLearner:
    acquisition_size: int
    max_training_set: int

    num_validation_samples: int
    num_pool_samples: int

    acquisition_function: Union[CandidateBatchComputer, EvalModelBatchComputer]
    model_trainer: ModelTrainer
    data: LmlExperimentData

    device: Optional

    def __call__(self, log):
        log["seed"] = torch.seed()

        # Active Learning setup
        data = self.data

        self.max_training_set = (
            (self.max_training_set + self.acquisition_size - 1) // self.acquisition_size * self.acquisition_size
        )

        active_learning_data = ActiveLearningData(data.train_dataset)
        active_learning_data.acquire_base_indices(data.initial_training_set_indices)

        # Remove most of the remaining pool set
        generator = np.random.default_rng(1137)
        discard_indices = generator.permutation(len(active_learning_data.pool_dataset))[
            : -(self.max_training_set - len(data.initial_training_set_indices))
        ]
        active_learning_data.extract_dataset_from_pool_indices(discard_indices)
        log["pool_indices"] = active_learning_data.pool_dataset.indices

        model_trainer = self.model_trainer

        train_loader = model_trainer.get_train_dataloader(active_learning_data.training_dataset)
        pool_loader = model_trainer.get_evaluation_dataloader(active_learning_data.pool_dataset)
        validation_loader = model_trainer.get_evaluation_dataloader(data.validation_dataset)
        test_loader = model_trainer.get_evaluation_dataloader(data.test_dataset)

        log["active_learning_steps"] = []
        active_learning_steps = log["active_learning_steps"]

        acquisition_function = self.acquisition_function

        num_iterations = 0
        max_iterations = int(
            1.5 * (self.max_training_set - len(active_learning_data.training_dataset)) / self.acquisition_size
        )

        # Active Training Loop
        while True:
            training_set_size = len(active_learning_data.training_dataset)
            print(f"Training set size {training_set_size}:")

            # iteration_log = dict(training={}, pool_training={}, evaluation_metrics=None, acquisition=None)
            active_learning_steps.append({})
            iteration_log = active_learning_steps[-1]

            iteration_log["training"] = {}

            loss = validation_loss = torch.nn.NLLLoss()

            trained_model = model_trainer.get_trained(
                train_loader=train_loader,
                train_augmentations=data.train_augmentations,
                validation_loader=validation_loader,
                log=iteration_log["training"],
                wandb_key_path="model_training",
                loss=loss,
                validation_loss=validation_loss,
            )

            evaluation_metrics = evaluate(
                model=trained_model,
                num_samples=self.num_validation_samples,
                loader=test_loader,
                device=self.device,
                storage_device="cpu",
            )
            iteration_log["evaluation_metrics"] = evaluation_metrics
            log2wandb(evaluation_metrics, commit=False)
            print(f"Perf after training {evaluation_metrics}")

            if training_set_size >= self.max_training_set or num_iterations >= max_iterations:
                log2wandb({}, commit=True)
                print("Done.")
                break

            if isinstance(acquisition_function, CandidateBatchComputer):
                candidate_batch = acquisition_function.compute_candidate_batch(trained_model, pool_loader, self.device)
            else:
                raise ValueError(f"Unknown acquisition function {acquisition_function}!")

            candidate_global_dataset_indices = []
            candidate_labels = []
            candidate_images = []
            for index in candidate_batch.indices:
                base_di = get_base_dataset_index(active_learning_data.pool_dataset, index)
                dataset_type = "id"
                candidate_global_dataset_indices.append((dataset_type, base_di.index))
                if dataset_type == "id":
                    label = get_target(active_learning_data.pool_dataset, index).tolist()
                else:
                    label = None
                candidate_labels.append(label)
                candidate_images.append(wandb.Image(active_learning_data.pool_dataset[index][0]))

            # Lml computation
            lml_batch_dataloader = model_trainer.get_evaluation_dataloader(
                torch.utils.data.Subset(active_learning_data.pool_dataset, candidate_batch.indices)
            )
            lml_log_probs_N_K_C_labels_N = trained_model.get_log_probs_N_K_C_labels_N(
                lml_batch_dataloader, self.num_pool_samples, self.device, "cpu"
            )
            lml_estimate = get_lml_estimates(lml_log_probs_N_K_C_labels_N)

            iteration_log["lml_estimate"] = lml_estimate
            log2wandb(dict(lml_estimate=dataclasses.asdict(lml_estimate)), commit=False)

            acquisition_info = dict(
                indices=candidate_global_dataset_indices, labels=candidate_labels, scores=candidate_batch.scores
            )
            iteration_log["acquisition"] = acquisition_info

            acquistion_batch_table = wandb.Table(
                data=list(
                    zip(
                        *zip(*candidate_global_dataset_indices),
                        candidate_images,
                        candidate_labels,
                        candidate_batch.scores,
                    )
                ),
                columns=["dataset", "index", "sample", "label", "score"],
            )
            log2wandb(dict(acquisition=acquistion_batch_table), commit=False)

            print(candidate_batch)
            print(candidate_global_dataset_indices)

            active_learning_data.acquire(candidate_batch.indices)

            ls = ", ".join(f"{label} ({score:.4})" for label, score in zip(candidate_labels, candidate_batch.scores))
            print(f"Acquiring (label, score)s: {ls}")

            num_iterations += 1
            log2wandb({}, commit=True)


@dataclass
class LmlUnifiedExperiment:
    seed: int

    experiment_data_config: LmlExperimentDataConfig

    acquisition_size: int = 5
    max_training_set: int = 200

    max_training_epochs: int = 300

    num_pool_samples: int = 100
    num_validation_samples: int = 20
    num_training_samples: int = 1

    device: str = "cuda"
    acquisition_function: Union[
        Type[CandidateBatchComputer], Type[EvalModelBatchComputer]
    ] = None  # acquisition_functions.BALD
    model_trainer_factory: Type[ModelTrainer] = None  # Cifar10ModelTrainer
    ensemble_size: int = 1

    temperature: float = 1.0
    coldness: float = 1.0
    stochastic_mode: acquisition_functions.StochasticMode = acquisition_functions.StochasticMode.TopK

    def load_experiment_data(self) -> LmlExperimentData:
        print(self.experiment_data_config)
        return self.experiment_data_config.load(self.device)

    # Simple Dependency Injection
    def create_acquisition_function(self):
        di = DependencyInjection(vars(self))
        return di.create_dataclass_type(self.acquisition_function)

    def create_model_trainer(self) -> ModelTrainer:
        di = DependencyInjection(vars(self))
        return di.create_dataclass_type(self.model_trainer_factory)

    def run(self, store, project=None, entity=None):
        init_wandb(self, project=project, entity=entity)

        torch.manual_seed(self.seed)

        # Active Learning setup
        data = self.load_experiment_data()
        store["dataset_info"] = dict(training=repr(data.train_dataset), test=repr(data.test_dataset))
        store["initial_training_set_indices"] = data.initial_training_set_indices

        print(wandb.config)

        wandb.config.initial_training_set_indices = data.initial_training_set_indices
        wandb.config["dataset_info"] = store["dataset_info"]

        acquisition_function = self.create_acquisition_function()
        model_trainer = self.create_model_trainer()
        if self.ensemble_size > 1:
            model_trainer = BayesianEnsembleModelTrainer(model_trainer=model_trainer, ensemble_size=self.ensemble_size)

        active_learner = LmlActiveLearner(
            acquisition_size=self.acquisition_size,
            max_training_set=self.max_training_set,
            num_validation_samples=self.num_validation_samples,
            num_pool_samples=self.num_pool_samples,
            acquisition_function=acquisition_function,
            model_trainer=model_trainer,
            data=data,
            device=self.device,
        )

        active_learner(store)

        wandb.finish()

In [None]:
configs = [
    LmlUnifiedExperiment(
        experiment_data_config=LmlExperimentDataConfig(
            id_dataset_name="MNIST",
            initial_training_set_size=20,
            validation_set_size=4096,
            validation_split_random_state=0,
        ),
        seed=seed + 45682,
        acquisition_function=acquisition_function,
        acquisition_size=acquisition_size,
        num_pool_samples=num_pool_samples,
        max_training_set=1000,
        model_trainer_factory=MnistModelTrainer,
        stochastic_mode=stochastic_mode,
        coldness=coldness,
    )
    for seed in range(5)
    for acquisition_size in [10]
    for num_pool_samples in [100]
    for coldness in [1]
    for stochastic_mode in [
        acquisition_functions.StochasticMode.Power,
    ]
    for acquisition_function in [
        acquisition_functions.BALD,
        acquisition_functions.Random
    ]
]

if not is_run_from_ipython() and __name__ == "__main__":
    for job_id, store in embedded_experiments(__file__, len(configs)):
        config = configs[job_id]
        config.seed += job_id
        print(config)
        store["config"] = dataclasses.asdict(config)
        store["log"] = {}

        try:
            config.run(store=store)
        except Exception:
            store["exception"] = traceback.format_exc()
            raise

In [None]:
len(configs)

10

## MNIST only

In [None]:
# experiment
# MNIST experiment (ood_exposure=False)

experiment = LmlUnifiedExperiment(
    experiment_data_config=LmlExperimentDataConfig(
        id_dataset_name="MNIST",
        initial_training_set_size=20,
        validation_set_size=4096,
        validation_split_random_state=0,
    ),
    seed=1,
    max_training_epochs=5,
    max_training_set=20 + 20,
    acquisition_function=acquisition_functions.BALD,
    acquisition_size=10,
    model_trainer_factory=MnistModelTrainer,
    num_pool_samples=2,
    device="cuda",
)

results = {}
experiment.run(results)
results

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33moatml-andreas-kirsch[0m (use `wandb login --relogin` to force relogin)


LmlExperimentDataConfig(id_dataset_name='MNIST', initial_training_set_size=20, validation_set_size=4096, validation_split_random_state=0)
{'Dataclass': '__main__.LmlUnifiedExperiment', 'seed': 1, 'experiment_data_config': {'Dataclass': '__main__.LmlExperimentDataConfig', 'id_dataset_name': 'MNIST', 'initial_training_set_size': 20, 'validation_set_size': 4096, 'validation_split_random_state': 0}, 'acquisition_size': 10, 'max_training_set': 40, 'max_training_epochs': 5, 'num_pool_samples': 2, 'num_validation_samples': 20, 'num_training_samples': 1, 'device': 'cuda', 'acquisition_function': 'batchbald_redux.acquisition_functions.bald.BALD', 'model_trainer_factory': 'batchbald_redux.models.MnistModelTrainer', 'ensemble_size': 1, 'temperature': 1.0, 'coldness': 1.0, 'stochastic_mode': 'StochasticMode.TopK'}
Creating: BALD(
	acquisition_size=10,
	num_pool_samples=2,
	coldness=1.0,
	stochastic_mode=StochasticMode.TopK
)
Creating: MnistModelTrainer(
	device=cuda,
	num_training_samples=1,
	num_

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return floored.astype(np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return floored.astype(np.int)


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.633544921875, 'crossentropy': 1.6153091192245483}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.665771484375, 'crossentropy': 1.049702525138855}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.6884765625, 'crossentropy': 1.0206776857376099}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.715576171875, 'crossentropy': 0.9736806750297546}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.689453125, 'crossentropy': 1.0641498565673828}
RestoringEarlyStopping: 1 / 20
RestoringEarlyStopping: Restoring best parameters. (Score: 0.715576171875)
RestoringEarlyStopping: Restoring optimizer.
{'model_training/val_metrics': <wandb.data_types.Table object at 0x7f543e24eca0>, 'model_training/best_epoch': 3, 'model_training/best_val_accuracy': 0.715576171875, 'model_training/best_val_crossentropy': 0.9736806750297546}


get_predictions_labels:   0%|          | 0/200000 [00:00<?, ?it/s]

Perf after training {'accuracy': 0.7305, 'crossentropy': tensor(0.9253), '_timestamp': 1654790487, '_runtime': 17}
20 40


get_predictions_labels:   0%|          | 0/40 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/20 [00:00<?, ?it/s]

Entropy:   0%|          | 0/20 [00:00<?, ?it/s]

get_predictions_labels:   0%|          | 0/20 [00:00<?, ?it/s]

CandidateBatch(scores=[0.6384381754120134, 0.5322415286180691, 0.5260113756832867, 0.5171422857261013, 0.43788164175231065, 0.41219195967240185, 0.35280675544516427, 0.2784379799719978, 0.23271255500333565, 0.21250867786466032], indices=[17, 4, 5, 14, 0, 11, 12, 6, 18, 15])
[('id', 42257), ('id', 10772), ('id', 12675), ('id', 33931), ('id', 4772), ('id', 18874), ('id', 20397), ('id', 14713), ('id', 51472), ('id', 34214)]
Acquiring (label, score)s: 7 (0.6384), 0 (0.5322), 9 (0.526), 3 (0.5171), 8 (0.4379), 5 (0.4122), 6 (0.3528), 4 (0.2784), 9 (0.2327), 4 (0.2125)
Training set size 30:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.3115234375, 'crossentropy': 1.981034755706787}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.688232421875, 'crossentropy': 1.1834875345230103}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.724609375, 'crossentropy': 0.9258241057395935}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.72119140625, 'crossentropy': 0.8667286038398743}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.7412109375, 'crossentropy': 0.8418170213699341}
RestoringEarlyStopping: Restoring best parameters. (Score: 0.7412109375)
RestoringEarlyStopping: Restoring optimizer.
{'model_training/val_metrics': <wandb.data_types.Table object at 0x7f541079beb0>, 'model_training/best_epoch': 4, 'model_training/best_val_accuracy': 0.7412109375, 'model_training/best_val_crossentropy': 0.8418170213699341}


get_predictions_labels:   0%|          | 0/200000 [00:00<?, ?it/s]

Perf after training {'accuracy': 0.7385, 'crossentropy': tensor(0.8186), '_timestamp': 1654790494, '_runtime': 24}
30 40


get_predictions_labels:   0%|          | 0/20 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/10 [00:00<?, ?it/s]

Entropy:   0%|          | 0/10 [00:00<?, ?it/s]

get_predictions_labels:   0%|          | 0/20 [00:00<?, ?it/s]

CandidateBatch(scores=[0.5039396624648, 0.11559749183180545, 0.022129452678531503, 0.010923715540287343, 0.0044587734035281545, 0.0013395915672130727, 0.0010455881146414457, 0.0003782144064454155, 0.0001937470539422276, 7.375469733071432e-05], indices=[9, 0, 8, 6, 7, 5, 2, 4, 1, 3])
[('id', 54609), ('id', 5447), ('id', 38349), ('id', 17476), ('id', 28686), ('id', 16197), ('id', 9157), ('id', 15373), ('id', 6636), ('id', 14867)]
Acquiring (label, score)s: 6 (0.5039), 5 (0.1156), 4 (0.02213), 6 (0.01092), 1 (0.004459), 6 (0.00134), 9 (0.001046), 6 (0.0003782), 1 (0.0001937), 0 (7.375e-05)
Training set size 40:


 20%|##        | 1/5 [00:00<?, ?it/s]

[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.3603515625, 'crossentropy': 1.8856765031814575}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.647216796875, 'crossentropy': 1.1425728797912598}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.7236328125, 'crossentropy': 0.8968830704689026}


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.7109375, 'crossentropy': 0.9118012189865112}
RestoringEarlyStopping: 1 / 20


[1/16]   6%|6          [00:00<?]

[1/32]   3%|3          [00:00<?]

Epoch metrics: {'accuracy': 0.732177734375, 'crossentropy': 0.8665317296981812}
RestoringEarlyStopping: Restoring best parameters. (Score: 0.732177734375)
RestoringEarlyStopping: Restoring optimizer.
{'model_training/val_metrics': <wandb.data_types.Table object at 0x7f543e242f70>, 'model_training/best_epoch': 4, 'model_training/best_val_accuracy': 0.732177734375, 'model_training/best_val_crossentropy': 0.8665317296981812}


get_predictions_labels:   0%|          | 0/200000 [00:00<?, ?it/s]

Perf after training {'accuracy': 0.7538, 'crossentropy': tensor(0.7971), '_timestamp': 1654790502, '_runtime': 32}
40 40
Done.



VBox(children=(Label(value='0.541 MB of 0.541 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁▃█
crossentropy,█▂▁
model_training/best_epoch,▁██
model_training/best_val_accuracy,▁█▆
model_training/best_val_crossentropy,█▁▂

0,1
accuracy,0.7538
crossentropy,0.79714
model_training/best_epoch,4.0
model_training/best_val_accuracy,0.73218
model_training/best_val_crossentropy,0.86653


{'dataset_info': {'training': "'MNIST (Train, seed=0, 55904 samples)'",
  'test': "'MNIST (Test)'"},
 'initial_training_set_indices': [47227,
  11511,
  18383,
  41080,
  32837,
  24393,
  23904,
  11784,
  20439,
  35043,
  27367,
  30426,
  32361,
  26116,
  24386,
  4689,
  44895,
  24211,
  17212,
  3478],
 'seed': 4749488752848790432,
 'pool_indices': array([ 4772,  5447,  6636,  9157, 10772, 12675, 14713, 14867, 15373,
        16197, 17476, 18874, 20397, 28686, 33931, 34214, 38349, 42257,
        51472, 54609]),
 'active_learning_steps': [{'training': {'epochs': [{'accuracy': 0.633544921875,
      'crossentropy': 1.6153091192245483},
     {'accuracy': 0.665771484375, 'crossentropy': 1.049702525138855},
     {'accuracy': 0.6884765625, 'crossentropy': 1.0206776857376099},
     {'accuracy': 0.715576171875, 'crossentropy': 0.9736806750297546},
     {'accuracy': 0.689453125, 'crossentropy': 1.0641498565673828}],
    'best_epoch': 3},
   'evaluation_metrics': {'accuracy': 0.7305,
    '