# Trained Model Interface
> "Why simple, when you can use design patterns?"

In [None]:
# default_exp train_eval_model

In [None]:
# hide
import blackhc.project.script
from nbdev.showdoc import *

In [None]:
# exports
from dataclasses import dataclass
from typing import Type

import torch
import torch.nn
import torch.utils.data
from torch import nn

from batchbald_redux.active_learning import RandomFixedLengthSampler
from batchbald_redux.black_box_model_training import train, train_with_schedule
from batchbald_redux.consistent_mc_dropout import get_log_mean_probs
from batchbald_redux.dataset_challenges import (
    RandomLabelsDataset,
    ReplaceTargetsDataset,
)
from batchbald_redux.model_optimizer_factory import ModelOptimizerFactory
from batchbald_redux.trained_model import TrainedMCDropoutModel, TrainedModel

In [None]:
# exports


class TrainEvalModel:
    def __call__(self, *, training_log, device) -> TrainedModel:
        raise NotImplementedError()


@dataclass
class TrainSelfDistillationEvalModel(TrainEvalModel):
    num_pool_samples: int
    num_training_samples: int
    num_validation_samples: int
    num_patience_epochs: int
    max_epochs: int
    training_dataset: torch.utils.data.Dataset
    eval_dataset: torch.utils.data.Dataset
    validation_loader: torch.utils.data.DataLoader
    training_batch_size: int
    model_optimizer_factory: Type[ModelOptimizerFactory]
    trained_model: TrainedModel
    min_samples_per_epoch: int
    # TODO: remove the default?
    train_augmentations: nn.Module = None
    # TODO: remove the default!
    prefer_accuracy: bool = True

    def __call__(self, *, training_log, device):
        train_eval_dataset = torch.utils.data.ConcatDataset([self.training_dataset, self.eval_dataset])
        train_eval_loader = torch.utils.data.DataLoader(train_eval_dataset, batch_size=512, drop_last=False)

        eval_log_probs_N_C = get_log_mean_probs(
            self.trained_model.get_log_probs_N_K_C(train_eval_loader, device=device)
        )

        eval_self_distillation_dataset = ReplaceTargetsDataset(dataset=train_eval_dataset, targets=eval_log_probs_N_C)
        train_eval_self_distillation_loader = torch.utils.data.DataLoader(
            eval_self_distillation_dataset,
            batch_size=self.training_batch_size,
            sampler=RandomFixedLengthSampler(eval_self_distillation_dataset, self.min_samples_per_epoch),
            drop_last=True,
        )

        eval_model_optimizer = self.model_optimizer_factory().create_model_optimizer()

        loss = torch.nn.KLDivLoss(log_target=True, reduction="batchmean")

        train(
            model=eval_model_optimizer.model,
            optimizer=eval_model_optimizer.optimizer,
            train_augmentations=self.train_augmentations,
            loss=loss,
            validation_loss=torch.nn.NLLLoss(),
            training_samples=self.num_training_samples,
            validation_samples=self.num_validation_samples,
            train_loader=train_eval_self_distillation_loader,
            validation_loader=self.validation_loader,
            patience=self.num_patience_epochs,
            max_epochs=self.max_epochs,
            prefer_accuracy=self.prefer_accuracy,
            device=device,
            training_log=training_log,
        )

        return TrainedMCDropoutModel(num_samples=self.num_pool_samples, model=eval_model_optimizer.model)


