# Black Box Model Training
> "Learning is not attained by chance, it must be sought for with ardor and attended to with diligence.” 
>
> &mdash; <cite>Abigail Adams</cite>.

After looking at a simple example experiment, it is worth looking at the big picture. The big picture requires us to train different models with differently-sized datasets.

We don't want to worry about fine-tuning training too much. Because we cannot.

In [None]:
# default_exp black_box_model_training

In [None]:
# hide
import blackhc.project.script
from nbdev.showdoc import *

Appended /home/blackhc/PycharmProjects/bald-ical/src to paths
Switched to directory /home/blackhc/PycharmProjects/bald-ical
%load_ext autoreload
%autoreload 2


## Goals

* Log as much as possible by default.
* Avoid magic numbers. Magic numbers don't work very well when everything keeps changing.

In [None]:
# exporti

from typing import Optional

import torch
from ignite.contrib.engines.common import (
    add_early_stopping_by_val_score,
    setup_common_training_handlers,
)
from ignite.contrib.handlers import ProgressBar
from ignite.engine import Events, create_supervised_evaluator, create_supervised_trainer
from ignite.metrics import Accuracy, Loss, RunningAverage
from toma import toma
from torch import nn
from tqdm.auto import tqdm

from batchbald_redux.consistent_mc_dropout import (
    GeometricMeanPrediction,
    SamplerModel,
    multi_sample_loss,
)
from batchbald_redux.restoring_early_stopping import RestoringEarlyStopping

In [None]:
# exports


LOG_INTERVAL = 10


def train(
    *,
    model,
    training_samples,
    validation_samples,
    train_loader,
    validation_loader,
    patience: Optional[int],
    max_epochs: int,
    device: str,
    training_log: dict,
    loss=None,
    validation_loss=None,
    optimizer=None
):
    """
    :param model:
    :param train_loader:
    :param val_loader:
    :param metric_loader: We compute metrics for debugging and introspection purposes with this data.
    :param patience: How many epochs to wait for early-stopping.
    :param max_epochs:
    :param tb_log_dir:
    :param device:
    :return: Optimizer that was used for training.
    """
    if loss is None:
        loss = nn.NLLLoss()
    if validation_loss is None:
        validation_loss = loss

    train_model = SamplerModel(model, training_samples)
    validation_model = GeometricMeanPrediction(SamplerModel(model, validation_samples))

    # Move model to device before creating the optimizer
    train_model.to(device)

    if optimizer is None:
        optimizer = torch.optim.Adam(model.parameters(), weight_decay=5e-4)

    trainer = create_supervised_trainer(train_model, optimizer, multi_sample_loss(loss), device=device)

    metrics = create_metrics(validation_loss)

    validation_evaluator = create_supervised_evaluator(validation_model, metrics=metrics, device=device)

    @trainer.on(Events.EPOCH_COMPLETED)
    def compute_metrics(engine):
        validation_evaluator.run(validation_loader)

    # Only to look nicer.
    RunningAverage(output_transform=lambda x: x).attach(trainer, "crossentropy")

    setup_common_training_handlers(trainer, with_gpu_stats=torch.cuda.is_available(), log_every_iters=LOG_INTERVAL)

    ProgressBar(persist=False).attach(
        validation_evaluator,
        metric_names="all",
        event_name=Events.ITERATION_COMPLETED(every=LOG_INTERVAL),
    )

    training_log["epochs"] = []
    epochs_log = training_log["epochs"]

    # Logging
    @validation_evaluator.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        metrics = dict(engine.state.metrics)
        epochs_log.append(metrics)

    # Add early stopping
    if patience is not None:
        early_stopping = RestoringEarlyStopping(
            patience=patience,
            score_function=lambda: float(-validation_evaluator.state.metrics["crossentropy"]),
            module=model,
            optimizer=optimizer,
            training_engine=trainer,
            validation_engine=validation_evaluator,
        )
    else:
        early_stopping = None

    # Kick everything off
    trainer.run(train_loader, max_epochs=max_epochs)

    if early_stopping:
        training_log["best_epoch"] = early_stopping.best_epoch

    # Return the optimizer in case we want to continue training.
    return optimizer


