# Black Box Model Training
> "Learning is not attained by chance, it must be sought for with ardor and attended to with diligence.” 
>
> &mdash; <cite>Abigail Adams</cite>.

After looking at a simple example experiment, it is worth looking at the big picture. The big picture requires us to train different models with differently-sized datasets.

We don't want to worry about fine-tuning training too much. Because we cannot.

In [None]:
# default_exp black_box_model_training

In [None]:
# hide
import blackhc.project.script
from nbdev.showdoc import *

Appended /home/blackhc/PycharmProjects/bald-ical/src to paths
Switched to directory /home/blackhc/PycharmProjects/bald-ical
%load_ext autoreload
%autoreload 2


## Goals

* Log as much as possible by default.
* Avoid magic numbers. Magic numbers don't work very well when everything keeps changing.

In [None]:
# exporti

from dataclasses import dataclass
from typing import Optional
import torch
from blackhc.project import is_run_from_ipython
from blackhc.project.utils.ignite_progress_bar import ignite_progress_bar
from ignite.contrib.engines.common import setup_common_training_handlers
from ignite.contrib.handlers import ProgressBar
from ignite.engine import Events, create_supervised_evaluator, create_supervised_trainer
from ignite.metrics import Accuracy, Loss, RunningAverage
from torch import nn


from batchbald_redux.consistent_mc_dropout import (
    GeometricMeanPrediction,
    SamplerModel,
    multi_sample_loss, BayesianModule, get_log_mean_probs,
)
from batchbald_redux.restoring_early_stopping import (
    PatienceWithSnapshot,
    ReduceLROnPlateauWithScheduleWrapper,
    RestoringEarlyStopping,
)

In [None]:
# exports
from batchbald_redux.trained_model import TrainedModel

LOG_INTERVAL = 10


def train(
    *,
    model,
    training_samples,
    validation_samples,
    train_loader,
    validation_loader,
    patience: Optional[int],
    max_epochs: int,
    device: str,
    training_log: dict,
    loss=None,
    validation_loss=None,
    optimizer=None,
    prefer_accuracy=True,
    train_augmentations=None,
):
    if not len(train_loader.dataset):
        return optimizer

    if loss is None:
        loss = nn.NLLLoss()
    if validation_loss is None:
        validation_loss = loss

    train_model = SamplerModel(model, training_samples)
    if train_augmentations is not None:
        train_model = torch.nn.Sequential(train_augmentations, train_model)

    validation_model = GeometricMeanPrediction(SamplerModel(model, validation_samples))

    # Move model to device before creating the optimizer
    train_model.to(device)

    if optimizer is None:
        optimizer = torch.optim.Adam(model.parameters(), weight_decay=5e-4)

    trainer = create_supervised_trainer(train_model, optimizer, loss_fn=multi_sample_loss(loss), device=device)

    metrics = create_metrics(validation_loss)

    validation_evaluator = create_supervised_evaluator(validation_model, metrics=metrics, device=device)

    @trainer.on(Events.EPOCH_COMPLETED)
    def compute_metrics(engine):
        validation_evaluator.run(validation_loader)

    # Only to look nicer.
    RunningAverage(output_transform=lambda x: x).attach(trainer, "crossentropy")

    enable_tqdm_pbars = is_run_from_ipython()

    setup_common_training_handlers(
        trainer, with_pbars=enable_tqdm_pbars, with_gpu_stats=torch.cuda.is_available(), log_every_iters=LOG_INTERVAL
    )

    if enable_tqdm_pbars:
        ProgressBar(persist=False).attach(
            validation_evaluator,
            metric_names="all",
            event_name=Events.ITERATION_COMPLETED(every=LOG_INTERVAL),
        )
    else:
        ignite_progress_bar(trainer, desc=lambda engine: "Training", log_interval=LOG_INTERVAL)

    training_log["epochs"] = []
    epochs_log = training_log["epochs"]

    # Logging
    @validation_evaluator.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        metrics = dict(engine.state.metrics)
        epochs_log.append(metrics)

        if is_run_from_ipython():
            print(f"Epoch metrics: {metrics}")

    # Add early stopping
    if patience is not None:
        if prefer_accuracy:

            def score_function():
                return float(validation_evaluator.state.metrics["accuracy"])

        else:

            def score_function():
                return float(-validation_evaluator.state.metrics["crossentropy"])

        early_stopping = RestoringEarlyStopping(
            patience=patience,
            score_function=score_function,
            module=model,
            optimizer=optimizer,
            training_engine=trainer,
            validation_engine=validation_evaluator,
        )
    else:
        early_stopping = None

    # Kick everything off
    trainer.run(train_loader, max_epochs=max_epochs)

    if early_stopping:
        training_log["best_epoch"] = early_stopping.best_epoch

    # Return the optimizer in case we want to continue training.
    return optimizer

