In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # or any {'0', '1', '2'}
# 0: infos, warning, errors.
# 1: warnings, errors.
# 2: errors.
# 3: none.
import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
print("GPUs available:", physical_devices)
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

In [None]:
from functools import partial
import math
import numpy as np
import random

from ocml.datasets import tfds_from_sampler, tf_from_tfds, zip_ds
from ocml.evaluate import check_LLC, log_metrics
from ocml.models import spectral_VGG, spectral_VGG_V2
from ocml.plot import plot_preds_ood, plot_imgs_grid, plot_gan
from ocml.priors import uniform_image, Mnist_NDA
from ocml.train import train, SH_KR

In [None]:
from types import SimpleNamespace
import math


def perc_to_margin(img_size, num_channels, perc, domain):
    return perc * (img_size * img_size * num_channels * (domain[1] - domain[0]))**0.5

def get_config(debug=False):
  domain = [-2., 2.]  # required for images.
  # heuristic of https://arxiv.org/abs/2206.06854
  ratio_images = 0.5 / 100
  ratio_pixels = 0.5 / 100
  margin = perc_to_margin(28, 1, ratio_pixels, domain)  # 5% pixels for real images
  lbda = 1. / ratio_images  
  print(f"Margin={margin:.3f} Lambda={lbda:.3f}")
  dataset_name = "cats_vs_dogs"
  config = SimpleNamespace(
    dataset_name = dataset_name,
    # Newton-Raphson.
    maxiter = 50,
    eta = 20.,
    level_set = - margin * 1.5,
    batch_size = 128,
    domain = domain,
    margin = margin,
    lbda = lbda,
    domain_clip = True,
    deterministic = False,
    negative_augmentation = True,
    overshoot_boundary = False,
    # architecture.
    k_coef_lip = 1.,
    strides = False,
    spectral_dense = True,
    pooling = True,
    global_pooling = False,
    groupsort = True,
    conv_widths = [256, 256, 256],
    dense_widths = [256, 256, 256],
    # training.
    in_labels = [0],
    warmup_epochs = 10,
    epochs_per_plot = 40,
  )
  return config

In [None]:
debug = "SANDBOX" in os.environ
config = get_config(debug)
train_kwargs = {
  'domain': config.domain,
  'eta': config.eta,
  'deterministic': config.deterministic,
  'level_set': config.level_set,
  'overshoot_boundary': config.overshoot_boundary
}

In [None]:
import plotly.io as pio
print("PLOTLY_RENDERER:", pio.renderers.default)
try:
  import wandb
  wandb.login()
  wandb_available = True
except ModuleNotFoundError as e:
  print(e)
  print("Wandb logs will be removed.")
  wandb_available = False
plot_wandb = wandb_available and not debug  # Set to False to de-activate Wandb.
if plot_wandb:
  import wandb
  group = os.environ.get("WANDB_GROUP", "sandbox_fashion_mnist")
  wandb.init(project="ocml_fashion", config=config.__dict__, group=group, save_code=True)
else:
  try:
    wandb.finish()
  except Exception as e:
    print(e)

train_kwargs['log_metrics_fn'] = partial(log_metrics, plot_wandb=plot_wandb)

In [None]:
input_shape = (64, 64, 3)
model = spectral_VGG_V2(input_shape, k_coef_lip=config.k_coef_lip, scale=1)

loss_fn = SH_KR(config.margin, config.lbda)

In [None]:
import tensorflow_datasets as tfds
from typing import Iterable, Iterator, Mapping, Optional, Sequence, Tuple

IMAGE_SIZE = 64 # 112  # 224
NUM_OF_CHANNELS = 3
CROP_PADDING = 32
MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255]
STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255]
NUM_IMAGES = 9469

