In [None]:
import tensorflow as tf

physical_devices = tf.config.list_physical_devices('GPU')
print("GPUs available:", physical_devices)

In [None]:
import math
import numpy as np

from ocml.datasets import build_mnist
from ocml.evaluate import check_LLC
from ocml.models import spectral_VGG
from ocml.plot import plot_preds_ood, plot_preds_ood
from ocml.priors import uniform_sampler_images
from ocml.train import train_loop, SH_KR

In [None]:
from types import SimpleNamespace
import math

def get_config(debug=False):
  domain = [-1., 1.]
  margin = (2/100) * (28 * 28 * (domain[1] - domain[0]))**0.5  # 5% pixels for real images
  lbda = 200.  # weak Hinge regularization, less KR.
  config = SimpleNamespace(
      dataset_name = "mnist",
      optimizer = "adam",
      maxiter = 16,
      batch_size = 128,
      domain = domain,
      margin = margin,
      lbda = lbda,
      k_coef_lip = 1.,
      strides = False,
      num_epochs = 1 if debug else 101,
      spectral_dense = True,
      domain_clip = True,
      deterministic = True,
      pooling = True,
      groupsort = False,
      conv_widths = [128, 128, 128],
      dense_widths = [128, 128, 128],
      in_labels = [4]
    )
  return config

In [None]:
debug = False
config = get_config(debug)

In [None]:
try:
  import os
  os.environ['WANDB_NOTEBOOK_NAME'] = 'run_toy2d.ipynb'
  import wandb
  wandb.login()
  wandb_available = True
except ModuleNotFoundError as e:
  print(e)
  print("Wandb logs will be removed.")
  wandb_available = False
plot_wandb = wandb_available and not debug  # Set to False to de-activate Wandb.
if plot_wandb:
  import wandb
  wandb.init(project="oneclass", config=config.__dict__)
else:
  try:
    wandb.finish()
  except Exception as e:
    print(e)

In [None]:
input_shape = (28, 28, 1)
model = discriminator(input_shape, conv_widths=config.conv_widths,
                      dense_widths=config.dense_widths,
                      k_coef_lip=config.k_coef_lip)

if config.conventional:
  loss_fn = BCE()
else:
  loss_fn = SH_KR(config.margin, config.lbda)

In [None]:
# Produce and process dataset.
dataset = build_mnist(config.batch_size)
epoch_length = math.ceil(50*1000 / config.batch_size)

# Create optimizer class.
opt = tf.keras.optimizers.get(config.optimizer)

In [None]:
gen = tf.random.Generator.from_seed(1234)
free_batch = next(iter(dataset))
_ = model(free_batch, training=True)  # garbage forward.
seeds = uniform_sampler_images(gen, config.batch_size, free_batch.shape[1:])
model.summary()

In [None]:
it_dataset = iter(dataset)
it_test = iter(build_dataset(config.batch_size, split='test'))
it_ood = iter(build_dataset(config.batch_size, split='ood'))
plot_imgs(next(it_dataset), 'dataset.png')
plot_imgs(next(it_test), 'dataset.png')
plot_imgs(next(it_ood), 'dataset.png')

In [None]:
check_LLC(model, seeds, plot_wandb)

In [None]:
num_epochs = config.num_epochs
for epoch in range(num_epochs):
  train_loop(model, opt, loss_fn, gen, dataset, epoch_length, config.domain, config.maxiter, plot_wandb=plot_wandb)
  check_LLC(model, seeds, plot_wandb=plot_wandb)
  evaluate(epoch, model, next(it_dataset), next(it_test), next(it_ood), plot_wandb=plot_wandb)
  plot_advs(epoch, next(it_dataset), save_file=True)