In [None]:
# exports


def train_with_schedule(
    *,
    model,
    training_samples,
    validation_samples,
    train_loader,
    validation_loader,
    patience_schedule: [int],
    factor_schedule: [int],
    max_epochs: int,
    device: str,
    training_log: dict,
    loss=None,
    validation_loss=None,
    optimizer=None,
    prefer_accuracy=True,
    train_augmentations=None,
):
    if not len(train_loader.dataset):
        return optimizer

    if loss is None:
        loss = nn.NLLLoss()
    if validation_loss is None:
        validation_loss = loss

    train_model = SamplerModel(model, training_samples)
    if train_augmentations is not None:
        train_model = torch.nn.Sequential(train_augmentations, train_model)

    validation_model = GeometricMeanPrediction(SamplerModel(model, validation_samples))

    # Move model to device before creating the optimizer
    train_model.to(device)

    if optimizer is None:
        optimizer = torch.optim.Adam(model.parameters(), weight_decay=5e-4)

    trainer = create_supervised_trainer(train_model, optimizer, loss_fn=multi_sample_loss(loss), device=device)

    metrics = create_metrics(validation_loss)

    validation_evaluator = create_supervised_evaluator(validation_model, metrics=metrics, device=device)

    @trainer.on(Events.EPOCH_COMPLETED)
    def compute_metrics(engine):
        validation_evaluator.run(validation_loader)

    # Only to look nicer.
    RunningAverage(output_transform=lambda x: x).attach(trainer, "crossentropy")

    enable_tqdm_pbars = is_run_from_ipython()

    setup_common_training_handlers(
        trainer, with_pbars=enable_tqdm_pbars, with_gpu_stats=torch.cuda.is_available(), log_every_iters=LOG_INTERVAL
    )

    if enable_tqdm_pbars:
        ProgressBar(persist=False).attach(
            validation_evaluator,
            metric_names="all",
            event_name=Events.ITERATION_COMPLETED(every=LOG_INTERVAL),
        )
    else:
        ignite_progress_bar(trainer, desc=lambda engine: "Training", log_interval=LOG_INTERVAL)

    training_log["epochs"] = []
    epochs_log = training_log["epochs"]

    # Logging
    @validation_evaluator.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        metrics = dict(engine.state.metrics)
        epochs_log.append(metrics)

        if is_run_from_ipython():
            print(f"Epoch metrics: {metrics}")

    if prefer_accuracy:
        def score_function(metrics):
            return float(metrics["accuracy"])

    else:

        def score_function(metrics):
            return float(metrics["crossentropy"])

    training_log["era_epochs"] = []

    def next_era_callback():
        training_log["era_epochs"].append(trainer.state.epoch)

    scheduler = ReduceLROnPlateauWithScheduleWrapper(
        optimizer,
        metrics_transform=score_function,
        factor_schedule=factor_schedule,
        patience_schedule=patience_schedule,
        end_callback=trainer.terminate,
        next_era_callback=next_era_callback,
        mode="max" if prefer_accuracy else "min",
        verbose=True,
    )

    @validation_evaluator.on(Events.EPOCH_COMPLETED)
    def step_scheduler(engine):
        scheduler.step(engine)

    # Kick everything off
    trainer.run(train_loader, max_epochs=max_epochs)

    # Return the optimizer in case we want to continue training.
    return optimizer


