In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # or any {'0', '1', '2'}
# 0: infos, warning, errors.
# 1: warnings, errors.
# 2: errors.
# 3: none.
import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
print("GPUs available:", physical_devices)
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

In [None]:
from functools import partial
import math
import numpy as np
import random

from ocml.datasets import build_mnist, tfds_from_sampler, tf_from_tfds, zip_ds
from ocml.evaluate import check_LLC, log_metrics
from ocml.models import spectral_VGG, spectral_VGG_V2
from ocml.plot import plot_preds_ood, plot_imgs_grid, plot_gan
from ocml.priors import uniform_image, Mnist_NDA
from ocml.train import train, SH_KR

In [None]:
from types import SimpleNamespace
import math


def perc_to_margin(img_size, num_channels, perc, domain):
    return perc * (img_size * img_size * num_channels * (domain[1] - domain[0]))**0.5

def get_config(debug=False):
  domain = [-1., 1.]
  # heuristic of https://arxiv.org/abs/2206.06854
  ratio_images = 0.5 / 100
  ratio_pixels = 4 / 100
  margin = perc_to_margin(28, 1, ratio_pixels, domain)  # 5% pixels for real images
  lbda = 1. / ratio_images  
  print(f"Margin={margin:.3f} Lambda={lbda:.3f}")
  dataset_name = os.environ.get("DATASET_NAME", "fashion_mnist")
  config = SimpleNamespace(
    dataset_name = dataset_name,
    # Newton-Raphson.
    maxiter = 16,
    eta = 5.,
    level_set = 0,
    batch_size = 512,
    domain = domain,
    margin = margin,
    lbda = lbda,
    domain_clip = True,
    deterministic = False,
    negative_augmentation = False,
    overshoot_boundary = False,
    # architecture.
    k_coef_lip = 1.,
    strides = False,
    spectral_dense = True,
    pooling = True,
    global_pooling = False,
    groupsort = True,
    conv_widths = [256, 256, 256],
    dense_widths = [256, 256, 256],
    # training.
    in_labels = [4],
    warmup_epochs = 10,
    epochs_per_plot = 15,
  )
  return config

In [None]:
debug = "SANDBOX" in os.environ
config = get_config(debug)
train_kwargs = {
  'domain': config.domain,
  'eta': config.eta,
  'deterministic': config.deterministic,
  'level_set': config.level_set,
  'overshoot_boundary': config.overshoot_boundary
}

In [None]:
import plotly.io as pio
print("PLOTLY_RENDERER:", pio.renderers.default)
try:
  import wandb
  wandb.login()
  wandb_available = True
except ModuleNotFoundError as e:
  print(e)
  print("Wandb logs will be removed.")
  wandb_available = False
plot_wandb = wandb_available and not debug  # Set to False to de-activate Wandb.
if plot_wandb:
  import wandb
  group = os.environ.get("WANDB_GROUP", "sandbox_fashion_mnist")
  wandb.init(project="oneclass", config=config.__dict__, group=group, save_code=True)
else:
  try:
    wandb.finish()
  except Exception as e:
    print(e)

train_kwargs['log_metrics_fn'] = partial(log_metrics, plot_wandb=plot_wandb)

In [None]:
input_shape = (28, 28, 1)
V2 = True
if V2:
  model = spectral_VGG_V2(input_shape, k_coef_lip=config.k_coef_lip)
else:
  model = spectral_VGG(input_shape, conv_widths=config.conv_widths,
                       dense_widths=config.dense_widths,
                       k_coef_lip=config.k_coef_lip,
                       groupsort=config.groupsort,
                       pooling=config.pooling,
                       global_pooling=config.global_pooling)

loss_fn = SH_KR(config.margin, config.lbda)

In [None]:
# Produce and process dataset.
p_dataset = build_mnist(config.dataset_name, config.batch_size, in_labels=config.in_labels)
num_images = 60 * 1000
epoch_length = math.ceil(num_images*len(config.in_labels)*(1/10) / config.batch_size) if not debug else 3

In [None]:
# Create optimizer class.
decay_steps = epoch_length*(config.warmup_epochs + config.epochs_per_plot*2)
initial_learning_rate = 1e-3
learning_rate =  tf.keras.optimizers.schedules.PolynomialDecay(
    initial_learning_rate=initial_learning_rate, decay_steps=decay_steps,
  end_learning_rate=initial_learning_rate/1000, power=1.)
