# Some Models
> To avoid copy-pasta #2

In [None]:
# default_exp models

In [None]:
# exports
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn
import torch.optim
from torch import nn as nn
from torch.nn import functional as F, Module
from torch.utils.data import DataLoader

from batchbald_redux.black_box_model_training import train
from batchbald_redux.consistent_mc_dropout import (
    BayesianModule,
    ConsistentMCDropout,
    ConsistentMCDropout2d,
)

from batchbald_redux.model_optimizer_factory import ModelOptimizer, ModelOptimizerFactory

In [None]:
# exports
from batchbald_redux.trained_model import ModelTrainer, TrainedModel, TrainedBayesianModel


class BayesianMNISTCNN(BayesianModule):
    def __init__(self, num_classes=10):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv1_drop = ConsistentMCDropout2d()
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.conv2_drop = ConsistentMCDropout2d()
        self.fc1 = nn.Linear(1024, 128)
        self.fc1_drop = ConsistentMCDropout()
        self.fc2 = nn.Linear(128, num_classes)

    def mc_forward_impl(self, input: torch.Tensor):
        input = F.relu(F.max_pool2d(self.conv1_drop(self.conv1(input)), 2))
        input = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(input)), 2))
        input = input.view(-1, 1024)
        input = F.relu(self.fc1_drop(self.fc1(input)))
        input = self.fc2(input)
        input = F.log_softmax(input, dim=1)

        return input

In [None]:
BayesianMNISTCNN()

BayesianMNISTCNN(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv1_drop): ConsistentMCDropout2d(p=0.5)
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): ConsistentMCDropout2d(p=0.5)
  (fc1): Linear(in_features=1024, out_features=128, bias=True)
  (fc1_drop): ConsistentMCDropout(p=0.5)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

In [None]:
# exports


class BayesianMNISTCNN_EBM(BayesianModule):
    """Without Softmax."""

    def __init__(self, num_classes=10):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv1_drop = ConsistentMCDropout2d()
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.conv2_drop = ConsistentMCDropout2d()
        self.fc1 = nn.Linear(1024, 128)
        self.fc1_drop = ConsistentMCDropout()
        self.fc2 = nn.Linear(128, num_classes)

    def mc_forward_impl(self, input: torch.Tensor):
        input = F.relu(F.max_pool2d(self.conv1_drop(self.conv1(input)), 2))
        input = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(input)), 2))
        input = input.view(-1, 1024)
        input = F.relu(self.fc1_drop(self.fc1(input)))
        input = self.fc2(input)

        return input

In [None]:
BayesianMNISTCNN_EBM()

BayesianMNISTCNN_EBM(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv1_drop): ConsistentMCDropout2d(p=0.5)
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): ConsistentMCDropout2d(p=0.5)
  (fc1): Linear(in_features=1024, out_features=128, bias=True)
  (fc1_drop): ConsistentMCDropout(p=0.5)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

In [None]:
# exports


class MnistOptimizerFactory(ModelOptimizerFactory):
    def create_model_optimizer(self) -> ModelOptimizer:
        model = BayesianMNISTCNN()
        optimizer = torch.optim.Adam(model.parameters(), weight_decay=5e-4)
        return ModelOptimizer(model=model, optimizer=optimizer)


@dataclass
class MnistModelTrainer(ModelTrainer):
    num_training_samples: int
    num_validation_samples: int
    num_patience_epochs: int
    max_training_epochs: int
    device: str

    @staticmethod
    def create_model_optimizer() -> ModelOptimizer:
        model = BayesianMNISTCNN()
        optimizer = torch.optim.Adam(model.parameters(), weight_decay=5e-4)
        return ModelOptimizer(model=model, optimizer=optimizer)

    def get_trained(self, *, train_loader: DataLoader, train_augmentations: Optional[Module], validation_loader: DataLoader,
                    log) -> TrainedModel:
        model_optimizer = self.create_model_optimizer()

        train(
            model=model_optimizer.model,
            optimizer=model_optimizer.optimizer,
            training_samples=self.num_training_samples,
            validation_samples=self.num_validation_samples,
            train_loader=train_loader,
            train_augmentations=train_augmentations,
            validation_loader=validation_loader,
            patience=self.num_patience_epochs,
            max_epochs=self.max_training_epochs,
            device=self.device,
            training_log=log,
        )

        return TrainedBayesianModel(model_optimizer.model)

    def get_distilled(self, *, prediction_loader: DataLoader, train_augmentations: Optional[Module],
                      validation_loader: DataLoader, log) -> TrainedModel:
        model_optimizer = self.create_model_optimizer()

        loss = torch.nn.KLDivLoss(log_target=True, reduction="batchmean")

        train(
            model=model_optimizer.model,
            optimizer=model_optimizer.optimizer,
            loss=loss,
            validation_loss=torch.nn.NLLLoss(),
            training_samples=self.num_training_samples,
            validation_samples=self.num_validation_samples,
            train_loader=prediction_loader,
            train_augmentations=train_augmentations,
            validation_loader=validation_loader,
            patience=self.num_patience_epochs,
            max_epochs=self.max_training_epochs,
            prefer_accuracy=True,
            device=self.device,
            training_log=log,
        )

        return TrainedBayesianModel(model_optimizer.model)