def evaluate(*, model, num_samples, loader, device, loss=None):
    evaluation_model = GeometricMeanPrediction(SamplerModel(model, num_samples))

    if loss is None:
        loss = nn.NLLLoss()

    metrics = create_metrics(loss)

    evaluator = create_supervised_evaluator(evaluation_model, metrics=metrics, device=device)

    ProgressBar(persist=False).attach(
        evaluator,
        metric_names="all",
        event_name=Events.ITERATION_COMPLETED(every=LOG_INTERVAL),
    )

    # Kick everything off
    evaluator.run(loader, max_epochs=1)

    return evaluator.state.metrics


def create_metrics(loss):
    return {"accuracy": Accuracy(), "crossentropy": Loss(loss)}

We want to use metrics that allow us to capture the quality of the produced uncertainty during training.

In [None]:
# experiment

import torch.utils.data

from batchbald_redux.consistent_mc_dropout import GeometricMeanPrediction, SamplerModel
from batchbald_redux.example_models import BayesianMNISTCNN
from batchbald_redux.fast_mnist import FastMNIST
from batchbald_redux.repeated_mnist import create_repeated_MNIST_dataset

train_dataset, test_dataset = create_repeated_MNIST_dataset(num_repetitions=1, add_noise=False)

model = BayesianMNISTCNN()

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, drop_last=False)

training_log = {}

train(
    model=model,
    training_samples=1,
    validation_samples=4,
    train_loader=train_loader,
    validation_loader=train_loader,
    patience=3,
    max_epochs=3,
    device="cuda",
    training_log=training_log,
)

training_log

 33%|###3      | 1/3 [00:00<?, ?it/s]

[1/938]   0%|           [00:00<?]

Engine run is terminating due to exception: .


KeyboardInterrupt: 

In [None]:
evaluate(model=model, num_samples=4, loader=test_loader, device="cuda")

[1/157]   1%|           [00:00<?]

{'accuracy': 0.9715, 'crossentropy': 0.1681874878913164}

## Obtaining predictions

Sometimes, we want to obtain predictions from our models, instead of pure evaluation metrics... I know right?

The following helper method registers an event handler with an Ignite Engine that stores the predictions in a list:

In [None]:
# exports


@torch.no_grad()
def get_predictions_labels(*, model, num_samples, num_classes, loader, device: str):
    model.to(device=device)

    N = len(loader.dataset)
    predictions = torch.empty((N, num_samples, num_classes), dtype=torch.float, device="cpu")
    labels = torch.empty(N, dtype=torch.long, device="cpu")

    pbar = tqdm(total=N * num_samples, desc="get_predictions_labels", leave=False)

    @toma.execute.range(0, num_samples, 128)
    def get_prediction_batch(start, end):
        if start == 0:
            pbar.reset()

        model.eval()

        prediction_model = SamplerModel(model, end - start)

        data_start = 0
        for batch_x, batch_labels in loader:
            batch_x = batch_x.to(device=device)

            batch_predictions = prediction_model(batch_x)

            batch_size = len(batch_predictions)
            data_end = data_start + batch_size

            predictions[data_start:data_end, start:end].copy_(batch_predictions.float(), non_blocking=True)
            if start == 0:
                labels[data_start:data_end].copy_(batch_labels.long(), non_blocking=True)
            else:
                assert labels[data_start:data_end] == batch_labels.long()

            data_start = data_end

            pbar.update(batch_size * (end - start))

    pbar.close()

    return predictions, labels


def get_predictions(*, model, num_samples, num_classes, loader, device: str):
    predictions, _ = get_predictions_labels(
        model=model, num_samples=num_samples, num_classes=num_classes, loader=loader, device=device
    )

    return predictions

### Example

In [None]:
# experiment

predictions = get_predictions(model=model, num_samples=10, num_classes=10, loader=test_loader, device="cuda")
len(predictions)

get_predictions_labels:   0%|          | 0/100000 [00:00<?, ?it/s]

10000

In [None]:
# experiment

predictions, labels = get_predictions_labels(model=model, num_samples=7, num_classes=10, loader=test_loader, device="cuda")
predictions.shape, labels.shape

get_predictions_labels:   0%|          | 0/70000 [00:00<?, ?it/s]

(torch.Size([10000, 7, 10]), torch.Size([10000]))

In [None]:
predictions[:10]

In [None]:
%config Completer.use_jedi = False
%config Completer.jedi_compute_type_timeout = 10000