In [None]:
import os
import sys
ENV = 'local' # ['colab', 'local']

if ENV == 'colab':
  # for running on google drive:
  from google.colab import drive
  drive.mount("/content/drive/", force_remount=True)

  !pip install fvcore -q

  module_dir = "/content/drive/My Drive/Bayes-Stochastic-Depth/"
  # append local module to path
  module_path = os.path.abspath(os.path.join(module_dir))
  if module_path not in sys.path:
      sys.path.append(module_path)

  data_dir = "/content/drive/My Drive/Bayes-Stochastic-Depth/data"
elif ENV == 'local':
  data_dir = "./data"

In [None]:
from typing import Any, Callable, List, Optional, Type, Union, Tuple, Dict
from torch import Tensor

import torch
from torch import nn, optim, mps
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision import transforms

import matplotlib.pyplot as plt
import copy
from tqdm import tqdm

from models import resnet
from data import load_CIFAR
from visualization import show_cifar_images
from utils import (
    AccuracyMetric,
    get_confusion_matrix,
    count_FLOPS,
    count_parameters,
    calculate_storage_in_mb,
    bayes_eval,
    bayes_forward,
    parse_loss,
    parse_scheduler,
)

## Prepare dataset

In [None]:
# some global variables
classes = [
    "plane",
    "car",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]

CLASS_MAP = {i: classes[i] for i in range(len(classes))}
FILL_PIX = None


DEVICE = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else DEVICE)

# overwrite device:
DEVICE = torch.device('cpu')

### Move dataset into memory

Since the entire CIFAR dataset (train + val + test) is around 180 MB, this can easily fit onto the GPU directly & speed up training a lot.

**Comparison**
* Training time with moving from CPU to GPU: `60s`
* Everything on GPU: `40s`

In [None]:
datasets = load_CIFAR(data_dir, DEVICE)

### Preview images

In [None]:
show_cifar_images((2, 8), show_labels=True, dataset=datasets['val'])
# to show images with predictions:
# show_cifar_images((2, 8), show_labels=True, dataset=testset, preds=preds)

## ResNet preview

In [None]:
resnet18 = resnet("resnet18", num_classes=10)

# load resnet18 pretrained on ImageNet
# resnet18_torch = torchvision.models.resnet.resnet18(num_classes=10)

print(f"Number of parameters: {count_parameters(resnet18, print_table=False)/1e6 :.3f} M")
print(f"Parameter in MB: {calculate_storage_in_mb(resnet18):.3f}")
print(f"Number of expected layers: {resnet18.expected_layers()}")
# print(f"Number of FLOPS: {count_FLOPS(resnet18, input_dim=(32, 32))/1e9:.3f} GFLOPS")

In [None]:
dummy_input = torch.randn(16, 3, 32, 32)
dummy_output = resnet18(dummy_input)

del(resnet18)

## Training Manager