opt = tf.keras.optimizers.RMSprop(learning_rate=learning_rate)

# Initialize the network.
gen = tf.random.Generator.from_seed(random.randint(0, 1000))
p_batch = next(iter(p_dataset))
_ = model(p_batch, training=True)  # garbage forward.
model.summary()

In [None]:
# Adversarial distribution.
if config.negative_augmentation:
  bs_1, bs_2 = config.batch_size // 2, config.batch_size - (config.batch_size // 2)
  q_random = tfds_from_sampler(uniform_image, gen, bs_1, p_batch.shape[1:], domain=config.domain)
  q_nda = Mnist_NDA().transform(build_mnist(bs_2, in_labels=config.in_labels))
  q_dataset = zip_ds(q_random, q_nda)
else:
  q_dataset = tfds_from_sampler(uniform_image, gen, config.batch_size, p_batch.shape[1:], domain=config.domain)
Q0 = next(iter(q_dataset))
plot_imgs_grid(Q0, 'X_ood.png')

In [None]:
X_P = tf.reshape(tf_from_tfds(p_dataset.take(epoch_length)), shape=(-1, 28, 28, 1))
X_test = tf_from_tfds(build_mnist(config.dataset_name, config.batch_size, in_labels=config.in_labels, split='test'))
X_ood = tf_from_tfds(build_mnist(config.dataset_name, config.batch_size, in_labels=config.in_labels, split='ood'))
print(f'TrainSize={len(X_P)} TestSize={len(X_test)} OODSize={len(X_ood)}')

In [None]:
# plot_imgs_grid(X_P, 'X_P.png')
# plot_imgs_grid(X_test, 'X_test.png')
# plot_imgs_grid(X_ood, 'X_ood.png')
# plot_imgs_grid(Q_0, 'X_ood.png')
# check_LLC(model, Q0, plot_wandb)

In [None]:
epoch = 0
for epoch in range(0, config.warmup_epochs):
  print(f"Epoch={epoch} LR={float(opt._decayed_lr(tf.float32)):.7f}")
  train(model, opt, loss_fn, gen, p_dataset, q_dataset, epoch_length, maxiter=0, **train_kwargs)
plot_preds_ood(epoch, model, X_P, X_test, X_ood, plot_histogram=True, plot_wandb=plot_wandb)
plot_gan(epoch, model, p_batch, Q0, gen, maxiter=config.maxiter, **train_kwargs)

In [None]:
end_epoch = config.epochs_per_plot+epoch+1
for epoch in range(epoch+1, end_epoch):
  print(f"Epoch={epoch} LR={float(opt._decayed_lr(tf.float32)):.7f}")
  train(model, opt, loss_fn, gen, p_dataset, q_dataset, epoch_length, maxiter=config.maxiter, **train_kwargs)
  plot_histogram = (epoch+1 == end_epoch)
  plot_preds_ood(epoch, model, X_P, X_test, X_ood, plot_histogram=plot_histogram, plot_wandb=plot_wandb)
plot_gan(epoch, model, p_batch, Q0, gen, maxiter=config.maxiter, **train_kwargs)

In [None]:
end_epoch = config.warmup_epochs+epoch+1
for epoch in range(epoch+1, end_epoch):
  print(f"Epoch={epoch} LR={float(opt._decayed_lr(tf.float32)):.7f}")
  train(model, opt, loss_fn, gen, p_dataset, q_dataset, epoch_length, maxiter=config.maxiter, **train_kwargs)
  plot_histogram = (epoch+1 == end_epoch)
  plot_preds_ood(epoch, model, X_P, X_test, X_ood, plot_histogram=plot_histogram, plot_wandb=plot_wandb)
plot_gan(epoch, model, p_batch, Q0, gen, maxiter=config.maxiter, **train_kwargs)

In [None]:
if plot_wandb:
  wandb.finish()

In [None]:
plot_gan(epoch, model, p_batch, Q0, gen, maxiter=config.maxiter, **train_kwargs)

In [None]:
plot_preds_ood(epoch, model, X_P, X_test, X_ood, plot_histogram=True, plot_wandb=False)