def train_with_cosine_annealing(
    *,
    model,
    training_samples,
    validation_samples,
    train_loader,
    validation_loader,
    max_epochs: int,
    device: str,
    training_log: dict,
    loss=None,
    validation_loss=None,
    optimizer=None,
    train_augmentations=None,
):
    if not len(train_loader.dataset):
        return optimizer

    if loss is None:
        loss = nn.NLLLoss()
    if validation_loss is None:
        validation_loss = loss

    train_model = SamplerModel(model, training_samples)
    if train_augmentations is not None:
        train_model = torch.nn.Sequential(train_augmentations, train_model)

    validation_model = GeometricMeanPrediction(SamplerModel(model, validation_samples))

    # Move model to device before creating the optimizer
    train_model.to(device)

    if optimizer is None:
        optimizer = torch.optim.Adam(model.parameters(), weight_decay=5e-4)

    trainer = create_supervised_trainer(train_model, optimizer, loss_fn=multi_sample_loss(loss), device=device)

    metrics = create_metrics(validation_loss)

    validation_evaluator = create_supervised_evaluator(validation_model, metrics=metrics, device=device)

    @trainer.on(Events.EPOCH_COMPLETED)
    def compute_metrics(engine):
        validation_evaluator.run(validation_loader)

    # Only to look nicer.
    RunningAverage(output_transform=lambda x: x).attach(trainer, "crossentropy")

    enable_tqdm_pbars = is_run_from_ipython()

    setup_common_training_handlers(
        trainer, with_pbars=enable_tqdm_pbars, with_gpu_stats=torch.cuda.is_available(), log_every_iters=LOG_INTERVAL
    )

    if enable_tqdm_pbars:
        ProgressBar(persist=False).attach(
            validation_evaluator,
            metric_names="all",
            event_name=Events.ITERATION_COMPLETED(every=LOG_INTERVAL),
        )
    else:
        ignite_progress_bar(trainer, desc=lambda engine: "Training", log_interval=LOG_INTERVAL)

    training_log["epochs"] = []
    epochs_log = training_log["epochs"]

    # Logging
    @validation_evaluator.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        metrics = dict(engine.state.metrics)
        epochs_log.append(metrics)

        if is_run_from_ipython():
            print(f"Epoch {trainer.state.epoch} metrics: {metrics}")

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=max_epochs
    )

    @validation_evaluator.on(Events.EPOCH_COMPLETED)
    def step_scheduler(engine):
        scheduler.step()

    # Kick everything off
    trainer.run(train_loader, max_epochs=max_epochs)

    # Return the optimizer in case we want to continue training.
    return optimizer

In [None]:
# exports


@dataclass
class ModelOptimizerStateDicts:
    model_state_dict: object
    optimizer_state_dict: object


@dataclass
class DoubleSnapshots:
    high_accuracy: ModelOptimizerStateDicts
    low_cross_entropy: ModelOptimizerStateDicts


