In [1]:
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import StepLR

from vbll.layers.classification import VBLLClassificationD

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
mnist_train_dataset = datasets.FashionMNIST(root='data',
                               train=True,
                               transform=transforms.ToTensor(),
                               download=True)

mnist_test_dataset = datasets.FashionMNIST(root='data',
                              train=False,
                              transform=transforms.ToTensor())

mnist_ood_dataset = datasets.MNIST(root='data',
                                    train=False,
                                    transform=transforms.ToTensor(),
                                    download=True)


In [3]:
class Classifier(nn.Module):
    def __init__(self, cfg):
        super(Classifier, self).__init__()
        self.cfg = cfg
        self.fc1 = nn.Linear(cfg.IN_FEATURES, cfg.HIDDEN_FEATURES)
        self.fc2 = nn.Linear(cfg.HIDDEN_FEATURES, cfg.HIDDEN_FEATURES)
        self.relu = nn.ReLU()
        self.vbll = VBLLClassificationD(cfg.HIDDEN_FEATURES,
                                        cfg.OUT_FEATURES,
                                        cfg.REG_WEIGHT,
                                        parameterization = cfg.PARAM,
                                        return_ood=cfg.RETURN_OOD)

    def forward(self, x):
        out = x.view(x.shape[0], -1)
        out = self.relu(self.fc1(out))
        out = self.relu(self.fc2(out))
        return self.vbll(out)

In [4]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True

In [5]:
class MODEL_CFG:
    IN_FEATURES = 784
    HIDDEN_FEATURES = 256
    OUT_FEATURES = 10
    REG_WEIGHT = 1./mnist_train_dataset.__len__()
    PARAM = 'dense'
    RETURN_OOD = True

model = Classifier(MODEL_CFG()).to(device)

In [9]:
class TRAIN_CFG:
  num_epochs = 10
  batch_size = 256
  lr = 1e-3
  lr_decay_every = 10000
  lr_decay_gamma = 0.5
  wd = 1e-2
  opt = torch.optim.AdamW
  clip_value = 1
  validation_freq = 1

def eval_acc(preds, y):
    map_preds = torch.argmax(preds, dim=1)
    return (map_preds == y).float().mean()

def train(train_dataset, test_dataset, model, train_cfg):
    
  train_dataloader = DataLoader(train_dataset, batch_size = train_cfg.batch_size, shuffle=True)
  test_dataloader = DataLoader(test_dataset, batch_size = train_cfg.batch_size, shuffle=True)
    
  optimizer = train_cfg.opt(model.parameters(), lr=train_cfg.lr, weight_decay=train_cfg.wd)
  scheduler = StepLR(optimizer, step_size=train_cfg.lr_decay_every, gamma=train_cfg.lr_decay_gamma)

  output_metrics = {
      'train_loss': [],
      'test_loss': [],
      'train_acc': [],
      'test_acc': []
  }

  for epoch in range(train_cfg.num_epochs + 1):
    running_loss = []
    running_acc = []

    model.train()
    for train_step, data in enumerate(train_dataloader):
      optimizer.zero_grad()
      x = data[0].to(device)
      y_label = data[1].to(device)
      out = model(x)

      loss = out.train_loss_fn(y_label)
      running_loss.append(loss.item())
      running_acc.append(eval_acc(out.predictive.probs, y_label).item())

      loss.backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(), train_cfg.clip_value)
      optimizer.step()


    output_metrics['train_loss'].append(np.mean(running_loss))
    output_metrics['train_acc'].append(np.mean(running_acc))

    scheduler.step()
    if epoch % train_cfg.validation_freq == 0:
      # print(f'[{epoch + 1}] loss: {np.mean(running_loss):.3f}')
      # running_loss = []
      running_test_loss = []
      running_test_acc = []

      with torch.no_grad():
        model.eval()
        for test_step, data in enumerate(test_dataloader):
          x = data[0].to(device)
          y_label = data[1].to(device)

          out = model(x)
          loss = out.val_loss_fn(y_label)
          running_test_loss.append(loss.item())
          running_test_acc.append(eval_acc(out.predictive.probs, y_label).item())

        output_metrics['test_loss'].append(np.mean(running_test_loss))
        output_metrics['test_acc'].append(np.mean(running_test_acc))

  return output_metrics

In [10]:
output = train(mnist_train_dataset, mnist_test_dataset, model, TRAIN_CFG())

In [None]:
output