# Black Box Model Training
> No worries, no cry

After looking at a simple example experiment, it is worth looking at the big picture. The big picture requires us to train different models with differently-sized datasets.

We don't want to worry about fine-tuning training too much. Because we cannot.

In [1]:
# default_exp black_box_model_training

In [1]:
#hide
import blackhc.project.script
from nbdev.showdoc import *

Appended /home/blackhc/PycharmProjects/bald-ical/src to paths
Switched to directory /home/blackhc/PycharmProjects/bald-ical
%load_ext autoreload
%autoreload 2


## Goals

* Log as much as possible by default.
* Avoid magic numbers. Magic numbers don't work very well when everything keeps changing.

In [10]:
#exporti

import torch
from torch import nn

from ignite.contrib.engines.common import setup_common_training_handlers, \
    add_early_stopping_by_val_score
from ignite.contrib.handlers import ProgressBar
from ignite.engine import create_supervised_trainer, create_supervised_evaluator, Events
from ignite.metrics import Accuracy, Loss, RunningAverage
from ignite.utils import apply_to_tensor

In [25]:
#exports


LOG_INTERVAL = 10
HEAVY_LOG_INTERVAL = 100


def train(*, model, train_loader, val_loader,
          patience:int, max_epochs:int, device:str, epochs_log:list, loss=None):
    """
    :param model:
    :param train_loader:
    :param val_loader:
    :param metric_loader: We compute metrics for debugging and introspection purposes with this data.
    :param patience: How many epochs to wait for early-stopping.
    :param max_epochs:
    :param tb_log_dir:
    :param device:
    :return: Optimizer that was used for training.
    """
    if loss is None:
        loss = nn.CrossEntropyLoss()

    # Move model before creating optimizer
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), weight_decay=5e-4)    
    
    trainer = create_supervised_trainer(model, optimizer, loss, device=device)

    metrics = create_metrics(loss)

    validation_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)

    @trainer.on(Events.EPOCH_COMPLETED)
    def compute_metrics(engine):
        validation_evaluator.run(val_loader)
        
    RunningAverage(output_transform=lambda x: x).attach(trainer, 'crossentropy')
        
    setup_common_training_handlers(trainer, with_gpu_stats=True, log_every_iters=LOG_INTERVAL)
        

    ProgressBar(persist=False).attach(validation_evaluator, metric_names="all",
                                      event_name=Events.ITERATION_COMPLETED(every=LOG_INTERVAL))
    
    @validation_evaluator.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        metrics = engine.state.metrics
        epochs_log.append(str(metrics))
        
    # Add early stopping
    add_early_stopping_by_val_score(patience, validation_evaluator, trainer, "accuracy")

    # kick everything off
    trainer.run(train_loader, max_epochs=max_epochs)

    # Return the optimizer in case we want to continue training.
    return optimizer


def create_metrics(loss):
    return {"accuracy": Accuracy(), "crossentropy": Loss(loss)}


We want to use metrics that allow us to capture the quality of the produced uncertainty during training.

In [26]:
# experiment

from batchbald_redux.fast_mnist import FastMNIST
from batchbald_redux.consistent_mc_dropout import SamplerModel
from batchbald_redux.example_models import BayesianMNISTCNN
from batchbald_redux.repeated_mnist import create_repeated_MNIST_dataset
import torch.utils.data

train_dataset, test_dataset = create_repeated_MNIST_dataset(num_repetitions=1, add_noise=False)

model = SamplerModel(BayesianMNISTCNN(), 1)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=64
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=64
)

epochs_log = []

train(model=model, 
      train_loader=train_loader,
      val_loader=train_loader, 
      patience=3, 
      max_epochs=3, 
      device="cuda",
      epochs_log=epochs_log)

epochs_log

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=938.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=938.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=938.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=938.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=938.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=938.0), HTML(value='')))




["{'accuracy': 0.9507, 'crossentropy': 0.16540256378240883}",
 "{'accuracy': 0.9603833333333334, 'crossentropy': 0.13504049125565215}",
 "{'accuracy': 0.9656833333333333, 'crossentropy': 0.11748369302990226}"]

## Obtaining predictions

Sometimes, we want to obtain predictions from our models, instead of pure evaluation metrics... I know right?

The following helper method registers an event handler with an Ignite Engine that stores the predictions in a list:

In [27]:
#exports

def handler_save_predictions(engine, target_list):
    @engine.on(Events.ITERATION_COMPLETED)
    def iteration_completed(engine):
        target_list.extend(engine.state.output[0])
        
def get_predictions(*, model, loader, device:str):
    evaluator = create_supervised_evaluator(model, device=device)
    
    predictions = []
    handler_save_predictions(evaluator, predictions)
    
    ProgressBar(persist=False).attach(evaluator, metric_names="all",
                                      event_name=Events.ITERATION_COMPLETED(every=LOG_INTERVAL))
    
    evaluator.run(loader)
    
    return predictions

### Example

In [28]:
# experiment

predictions = get_predictions(model=model, loader=test_loader, device="cuda")
len(predictions)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=157.0), HTML(value='')))

10000

In [30]:
predictions[:10]

[tensor([-52.1431, -53.4922, -40.7442, -40.1009, -44.2619, -44.9161, -78.0699,
           0.0000, -48.4833, -27.3636], device='cuda:0'),
 tensor([-22.3552, -21.4889,   0.0000, -25.8736, -25.5011, -29.1754, -20.9751,
         -36.6820, -22.6187, -39.2213], device='cuda:0'),
 tensor([-2.6418e+01, -5.9605e-07, -2.4454e+01, -2.8485e+01, -1.4397e+01,
         -1.9134e+01, -2.5368e+01, -1.7937e+01, -1.9470e+01, -1.7325e+01],
        device='cuda:0'),
 tensor([-1.5497e-05, -2.3048e+01, -1.8644e+01, -2.3828e+01, -1.9778e+01,
         -1.5977e+01, -1.1162e+01, -2.2100e+01, -1.5301e+01, -1.3860e+01],
        device='cuda:0'),
 tensor([-2.9340e+01, -3.4066e+01, -2.7912e+01, -2.5837e+01, -5.5762e-04,
         -2.1208e+01, -2.4429e+01, -2.7754e+01, -1.8389e+01, -7.4921e+00],
        device='cuda:0'),
 tensor([-2.8714e+01, -1.1921e-07, -2.8676e+01, -3.2414e+01, -1.5823e+01,
         -2.2883e+01, -2.9513e+01, -1.7877e+01, -2.2081e+01, -1.8668e+01],
        device='cuda:0'),
 tensor([-4.2945e+01, -2.8