def train_double_snapshots(
    *,
    model: BayesianModule,
    training_samples,
    validation_samples,
    train_loader,
    validation_loader,
    patience: int,
    max_epochs: int,
    device: str,
    training_log: dict,
    loss=None,
    validation_loss=None,
    optimizer=None,
    train_augmentations: torch.nn.Module,
) -> DoubleSnapshots:
    if loss is None:
        loss = nn.NLLLoss()
    if validation_loss is None:
        validation_loss = loss

    train_model = SamplerModel(model, training_samples)
    if train_augmentations is not None:
        train_model = torch.nn.Sequential(train_augmentations, train_model)

    validation_model = GeometricMeanPrediction(SamplerModel(model, validation_samples))

    # Move model to device before creating the optimizer
    train_model.to(device)

    if optimizer is None:
        optimizer = torch.optim.Adam(model.parameters(), weight_decay=5e-4)

    trainer = create_supervised_trainer(train_model, optimizer, loss_fn=multi_sample_loss(loss), device=device)

    metrics = create_metrics(validation_loss)

    validation_evaluator = create_supervised_evaluator(validation_model, metrics=metrics, device=device)

    @trainer.on(Events.EPOCH_COMPLETED)
    def compute_metrics(engine):
        validation_evaluator.run(validation_loader)

    # Only to look nicer.
    RunningAverage(output_transform=lambda x: x).attach(trainer, "crossentropy")

    enable_tqdm_pbars = is_run_from_ipython()

    setup_common_training_handlers(
        trainer, with_pbars=enable_tqdm_pbars, with_gpu_stats=torch.cuda.is_available(), log_every_iters=LOG_INTERVAL
    )

    if enable_tqdm_pbars:
        ProgressBar(persist=False).attach(
            validation_evaluator,
            metric_names="all",
            event_name=Events.ITERATION_COMPLETED(every=LOG_INTERVAL),
        )
    else:
        ignite_progress_bar(trainer, desc=lambda engine: "Training", log_interval=LOG_INTERVAL)

    training_log["epochs"] = []
    epochs_log = training_log["epochs"]

    # Logging
    @validation_evaluator.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        metrics = dict(engine.state.metrics)
        epochs_log.append(metrics)

        if is_run_from_ipython():
            print(f"Epoch metrics: {metrics}")

    cross_entropy_state_dicts = None
    accuracy_state_dicts = None

    # Add early stopping
    if patience is not None:

        def cross_entropy_out_of_patience_callback(module_state_dict, optimizer_state_dict):
            nonlocal cross_entropy_state_dicts
            cross_entropy_state_dicts = ModelOptimizerStateDicts(
                model_state_dict=module_state_dict, optimizer_state_dict=optimizer_state_dict
            )
            if cross_entropy_pws.is_out_of_patience() and accuracy_entropy_pws.is_out_of_patience():
                trainer.terminate()

        cross_entropy_pws = PatienceWithSnapshot(
            name="LowCrossEntropy-",
            patience=patience,
            score_function=lambda: float(-validation_evaluator.state.metrics["crossentropy"]),
            module=model,
            optimizer=optimizer,
            training_engine=trainer,
            validation_engine=validation_evaluator,
            out_of_patience_callback=cross_entropy_out_of_patience_callback,
        )

        def accuracy_out_of_patience_callback(module_state_dict, optimizer_state_dict):
            nonlocal accuracy_state_dicts
            accuracy_state_dicts = ModelOptimizerStateDicts(
                model_state_dict=module_state_dict, optimizer_state_dict=optimizer_state_dict
            )
            if cross_entropy_pws.is_out_of_patience() and accuracy_entropy_pws.is_out_of_patience():
                trainer.terminate()

        accuracy_entropy_pws = PatienceWithSnapshot(
            name="Accuracy-",
            patience=patience,
            score_function=lambda: float(validation_evaluator.state.metrics["accuracy"]),
            module=model,
            optimizer=optimizer,
            training_engine=trainer,
            validation_engine=validation_evaluator,
            out_of_patience_callback=accuracy_out_of_patience_callback,
        )

    # Kick everything off
    trainer.run(train_loader, max_epochs=max_epochs)

    if cross_entropy_state_dicts:
        training_log["cross_entropy_best_epoch"] = cross_entropy_pws.best_epoch
    else:
        cross_entropy_state_dicts = ModelOptimizerStateDicts(
            model_state_dict=model.state_dict(), optimizer_state_dict=optimizer.state_dict()
        )

    if accuracy_state_dicts:
        training_log["cross_entropy_best_epoch"] = cross_entropy_pws.best_epoch
    else:
        accuracy_state_dicts = ModelOptimizerStateDicts(
            model_state_dict=model.state_dict(), optimizer_state_dict=optimizer.state_dict()
        )

    return DoubleSnapshots(accuracy_state_dicts, cross_entropy_state_dicts)