In [None]:
class TrainingManager:
    def __init__(self, config, datasets):
        self.datasets = datasets
        self.config = config

        self._load()

    def _load(self) -> None:
        # read config
        self.T = self.config.get("T", 10)
        self.batch_size = self.config.get("batch_size", 16)
        self.input_dim = self.config.get("input_dim", (32, 32))
        self.total_epochs = self.config.get("total_epochs", 100)
        self.weight_decay = self.config.get("weight_decay", 1e-4)
        self.baseline_lr = self.config.get("lr", 1e-4)
        self.patience = self.config.get("patience", 20)
        self.track_metrics = self.config.get("track_metrics", ["loss"])
        self.num_classes = self.config.get("num_classes", 10)

        # initiate dataloader
        self.dataloaders = {
            "train": DataLoader(
                self.datasets["train"], batch_size=self.batch_size, shuffle=True
            ),
            "val": DataLoader(self.datasets["val"], batch_size=1, shuffle=False),
            "test": DataLoader(self.datasets["test"], batch_size=1, shuffle=False),
        }

        # initiate model
        self.net = resnet(
            resnet_name=self.config["model"],
            num_classes=self.config.get("num_classes", 10),
            dropout_mode=self.config.get("dropout_mode", "none"),
            dropout_p=self.config.get("dropout_p", 0.0),
            sd_mode=self.config.get("sd_mode", "none"),
            sd_p=self.config.get("sd_p", 0.0),
        )
        self.net.to(DEVICE)

        # initiate loss function and metric
        self.criterion = parse_loss(self.config.get("loss", "CE"))
        self.metric = AccuracyMetric()

        # initiate optimizer and scheduler
        self.optimizer = torch.optim.RMSprop(
            self.net.parameters(), weight_decay=self.weight_decay, lr=self.baseline_lr
        )
        self.scheduler = parse_scheduler(
            self.optimizer, self.config.get("scheduler", "none"), self.total_epochs
        )

        # initialize history
        self.epoch = 0
        self.history = {"lr": []}

        for split in ["train", "val", "val_b"]:
            self.history[f"loss/{split}"] = []
            self.history[f"acc/avg/{split}"] = []
            self.history[f"acc/global/{split}"] = []
            for c in CLASS_MAP.values():
                self.history[f"acc/{c}/{split}"] = []

        self.history["best_val_loss"] = 999
        self.history["best_val_acc"] = 0
        self.history["best_epoch"] = 0
        self.patience_count = 0

    def _check_early_stop(self, min_epochs: int = 10) -> bool:
        """
        Checks for early stop and updates best net, returns early stop state
        """
        if self.epoch < min_epochs or len(self.history["loss/val"]) <= 1:
            return False

        update_best = False
        loss = self.history["loss/val"][-1]
        acc = self.history["acc/avg/val"][-1]
        curr_min_loss = self.history.get("best_val_loss", 999)
        curr_max_acc = self.history.get("best_val_acc", 0)

        # compare loss and acc to current best
        if "loss" in self.track_metrics and loss < curr_min_loss:
            print(f"loss decreased by {(curr_min_loss-loss)/curr_min_loss*100 :.3f} %")
            self.history["best_val_loss"] = loss
            update_best = True
        if "acc" in self.track_metrics and acc > curr_max_acc:
            print(
                f"acc increased by {(acc-curr_max_acc)/(curr_max_acc+1e-16)*100 :.3f} %"
            )
            self.history["best_val_acc"] = acc
            update_best = True

        if loss > curr_min_loss and acc < curr_max_acc:
            self.patience_count += 1

        # update best net
        if update_best:
            print("saving best net...")
            self.best_net = copy.deepcopy(self.net)
            self.history["best_epoch"] = self.epoch
            self.patience_count = 0

        if self.patience_count >= self.patience:
            print(
                f"Acc(c) or miou have not improved for {self.patience} epochs, terminate training"
            )
            return True

        return False

    def evaluate(
        self,
        set: str = "test",
        bayes_mode: bool = True,
        T: Optional[int] = None,
        **kwargs,
    ) -> Tuple[float, Dict[str, float]]:
        """
        Evaluate the model on a given set
        """
        self.net.eval()
        buffer_tensor = None
        if bayes_mode:
            self.net.set_bayes_mode(True)
            T = self.T if T is None else T
            h, w = self.config["input_dim"]
            buffer_tensor = torch.empty(
                size=(T, 3, h, w), dtype=torch.float32, device=DEVICE
            )
        else:
            self.net.set_bayes_mode(False)
            T = 0

        # create temporary dict to hold epoch results for train set
        loss = 0
        confusion_matrix = torch.zeros(
            (self.num_classes, self.num_classes), device=DEVICE
        )

        with torch.no_grad():
            for X_batch, y_batch in tqdm(self.dataloaders[set]):
                # force batch size to be 1
                for x, y in zip(X_batch, y_batch):
                    # forward pass
                    y_logits, y_pred = bayes_eval(self.net, x, T, buffer=buffer_tensor)
                    # compute loss
                    loss += self.criterion(y, y_logits).item()
                    # update confusion matrix
                    confusion_matrix += get_confusion_matrix(
                        y, y_pred, self.num_classes
                    )

        # compute loss and accuracy
        loss /= len(self.dataloaders[set].dataset)
        accs = self.metric(confusion_matrix)

        return loss, accs

    def train(self, epochs, eval_mode="bayes") -> None:
        """
        Main training loop
        Args:
            epochs: number of epochs to train
            eval_mode: mode to evaluate the model, can be "regular", "bayes", or "all"
        """
        torch.cuda.empty_cache()  # helps clearing RAM
        for e in range(epochs):
            self.net.train()
            # stochastic regularization should always be on during training
            self.net.set_bayes_mode(True)
            # create temporary dict to hold epoch results for train set
            train_loss = 0
            confusion_matrix = torch.zeros(
                (self.num_classes, self.num_classes), device=DEVICE
            )

            for X_batch, y_batch in tqdm(self.dataloaders["train"]):
                self.optimizer.zero_grad()
                # forward pass
                y_logits = self.net(X_batch)
                # compute loss
                loss = self.criterion(y_batch, y_logits)
                # back prop
                loss.backward()
                # update parameters
                self.optimizer.step()
                # update loss
                batch_size = X_batch.shape[0]
                train_loss += loss.item() * batch_size
                # update confusion matrix
                y_pred = torch.argmax(y_logits, 1, keepdim=False)
                confusion_matrix += get_confusion_matrix(
                    y_batch, y_pred, self.num_classes
                )
            self.scheduler.step()

            # compute average train loss and accuracy
            n_train = len(self.dataloaders["train"].dataset)
            train_loss /= n_train
            accs = self.metric(confusion_matrix)
            self._log_results("train", train_loss, accs)

            # compute validation loss and accuracy
            val_loss, val_accs, val_loss_b, val_accs_b = None, None, None, None
            if eval_mode == "regular" or eval_mode == "all":
                val_loss, val_accs = self.evaluate("val", False)
            if eval_mode == "bayes" or eval_mode == "all":
                val_loss_b, val_accs_b = self.evaluate("val", True)

            self._log_results("val", val_loss, val_accs)
            self._log_results("val_b", val_loss_b, val_accs_b)

            self.history["lr"].append(self.scheduler.get_last_lr()[0])

            # check for early stop
            early_stop = self._check_early_stop()
            self.epoch += 1

            # print results
            print(f"Epoch {self.epoch}")
            for name in ["train", "val"]:
                print(
                    f"{name}: Acc(g) = {self.history[f'acc/global/{name}'][-1]*100:.4f}, Acc(c) = {self.history[f'acc/avg/{name}'][-1]*100:.4f}, Loss = {self.history[f'loss/{name}'][-1]:.4f}"
                )

            if early_stop or self.epoch >= self.total_epochs:
                print("Terminate training")
                return

    def _log_results(self, set: str, loss: float, accs: Dict[str, float]) -> None:
        if loss is None:
            return
        self.history[f"loss/{set}"].append(loss)
        for name in accs.keys():
            self.history[f"{name}/{set}"].append(accs.get(name, 0))
        return

## Training

In [None]:
config = {
    "model": "resnet18",
    "num_classes": 10,
    "dropout_mode": "none",
    "dropout_p": 0.0,
    "sd_mode": "none",
    "sd_p": 0.0,
    "loss": "CE",
    "lr": 1e-4,
    "weight_decay": 1e-4,
    "scheduler": "none",
    "T": 10,
    "batch_size": 16,
    "input_dim": (32, 32),
    "total_epochs": 100,
    "patience": 20,
    "track_metrics": ["loss", "acc"],
}

training_manager = TrainingManager(config, datasets)

In [None]:
training_manager.train(1, eval_mode='regular')