def _center_crop(
    image: tf.Tensor,
    original_shape: Optional[tf.Tensor] = None,
) -> tf.Tensor:
    """Crops to center of image with padding then scales."""
    if original_shape is None:
        original_shape = tf.shape(image)
    image_height = original_shape[0]
    image_width = original_shape[1]

    padded_center_crop_size = tf.cast(
      ((IMAGE_SIZE / (IMAGE_SIZE + CROP_PADDING)) *
       tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32)

    offset_height = ((image_height - padded_center_crop_size) + 1) // 2
    offset_width = ((image_width - padded_center_crop_size) + 1) // 2
    crop_window = tf.stack([offset_height, offset_width,
                          padded_center_crop_size, padded_center_crop_size])
    image = tf.image.crop_to_bounding_box(image, offset_height, offset_width, padded_center_crop_size, padded_center_crop_size)
    return image

def _distorted_bounding_box_crop(
    image: tf.Tensor,
    *,
    jpeg_shape: tf.Tensor,
    bbox: tf.Tensor,
    min_object_covered: float,
    aspect_ratio_range: Tuple[float, float],
    area_range: Tuple[float, float],
    max_attempts: int,
) -> tf.Tensor:
    """Generates cropped_image using one of the bboxes randomly distorted."""
    bbox_begin, bbox_size, _ = tf.image.sample_distorted_bounding_box(
      jpeg_shape,
      bounding_boxes=bbox,
      min_object_covered=min_object_covered,
      aspect_ratio_range=aspect_ratio_range,
      area_range=area_range,
      max_attempts=max_attempts,
      use_image_if_no_bounding_boxes=True)

    # Crop the image to the specified bounding box.
    offset_y, offset_x, _ = tf.unstack(bbox_begin)
    target_height, target_width, _ = tf.unstack(bbox_size)
    image = tf.image.crop_to_bounding_box(image, offset_y, offset_x, target_height, target_width)
    return image

def random_crop(image, image_size):
    """Make a random crop of image_size."""
    original_shape = tf.shape(image)
    bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
    image = _distorted_bounding_box_crop(
      image,
      jpeg_shape=original_shape,
      bbox=bbox,
      min_object_covered=0.1,
      aspect_ratio_range=(3 / 4, 4 / 3),
      area_range=(0.08, 1.0),
      max_attempts=10)
    if tf.reduce_all(tf.equal(original_shape, tf.shape(image))):
        # If the random crop failed fall back to center crop.
        image = _center_crop(image, original_shape)
    return image

def _normalize_image(image: tf.Tensor) -> tf.Tensor:
    """Normalize the image to zero mean and unit variance."""
    image -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=image.dtype)
    image /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=image.dtype)
    return image

def preprocess_cats_dogs(image, filename, label, augmentation):
    if augmentation:
        image = random_crop(image, IMAGE_SIZE)
        image = tf.image.random_flip_left_right(image)
    else:
        image = _center_crop(image)
    image = tf.image.resize(image, [IMAGE_SIZE, IMAGE_SIZE], tf.image.ResizeMethod.BICUBIC)
    image = _normalize_image(image)
    return image

def filter_labels(white_list):
  """Select images belonging to a white list of labels."""
  white_list = tf.constant(white_list, dtype=tf.int64)
  def filter_fun(image, label):
    return tf.math.reduce_any(tf.equal(label, white_list))
  return filter_fun

def build_cats_dogs(ds_name, batch_size, in_labels, domain, split='train', preprocess_fun=None):
  """Convert dataset from Tf repository into iterable tf.Dataset."""
  ds = tfds.load('cats_vs_dogs', split='train', as_supervised=True, shuffle_files=True)
  if split in ['train', 'test'] :
    label_set = in_labels
  elif split == 'ood':
    label_set = list(set(range(0, 2)).difference(in_labels))
  ds = ds.filter(filter_labels(label_set))
  if preprocess_fun is None:
    preprocess_fun = preprocess_cats_dogs
  augmentation = split == 'train'
  ds = ds.map(partial(preprocess_fun, domain=None, augmentation=augmentation))
  to_shuffle = 2
  if split == 'train':
    ds = ds.repeat().shuffle(to_shuffle*batch_size)  # always repeat a dataset
  ds = ds.batch(batch_size).prefetch(4)
  return ds

In [None]:
# Produce and process dataset.
p_dataset = build_cats_dogs(config.dataset_name, config.batch_size, in_labels=config.in_labels, domain=config.domain)
num_images = 23262
epoch_length = math.ceil(num_images*len(config.in_labels)*(1/2) / config.batch_size) if not debug else 3

In [None]:
# Create optimizer class.
decay_steps = epoch_length*(config.warmup_epochs + config.epochs_per_plot*2)
initial_learning_rate = 2.5e-4
learning_rate =  tf.keras.optimizers.schedules.PolynomialDecay(
    initial_learning_rate=initial_learning_rate, decay_steps=decay_steps,
  end_learning_rate=initial_learning_rate/1000, power=1.)
opt = tf.keras.optimizers.RMSprop(learning_rate=learning_rate)