In [None]:
# exports


def evaluate(*, model: TrainedModel, loader, num_samples, device, storage_device, loss=None):
    log_probs_N_K_C, labels_N = model.get_log_probs_N_K_C_labels_N(loader=loader, num_samples=num_samples, device=device, storage_device=storage_device)


    if loss is None:
        loss = nn.NLLLoss()

    log_prob_mean_N_C = get_log_mean_probs(log_probs_N_K_C)
    crossentropy = loss(log_prob_mean_N_C, labels_N)
    accuracy = torch.sum(torch.eq(torch.argmax(log_prob_mean_N_C, dim=1), labels_N)).item() / len(labels_N)

    return dict(accuracy=accuracy, crossentropy=crossentropy)


def evaluate_old(*, model, num_samples, loader, device, loss=None):
    # TODO: rewrite this on top of TrainedModel?
    # Add "get_log_prob_predictions" which returns the mean?
    # Compute accuracy etc based on that?

    # Move model to device
    model.to(device)

    evaluation_model = GeometricMeanPrediction(SamplerModel(model, num_samples))

    if loss is None:
        loss = nn.NLLLoss()

    metrics = create_metrics(loss)

    evaluator = create_supervised_evaluator(evaluation_model, metrics=metrics, device=device)

    ProgressBar(persist=False).attach(
        evaluator,
        metric_names="all",
        event_name=Events.ITERATION_COMPLETED(every=LOG_INTERVAL),
    )

    # Kick everything off
    evaluator.run(loader, max_epochs=1)

    return evaluator.state.metrics


def create_metrics(loss):
    return {"accuracy": Accuracy(), "crossentropy": Loss(loss)}

We want to use metrics that allow us to capture the quality of the produced uncertainty during training.

In [None]:
# experiment

import torch.utils.data

from batchbald_redux.consistent_mc_dropout import GeometricMeanPrediction, SamplerModel
from batchbald_redux.dataset_challenges import create_repeated_MNIST_dataset
from batchbald_redux.models import BayesianMNISTCNN
from batchbald_redux.fast_mnist import FastMNIST

train_dataset, test_dataset = create_repeated_MNIST_dataset(num_repetitions=1, add_noise=False)

train_loader = torch.utils.data.DataLoader(train_dataset * 0.5, batch_size=64)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, drop_last=False)

In [None]:
# experiment

model = BayesianMNISTCNN()

training_log = {}

train(
    model=model,
    training_samples=1,
    validation_samples=4,
    train_loader=train_loader,
    validation_loader=train_loader,
    patience=3,
    max_epochs=3,
    device="cuda",
    training_log=training_log,
)

training_log

 33%|###3      | 1/3 [00:00<?, ?it/s]

[1/469]   0%|           [00:00<?]

[1/469]   0%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9594, 'crossentropy': 0.2248309172709783}


[1/469]   0%|           [00:00<?]

[1/469]   0%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9734666666666667, 'crossentropy': 0.1687226093451182}


[1/469]   0%|           [00:00<?]

[1/469]   0%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9788666666666667, 'crossentropy': 0.13679596647024156}
RestoringEarlyStopping: Restoring best parameters. (Score: 0.9788666666666667)
RestoringEarlyStopping: Restoring optimizer.


{'epochs': [{'accuracy': 0.9594, 'crossentropy': 0.2248309172709783},
  {'accuracy': 0.9734666666666667, 'crossentropy': 0.1687226093451182},
  {'accuracy': 0.9788666666666667, 'crossentropy': 0.13679596647024156}],
 'best_epoch': 3}

