# Black Box Model Training
> No worries, no cry

After looking at a simple example experiment, it is worth looking at the big picture. The big picture requires us to train different models with differently-sized datasets.

We don't want to worry about fine-tuning training too much. Because we cannot.

In [1]:
# default_exp black_box_model_training

In [2]:
#hide
import blackhc.project.script
from nbdev.showdoc import *

Appended /home/blackhc/PycharmProjects/bald-ical/src to paths
Switched to directory /home/blackhc/PycharmProjects/bald-ical
%load_ext autoreload
%autoreload 2


## Goals

* Log as much as possible by default.
* Avoid magic numbers. Magic numbers don't work very well when everything keeps changing.

In [3]:
#exporti

from typing import Optional

import torch
from torch import nn

from ignite.contrib.engines.common import setup_common_training_handlers, \
    add_early_stopping_by_val_score
from ignite.contrib.handlers import TensorboardLogger
from ignite.contrib.handlers.tensorboard_logger import OutputHandler, OptimizerParamsHandler, \
    WeightsScalarHandler, WeightsHistHandler, GradsScalarHandler, GradsHistHandler
from ignite.engine import create_supervised_trainer, create_supervised_evaluator, Events
from ignite.metrics import Accuracy, Loss

In [4]:
#exports


LOG_INTERVAL = 10
HEAVY_LOG_INTERVAL = 100


def train(*, model, train_loader, val_loader, metric_loader,
          patience:int, max_epochs:int, tb_log_dir:Optional[str]=None, device:Optional[str]=None):
    """
    :param model:
    :param train_loader:
    :param val_loader:
    :param metric_loader: We compute metrics for debugging and introspection purposes with this data.
    :param patience: How many epochs to wait for early-stopping.
    :param max_epochs:
    :param tb_log_dir:
    :param device:
    :return: Optimizer that was used for training.
    """
    # Move model before creating optimizer

    model.to(device)

    optimizer = torch.optim.AdamW(model.parameters())
    criterion = nn.CrossEntropyLoss()

    trainer = create_supervised_trainer(model, optimizer, criterion, device=device)

    metrics = create_metrics(criterion)

    metric_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)
    validation_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)

    @trainer.on(Events.EPOCH_COMPLETED)
    def compute_metrics(engine):
        metric_evaluator.run(metric_loader)
        validation_evaluator.run(val_loader)

    setup_common_training_handlers(trainer, log_every_iters=LOG_INTERVAL)

    if tb_log_dir:
        tb_logger = TensorboardLogger(log_dir=tb_log_dir)
        configure_tb_logger(tb_logger, model, trainer, metric_evaluator, validation_evaluator)

    # Add early stopping
    add_early_stopping_by_val_score(patience, validation_evaluator, trainer, "accuracy")

    # kick everything off
    trainer.run(train_loader, max_epochs=max_epochs)
    tb_logger.close()

    # Return the optimizer in case we want to continue training.
    return optimizer


# Dummy before we come up with the actual metrics.

def create_metrics(criterion):
    return {"accuracy": Accuracy(), "loss": Loss(criterion)}


def configure_tb_logger(tb_logger, model, trainer, train_evaluator, validation_evaluator):
    def global_step_transform(_, __):
        return trainer.state.iteration

    # Compared to the default tb_logger behavior it is better to log everything using a single step counter
    # and log epoch numbers separately.
    tb_logger.attach(
            trainer,
            log_handler=lambda engine, logger, event_name: logger.writer.add_scalar("epoch", engine.state.epoch,
                                                                                    engine.state.iteration),
            event_name=Events.ITERATION_COMPLETED(every=LOG_INTERVAL),
        )

    # Log trainer metrics
    tb_logger.attach(
            trainer,
            log_handler=OutputHandler(
                tag="training", output_transform=lambda loss: {"batchloss": loss}, metric_names="all",
                global_step_transform=global_step_transform
            ),
            event_name=Events.ITERATION_COMPLETED(every=LOG_INTERVAL),
        )

    # Log validation evaluator metrics.
    tb_logger.attach(
            validation_evaluator,
            log_handler=OutputHandler(tag="validation", metric_names="all", global_step_transform=global_step_transform),
            event_name=Events.EPOCH_COMPLETED,
        )

    # Log weights and gradients.
    tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model),
                         event_name=Events.ITERATION_COMPLETED(every=HEAVY_LOG_INTERVAL))
    tb_logger.attach(trainer, log_handler=WeightsHistHandler(model),
                         event_name=Events.ITERATION_COMPLETED(every=HEAVY_LOG_INTERVAL))
    tb_logger.attach(trainer, log_handler=GradsScalarHandler(model),
                         event_name=Events.ITERATION_COMPLETED(every=HEAVY_LOG_INTERVAL))
    tb_logger.attach(trainer, log_handler=GradsHistHandler(model),
                         event_name=Events.ITERATION_COMPLETED(every=HEAVY_LOG_INTERVAL))

We want to use metrics that allow us to capture the quality of the produced uncertainty during training.

In [8]:
# experiment

from batchbald_redux.fast_mnist import FastMNIST
from batchbald_redux.consistent_mc_dropout import SamplerModel
from batchbald_redux.example_models import BayesianMNISTCNN
from batchbald_redux.repeated_mnist import create_repeated_MNIST_dataset
import torch.utils.data

train_dataset, test_dataset = create_repeated_MNIST_dataset(num_repetitions=1, add_noise=False)

model = SamplerModel(BayesianMNISTCNN(), 1)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=64
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=64
)

train(model=model, 
      train_loader=train_loader,
      val_loader=train_loader, 
      metric_loader=test_loader, 
      patience=3, 
      max_epochs=10, 
      tb_log_dir=None, 
      device="cuda")


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=938.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=938.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=938.0), HTML(value='')))

Current run is terminating due to exception: .
Engine run is terminating due to exception: .


KeyboardInterrupt: 

## Uncertainty metric idea

Probability of being correct vs confidence!
And probability of being correct histogram and cumulative etc.