# Initialize the network.
gen = tf.random.Generator.from_seed(random.randint(0, 1000))
p_batch = next(iter(p_dataset))
_ = model(p_batch, training=True)  # garbage forward.
model.summary()

In [None]:
# Adversarial distribution.
if config.negative_augmentation:
  bs_1, bs_2 = config.batch_size // 2, config.batch_size - (config.batch_size // 2)
  q_random = tfds_from_sampler(uniform_image, gen, bs_1, p_batch.shape[1:], domain=config.domain)
  q_nda = Mnist_NDA(bs_1, p_batch.shape[1:]).transform(gen, build_ds(config.dataset_name, bs_2, in_labels=config.in_labels))
  q_dataset = zip_ds(q_random, q_nda)
else:
  q_dataset = tfds_from_sampler(uniform_image, gen, config.batch_size, p_batch.shape[1:], domain=config.domain)
Q0 = next(iter(q_dataset))
plot_imgs_grid(Q0, 'X_ood.png')

In [None]:
X_P = tf.reshape(tf_from_tfds(p_dataset.take(epoch_length)), shape=(-1, 32, 32, 3))
X_test = tf_from_tfds(build_cats_dogs(config.dataset_name, config.batch_size, in_labels=config.in_labels, domain=config.domain, split='test'))
X_ood = tf_from_tfds(build_cats_dogs(config.dataset_name, config.batch_size, in_labels=config.in_labels, domain=config.domain, split='ood'))
print(f'TrainSize={len(X_P)} TestSize={len(X_test)} OODSize={len(X_ood)}')

In [None]:
# plot_imgs_grid(X_P, 'X_P.png')
# plot_imgs_grid(X_test, 'X_test.png')
# plot_imgs_grid(X_ood, 'X_ood.png')
plot_imgs_grid(Q0, 'X_ood.png')
# check_LLC(model, Q0, plot_wandb)

In [None]:
epoch = 0
for epoch in range(0, config.warmup_epochs):
  print(f"Epoch={epoch} LR={float(opt._decayed_lr(tf.float32)):.7f}")
  train(model, opt, loss_fn, gen, p_dataset, q_dataset, epoch_length, maxiter=0, **train_kwargs)
plot_preds_ood(epoch, model, X_P, X_test, X_ood, plot_histogram=True, plot_wandb=plot_wandb)
plot_gan(epoch, model, p_batch, Q0[:16], gen, maxiter=config.maxiter, **train_kwargs)

In [None]:
end_epoch = config.epochs_per_plot+epoch+1
for epoch in range(epoch+1, end_epoch):
  print(f"Epoch={epoch} LR={float(opt._decayed_lr(tf.float32)):.7f}")
  train(model, opt, loss_fn, gen, p_dataset, q_dataset, epoch_length, maxiter=config.maxiter, **train_kwargs)
  plot_histogram = (epoch+1 == end_epoch)
  plot_preds_ood(epoch, model, X_P, X_test, X_ood, plot_histogram=plot_histogram, plot_wandb=plot_wandb)
plot_gan(epoch, model, p_batch, Q0[:16], gen, maxiter=config.maxiter, **train_kwargs)

In [None]:
end_epoch = config.epochs_per_plot+epoch+1
for epoch in range(epoch+1, end_epoch):
  print(f"Epoch={epoch} LR={float(opt._decayed_lr(tf.float32)):.7f}")
  train(model, opt, loss_fn, gen, p_dataset, q_dataset, epoch_length, maxiter=config.maxiter, **train_kwargs)
  plot_histogram = (epoch+1 == end_epoch)
  plot_preds_ood(epoch, model, X_P, X_test, X_ood, plot_histogram=plot_histogram, plot_wandb=plot_wandb)
plot_gan(epoch, model, p_batch, Q0[:16], gen, maxiter=config.maxiter, **train_kwargs)

In [None]:
def compute_certificates(model, config, X_train, X_test, X_ood):
  # run XPs.
  y_train    = model.predict(X_train, batch_size=256, verbose=0).flatten()
  y_test     = model.predict(X_test, batch_size=256, verbose=0).flatten()
  y_ood      = model.predict(X_ood, batch_size=256, verbose=0).flatten()
  
  msg = dict()
  attacks_radii = [0, 8, 16, 36, 72, 144, 256]
  for r in attacks_radii:
    e = (r / 255) * (config.domain[1] - config.domain[0])
    T_train, acc_train, roc_auc_train = calibrate_accuracy(y_train-e, y_ood+e)
    T_test, acc_test, roc_auc_test = calibrate_accuracy(y_test-e, y_ood+e)
    msg[f'certificate_r={r:}_train'] = roc_auc_train
    msg[f'certificate_r={r:}_test'] = roc_auc_test
  print(msg)
  wandb.log(msg)
  return msg