In [None]:
evaluate_old(model=model, num_samples=4, loader=test_loader, device="cuda")

[1/157]   1%|           [00:00<?]

{'accuracy': 0.9715, 'crossentropy': 0.1681874878913164}

In [None]:
# experiment

model = BayesianMNISTCNN()

training_log = {}

train_with_schedule(
    model=model,
    training_samples=1,
    validation_samples=4,
    train_loader=train_loader,
    validation_loader=test_loader,
    patience_schedule=[3, 3],
    factor_schedule=[0.1],
    max_epochs=60,
    device="cuda",
    training_log=training_log,
)

training_log

  2%|1         | 1/60 [00:00<?, ?it/s]

[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.8138, 'crossentropy': 0.8318569969177246}


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9051, 'crossentropy': 0.5015810224533082}


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9231, 'crossentropy': 0.4067144854307175}


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9376, 'crossentropy': 0.33887774846553803}


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9488, 'crossentropy': 0.284973241519928}


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9463, 'crossentropy': 0.2858008557677269}
Epoch 6: 0.9463 worse than 0.9488, patience: 1/3!


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9544, 'crossentropy': 0.2509216430902481}


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9573, 'crossentropy': 0.2428978340089321}


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9608, 'crossentropy': 0.22644489660859107}


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9608, 'crossentropy': 0.21278315509557724}
Epoch 10: 0.9608 worse than 0.9608, patience: 1/3!


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9618, 'crossentropy': 0.2145445026308298}


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.963, 'crossentropy': 0.21237999706566335}


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9631, 'crossentropy': 0.21369002146720886}


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9673, 'crossentropy': 0.188571187569201}


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9667, 'crossentropy': 0.1944719546556473}
Epoch 15: 0.9667 worse than 0.9673, patience: 1/3!


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9679, 'crossentropy': 0.19227623180747033}


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9673, 'crossentropy': 0.20758895537257194}
Epoch 17: 0.9673 worse than 0.9679, patience: 1/3!


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9689, 'crossentropy': 0.19289123542010783}


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9701, 'crossentropy': 0.18630976665019988}


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9695, 'crossentropy': 0.19547787671089173}
Epoch 20: 0.9695 worse than 0.9701, patience: 1/3!


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9706, 'crossentropy': 0.18417280072569847}


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9718, 'crossentropy': 0.18118949906229972}


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9694, 'crossentropy': 0.1888744118079543}
Epoch 23: 0.9694 worse than 0.9718, patience: 1/3!


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9683, 'crossentropy': 0.189695987880975}
Epoch 24: 0.9683 worse than 0.9718, patience: 2/3!


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9715, 'crossentropy': 0.1844632924541831}
Epoch 25: 0.9715 worse than 0.9718, patience: 3/3!


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9707, 'crossentropy': 0.19293912473917008}
Epoch 26: 0.9707 worse than 0.9718, patience: 4/3!
Epoch    26: reducing learning rate of group 0 to 1.0000e-04.


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9744, 'crossentropy': 0.16678278848044575}


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9732, 'crossentropy': 0.170559820997715}
Epoch 28: 0.9732 worse than 0.9744, patience: 1/3!


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9747, 'crossentropy': 0.16037208632528782}


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9751, 'crossentropy': 0.16294795689582825}


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9757, 'crossentropy': 0.1579607660241425}


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9749, 'crossentropy': 0.15656672970429064}
Epoch 32: 0.9749 worse than 0.9757, patience: 1/3!


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9742, 'crossentropy': 0.165166756830737}
Epoch 33: 0.9742 worse than 0.9757, patience: 2/3!


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9744, 'crossentropy': 0.16184359027966858}
Epoch 34: 0.9744 worse than 0.9757, patience: 3/3!


[1/47]   2%|2          [00:00<?]

[1/157]   1%|           [00:00<?]

Epoch metrics: {'accuracy': 0.9754, 'crossentropy': 0.15807696827538312}
Epoch 35: 0.9754 worse than 0.9757, patience: 4/3!