@dataclass
class TrainSelfDistillationEvalModelWithSchedule(TrainEvalModel):
    num_pool_samples: int
    num_training_samples: int
    num_validation_samples: int
    patience_schedule: [int]
    factor_schedule: [int]
    max_epochs: int
    training_dataset: torch.utils.data.Dataset
    eval_dataset: torch.utils.data.Dataset
    validation_loader: torch.utils.data.DataLoader
    training_batch_size: int
    model_optimizer_factory: Type[ModelOptimizerFactory]
    trained_model: TrainedModel
    min_samples_per_epoch: int
    prefer_accuracy: bool
    # TODO: remove the default?
    train_augmentations: nn.Module = None

    def __call__(self, *, training_log, device):
        train_eval_dataset = torch.utils.data.ConcatDataset([self.training_dataset, self.eval_dataset])
        train_eval_loader = torch.utils.data.DataLoader(train_eval_dataset, batch_size=512, drop_last=False)

        eval_log_probs_N_C = get_log_mean_probs(
            self.trained_model.get_log_probs_N_K_C(train_eval_loader, device=device)
        )

        eval_self_distillation_dataset = ReplaceTargetsDataset(dataset=train_eval_dataset, targets=eval_log_probs_N_C)
        train_eval_self_distillation_loader = torch.utils.data.DataLoader(
            eval_self_distillation_dataset,
            batch_size=self.training_batch_size,
            sampler=RandomFixedLengthSampler(eval_self_distillation_dataset, self.min_samples_per_epoch),
            drop_last=True,
        )

        eval_model_optimizer = self.model_optimizer_factory().create_model_optimizer()

        loss = torch.nn.KLDivLoss(log_target=True, reduction="batchmean")

        train_with_schedule(
            model=eval_model_optimizer.model,
            optimizer=eval_model_optimizer.optimizer,
            loss=loss,
            train_augmentations=self.train_augmentations,
            validation_loss=torch.nn.NLLLoss(),
            training_samples=self.num_training_samples,
            validation_samples=self.num_validation_samples,
            train_loader=train_eval_self_distillation_loader,
            validation_loader=self.validation_loader,
            patience_schedule=self.patience_schedule,
            factor_schedule=self.factor_schedule,
            max_epochs=self.max_epochs,
            prefer_accuracy=self.prefer_accuracy,
            device=device,
            training_log=training_log,
        )

        return TrainedMCDropoutModel(num_samples=self.num_pool_samples, model=eval_model_optimizer.model)


@dataclass
class TrainRandomLabelEvalModel(TrainEvalModel):
    num_pool_samples: int
    num_training_samples: int
    num_validation_samples: int
    num_patience_epochs: int
    max_epochs: int
    training_dataset: torch.utils.data.Dataset
    eval_dataset: torch.utils.data.Dataset
    validation_loader: torch.utils.data.DataLoader
    training_batch_size: int
    model_optimizer_factory: ModelOptimizerFactory

    def __call__(self, *, training_log, device):
        # TODO: support one_hot!
        # TODO: different seed needed!
        train_eval_dataset = torch.utils.data.ConcatDataset(
            [self.training_dataset, RandomLabelsDataset(self.eval_dataset, seed=0)]
        )
        train_eval_loader = torch.utils.data.DataLoader(
            train_eval_dataset, batch_size=self.training_batch_size, drop_last=True, shuffle=True
        )

        eval_model_optimizer = self.model_optimizer_factory.create_model_optimizer()

        loss = torch.nn.NLLLoss()

        train(
            model=eval_model_optimizer.model,
            optimizer=eval_model_optimizer.optimizer,
            loss=loss,
            validation_loss=loss,
            training_samples=self.num_training_samples,
            validation_samples=self.num_validation_samples,
            train_loader=train_eval_loader,
            validation_loader=self.validation_loader,
            patience=self.num_patience_epochs,
            max_epochs=self.max_epochs,
            device=device,
            training_log=training_log,
        )

        return TrainedMCDropoutModel(num_samples=self.num_pool_samples, model=eval_model_optimizer.model)


@dataclass
class TrainExplicitEvalModel(TrainEvalModel):
    num_pool_samples: int
    num_training_samples: int
    num_validation_samples: int
    num_patience_epochs: int
    max_epochs: int
    training_dataset: torch.utils.data.Dataset
    eval_dataset: torch.utils.data.Dataset
    validation_loader: torch.utils.data.DataLoader
    training_batch_size: int
    model_optimizer_factory: ModelOptimizerFactory

    def __call__(self, *, training_log, device):
        # TODO: support one_hot!
        # TODO: different seed needed!
        train_eval_dataset = torch.utils.data.ConcatDataset([self.training_dataset, self.eval_dataset])
        train_eval_loader = torch.utils.data.DataLoader(
            train_eval_dataset, batch_size=self.training_batch_size, drop_last=True, shuffle=True
        )

        eval_model_optimizer = self.model_optimizer_factory.create_model_optimizer()

        loss = torch.nn.NLLLoss()

        train(
            model=eval_model_optimizer.model,
            optimizer=eval_model_optimizer.optimizer,
            loss=loss,
            validation_loss=loss,
            training_samples=self.num_training_samples,
            validation_samples=self.num_validation_samples,
            train_loader=train_eval_loader,
            validation_loader=self.validation_loader,
            patience=self.num_patience_epochs,
            max_epochs=self.max_epochs,
            device=device,
            training_log=training_log,
        )

        return TrainedMCDropoutModel(num_samples=self.num_pool_samples, model=eval_model_optimizer.model)