In [None]:
compute_certificates(model, config, X_P, X_test, X_ood)

In [None]:
from sklearn.metrics import roc_auc_score
import tqdm

def proj_l2_ball(x, x_0, eps):
  n = x - x_0
  l = tf.reduce_sum(n**2, axis=-1, keepdims=True)**0.5
  l = tf.maximum(l, 1e-6 * eps)
  factor = tf.where(l > eps, eps / l, 1.)
  n = n * factor
  x = x_0 + n
  return x

def random_ball(x_0, eps):
  n = tf.random.normal((x_0.shape[0], x_0.shape[1]+1,))
  l = tf.reduce_sum(n**2, axis=-1, keepdims=True)**0.5
  n = n / (l + eps*1e-6)
  n = n[:,:-1]  # drop last coordinate.
  x = x_0 + n
  return x

def gd_step(x, x0, label, eps, model, step):
  with tf.GradientTape(watch_accessed_variables=False) as tape:
    tape.watch(x)
    y = model(tf.reshape(x, (-1, 64, 64, 3)), training=False)
  g = tape.batch_jacobian(y, x)
  g = g[:,0,:]
  x = x - step * label * g  # descent: decreases OOD score of OOD, increases of normal data
  x = proj_l2_ball(x, x0, eps)
  x = tf.clip_by_value(x, -1, 1.)
  return x

@tf.function
def l2_pgd(model, x0, label, eps, attempts=1, random_start=True):
  y_best = None
  for attempt in range(attempts):
    if random_start:
      x = random_ball(x0, eps)
    else:
      x = x0
    x = tf.clip_by_value(x, -1, 1.)
    max_iter = 50
    step = 0.025 * eps
    for iter in range(max_iter):
      x = gd_step(x, x0, label, eps, model, step)
    delta = tf.reduce_mean(tf.reduce_sum((x - x0)**2, axis=-1)**0.5)
    y = model(tf.reshape(x, (-1, 64, 64, 3)), training=False)
    if y_best is None:
      y_best = y
    else:
      y_min = tf.minimum(y, y_best)
      y_max = tf.maximum(y, y_best)
      y_best = tf.where(label[:,0] > 0., y_min, y_max)
  return y_best, delta

def l2_pgd_batch(model, images, labels, eps, batch_size):
  scores = []
  images = tf.reshape(images, (-1, batch_size, 12288))
  labels = tf.reshape(labels, (-1, batch_size, 1))
  deltas = [0.]
  for x0, label in tqdm.tqdm(zip(images, labels)):
    if eps == 0.:
      x = tf.reshape(x0, shape=(-1, 64, 64, 3))
      score = model(x, training=False)
    else:
      score, delta = l2_pgd(model, x0, label, eps)
      deltas.append(delta)
    scores.append(score.numpy().flatten())
  scores = np.concatenate(np.array(scores), axis=0)
  return scores, deltas

def attack(model, config, X_train, X_test, X_ood, batch_size):
  X_ood = np.random.permutation(X_ood)[:(len(X_ood) // batch_size)*batch_size]
  X_test = np.random.permutation(X_test)[:(len(X_test) // batch_size)*batch_size]
  images = tf.constant(np.concatenate([X_test, X_ood], axis=0))
  labels = tf.concat([tf.ones((len(X_test),)), -tf.ones((len(X_ood),))], axis=0)
  msg = dict()
  attacks_radii = [0, 8, 16, 36, 72, 144, 255]
  for r in attacks_radii:
    e = (r / 255) * (config.domain[1] - config.domain[0])
    scores, deltas = l2_pgd_batch(model, images, labels, eps=e, batch_size=batch_size)
    deltas = np.mean(np.array(deltas))
    print("deltas:", deltas)
    msg['deltas'] = deltas
    roc_auc_test = roc_auc_score((labels+1)/2, scores)*100
    print(f'r={r:}_test={roc_auc_test}%')
    msg[f'r={r:}_test'] = roc_auc_test
    print(scores[:10], scores[-10:])
  print(msg)
  wandb.log(msg)
  return msg

In [None]:
attack(model, config, X_P, X_test, X_ood, batch_size=10)

In [None]:
if plot_wandb:
  wandb.finish()