{'epochs': [{'accuracy': 0.8138, 'crossentropy': 0.8318569969177246},
  {'accuracy': 0.9051, 'crossentropy': 0.5015810224533082},
  {'accuracy': 0.9231, 'crossentropy': 0.4067144854307175},
  {'accuracy': 0.9376, 'crossentropy': 0.33887774846553803},
  {'accuracy': 0.9488, 'crossentropy': 0.284973241519928},
  {'accuracy': 0.9463, 'crossentropy': 0.2858008557677269},
  {'accuracy': 0.9544, 'crossentropy': 0.2509216430902481},
  {'accuracy': 0.9573, 'crossentropy': 0.2428978340089321},
  {'accuracy': 0.9608, 'crossentropy': 0.22644489660859107},
  {'accuracy': 0.9608, 'crossentropy': 0.21278315509557724},
  {'accuracy': 0.9618, 'crossentropy': 0.2145445026308298},
  {'accuracy': 0.963, 'crossentropy': 0.21237999706566335},
  {'accuracy': 0.9631, 'crossentropy': 0.21369002146720886},
  {'accuracy': 0.9673, 'crossentropy': 0.188571187569201},
  {'accuracy': 0.9667, 'crossentropy': 0.1944719546556473},
  {'accuracy': 0.9679, 'crossentropy': 0.19227623180747033},
  {'accuracy': 0.9673, 'cro

## Obtaining predictions

Sometimes, we want to obtain predictions from our models, instead of pure evaluation metrics... I know right?

The following helper method registers an event handler with an Ignite Engine that stores the predictions in a list:

TODO: MOVE THIS/REMOVE THIS

### Example

In [None]:
# experiment


predictions = get_predictions(model=model, num_samples=10, num_classes=10, loader=test_loader, device="cuda")
len(predictions)

get_predictions_labels:   0%|          | 0/100000 [00:00<?, ?it/s]

10000

In [None]:
# experiment

predictions, labels = get_predictions_labels(
    model=model, num_samples=7, num_classes=10, loader=test_loader, device="cuda"
)
predictions.shape, labels.shape

get_predictions_labels:   0%|          | 0/70000 [00:00<?, ?it/s]

(torch.Size([10000, 7, 10]), torch.Size([10000]))

In [None]:
predictions[:10]

tensor([[[-2.8367e+01, -2.4773e+01, -1.6460e+01, -1.5410e+01, -2.6516e+01,
          -2.5284e+01, -4.4725e+01, -5.9605e-07, -1.9745e+01, -1.5331e+01],
         [-2.0561e+01, -2.3827e+01, -1.9392e+01, -2.4907e+01, -2.7014e+01,
          -2.3035e+01, -3.5366e+01, -1.6689e-06, -2.3001e+01, -1.3284e+01],
         [-2.8653e+01, -2.2573e+01, -1.1184e+01, -1.6494e+01, -3.4791e+01,
          -2.9149e+01, -4.1066e+01, -1.4067e-05, -2.3491e+01, -2.0259e+01],
         [-1.8519e+01, -1.0138e+01, -1.1501e+01, -9.9980e+00, -1.2409e+01,
          -1.5788e+01, -2.3412e+01, -6.1302e-04, -1.0140e+01, -7.6546e+00],
         [-2.5508e+01, -1.6356e+01, -1.3611e+01, -1.8889e+01, -2.4214e+01,
          -3.2515e+01, -4.0044e+01, -1.3113e-06, -2.1938e+01, -2.1855e+01],
         [-2.7792e+01, -2.1009e+01, -1.6808e+01, -2.0931e+01, -2.7830e+01,
          -2.6055e+01, -3.9284e+01, -9.5367e-07, -2.4812e+01, -1.3849e+01],
         [-2.9338e+01, -2.4237e+01, -2.1119e+01, -2.8949e+01, -3.0423e+01,
          -3.5695e+