# Black Box Model Training
> "Learning is not attained by chance, it must be sought for with ardor and attended to with diligence.” 
>
> &mdash; <cite>Abigail Adams</cite>.

After looking at a simple example experiment, it is worth looking at the big picture. The big picture requires us to train different models with differently-sized datasets.

We don't want to worry about fine-tuning training too much. Because we cannot.

In [None]:
# default_exp black_box_model_training

In [None]:
# hide
import blackhc.project.script
from nbdev.showdoc import *

Appended /home/blackhc/PycharmProjects/bald-ical/src to paths
Switched to directory /home/blackhc/PycharmProjects/bald-ical
%load_ext autoreload
%autoreload 2


## Goals

* Log as much as possible by default.
* Avoid magic numbers. Magic numbers don't work very well when everything keeps changing.

In [None]:
# exporti

from typing import Optional

import torch
from ignite.contrib.engines.common import (
    add_early_stopping_by_val_score,
    setup_common_training_handlers,
)
from ignite.contrib.handlers import ProgressBar
from ignite.engine import Events, create_supervised_evaluator, create_supervised_trainer
from ignite.metrics import Accuracy, Loss, RunningAverage
from torch import nn

from batchbald_redux.consistent_mc_dropout import (
    GeometricMeanPrediction,
    SamplerModel,
    multi_sample_loss,
)
from batchbald_redux.restoring_early_stopping import RestoringEarlyStopping

In [None]:
# exports


LOG_INTERVAL = 10


def train(
    *,
    model,
    training_samples,
    validation_samples,
    train_loader,
    validation_loader,
    patience: Optional[int],
    max_epochs: int,
    device: str,
    training_log: dict,
    loss=None,
    validation_loss=None,
    optimizer=None
):
    """
    :param model:
    :param train_loader:
    :param val_loader:
    :param metric_loader: We compute metrics for debugging and introspection purposes with this data.
    :param patience: How many epochs to wait for early-stopping.
    :param max_epochs:
    :param tb_log_dir:
    :param device:
    :return: Optimizer that was used for training.
    """
    if loss is None:
        loss = nn.NLLLoss()
    if validation_loss is None:
        validation_loss = loss

    train_model = SamplerModel(model, training_samples)
    validation_model = GeometricMeanPrediction(SamplerModel(model, validation_samples))

    # Move model to device before creating the optimizer
    train_model.to(device)

    if optimizer is None:
        optimizer = torch.optim.Adam(model.parameters(), weight_decay=5e-4)

    trainer = create_supervised_trainer(train_model, optimizer, multi_sample_loss(loss), device=device)

    metrics = create_metrics(validation_loss)

    validation_evaluator = create_supervised_evaluator(validation_model, metrics=metrics, device=device)

    @trainer.on(Events.EPOCH_COMPLETED)
    def compute_metrics(engine):
        validation_evaluator.run(validation_loader)

    # Only to look nicer.
    RunningAverage(output_transform=lambda x: x).attach(trainer, "crossentropy")

    setup_common_training_handlers(trainer, with_gpu_stats=torch.cuda.is_available(), log_every_iters=LOG_INTERVAL)

    ProgressBar(persist=False).attach(
        validation_evaluator,
        metric_names="all",
        event_name=Events.ITERATION_COMPLETED(every=LOG_INTERVAL),
    )

    training_log["epochs"] = []
    epochs_log = training_log["epochs"]

    # Logging
    @validation_evaluator.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        metrics = dict(engine.state.metrics)
        epochs_log.append(metrics)

    # Add early stopping
    if patience is not None:
        early_stopping = RestoringEarlyStopping(
            patience=patience,
            score_function=lambda: float(-validation_evaluator.state.metrics["crossentropy"]),
            module=model,
            optimizer=optimizer,
            training_engine=trainer,
            validation_engine=validation_evaluator,
        )
    else:
        early_stopping = None

    # Kick everything off
    trainer.run(train_loader, max_epochs=max_epochs)

    if early_stopping:
        training_log["best_epoch"] = early_stopping.best_epoch

    # Return the optimizer in case we want to continue training.
    return optimizer


def evaluate(*, model, num_samples, loader, device, loss=None):
    evaluation_model = GeometricMeanPrediction(SamplerModel(model, num_samples))

    if loss is None:
        loss = nn.NLLLoss()

    metrics = create_metrics(loss)

    evaluator = create_supervised_evaluator(evaluation_model, metrics=metrics, device=device)

    ProgressBar(persist=False).attach(
        evaluator,
        metric_names="all",
        event_name=Events.ITERATION_COMPLETED(every=LOG_INTERVAL),
    )

    # Kick everything off
    evaluator.run(loader, max_epochs=1)

    return evaluator.state.metrics


def create_metrics(loss):
    return {"accuracy": Accuracy(), "crossentropy": Loss(loss)}

We want to use metrics that allow us to capture the quality of the produced uncertainty during training.

In [None]:
# experiment

import torch.utils.data

from batchbald_redux.consistent_mc_dropout import GeometricMeanPrediction, SamplerModel
from batchbald_redux.example_models import BayesianMNISTCNN
from batchbald_redux.fast_mnist import FastMNIST
from batchbald_redux.repeated_mnist import create_repeated_MNIST_dataset

train_dataset, test_dataset = create_repeated_MNIST_dataset(num_repetitions=1, add_noise=False)

model = BayesianMNISTCNN()

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, drop_last=False)

training_log = {}

train(
    model=model,
    training_samples=1,
    validation_samples=4,
    train_loader=train_loader,
    validation_loader=train_loader,
    patience=3,
    max_epochs=3,
    device="cuda",
    training_log=training_log,
)

training_log

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=938.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=938.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=938.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=938.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=938.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=938.0), HTML(value='')))


RestoringEarlyStopping: Restoring best parameters. (Score: -0.11676170641109347)
RestoringEarlyStopping: Restoring optimizer.


{'epochs': [{'accuracy': 0.9723333333333334,
   'crossentropy': 0.1676989432533582},
  {'accuracy': 0.9791833333333333, 'crossentropy': 0.13145085282996297},
  {'accuracy': 0.98265, 'crossentropy': 0.11676170641109347}],
 'best_epoch': 3}

In [None]:
evaluate(model=model, num_samples=4, loader=test_loader, device="cuda")

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=157.0), HTML(value='')))

{'accuracy': 0.9844, 'crossentropy': 0.10267819431312382}

## Obtaining predictions

Sometimes, we want to obtain predictions from our models, instead of pure evaluation metrics... I know right?

The following helper method registers an event handler with an Ignite Engine that stores the predictions in a list:

In [None]:
# exports


def handler_save_predictions(engine, target_list):
    @engine.on(Events.ITERATION_COMPLETED)
    def iteration_completed(engine):
        target_list.extend(engine.state.output[0])


# TODO: ought to add support for toma here (and large k)
def get_predictions(*, model, loader, device: str):
    evaluator = create_supervised_evaluator(model, device=device)

    predictions = []
    handler_save_predictions(evaluator, predictions)

    ProgressBar(persist=False).attach(
        evaluator,
        metric_names="all",
        event_name=Events.ITERATION_COMPLETED(every=LOG_INTERVAL),
    )

    evaluator.run(loader)

    predictions = torch.stack(predictions)
    return predictions

### Example

In [None]:
# experiment

predictions = get_predictions(model=model, loader=test_loader, device="cuda")
len(predictions)

In [None]:
predictions[:10]

In [None]:
%config Completer.use_jedi = False
%config Completer.jedi_compute_type_timeout = 10000