In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import tensorflow_datasets as tfds
import tqdm
from sklearn.metrics import roc_auc_score
from sklearn.calibration import calibration_curve
import tensorflow_probability as tfp
from scipy.ndimage import gaussian_filter1d

### cifar experiments

In [None]:
# helper functions to create TF datasets

def normalize_img(image, label):
    """Normalizes images: `uint8` -> `float32`."""
    image = tf.cast(image, tf.float32) / 255.
    # image = tf.reshape(image, (-1,))
    return image, label


def make_ds(ds):
    ds = ds.map(
        normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds = ds.cache()
    ds = ds.repeat()
    ds = ds.shuffle(int(1e5))
    ds = ds.batch(128)
    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
    return ds


In [None]:
# create dataset, iterate over it and train, return model and metrics
class Trainer:
    
    def __init__(self, params):
        self.params = params
        print(f"\nRun training with params {self.params}")
        
        # create CIFAR dataset and get iterators
        ds_train, ds_test = out = tfds.load(
            'cifar10',
            split=['train', 'test'],
            shuffle_files=True,
            as_supervised=True,
            with_info=False,
        )
        ds_train = make_ds(ds_train)
        ds_test = make_ds(ds_test)
        self.iter_train = iter(ds_train)
        self.iter_test = iter(ds_test)

        # define model
        self.model = tf.keras.models.Sequential([
            tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=params['input_shape']),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dense(self.params['n_classes'], activation='softmax')])
        print(self.model.summary())

        # define optimizer
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.params['lr'])
        
        # maintain history
        self.history = []


    # define loss function
    # computes loss given a model and X, Y
    @tf.function
    def loss_fn(self, X, Y):
        Y_hat = self.model(X) 
        loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
        accuracy = tf.reduce_mean(tf.cast(tf.argmax(Y_hat, axis=1) == Y, tf.float32)) * 100.

        # Do PGD
        step_size = 0.1

        if self.params['version'] == 'original':
            # compute gradient of cross entropy loss wrt X and take a step in -ve direction
            # this would try to find a point in the neighborhood of X that minimizes cross entropy
            with tf.GradientTape() as tape:
                tape.watch(X)
                Y_hat = self.model(X)
                loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
            grads = tape.gradient(loss, X)
            grads_norm = tf.norm(tf.reshape(grads, (128, -1)), axis=1)
            grads = -step_size * grads / grads_norm[:, None, None, None]
            X_perturbed = tf.clip_by_value(X + grads, 0.0, 1.0)
            # compute -ve entropy at this new point
            Y_hat = self.model(X_perturbed)  
            entropy = -1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat), axis=1)
            loss_adv = -1.0 * tf.reduce_mean(entropy)
            # loss_adv = tf.keras.losses.CategoricalCrossentropy(from_logits=False)(y_true=tf.ones_like(Y_hat) / params['n_classes'], y_pred=Y_hat)
            return loss + self.params['lambda'] * loss_adv, loss_adv, accuracy

        elif self.params['version'] == 'entropy':
            # compute grad of entropy wrt X and take a step in negative direction
            # this would find a point in the neighborhood of X that would minimize entropy
            with tf.GradientTape() as tape:
                tape.watch(X)
                Y_hat = self.model(X)
                entropy = -1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat), axis=1)
            grads = tape.gradient(entropy, X)
            grads_norm = tf.norm(tf.reshape(grads, (128, -1)), axis=1)
            grads = -step_size * grads / grads_norm[:, None, None, None]      
            X_perturbed = tf.clip_by_value(X + grads, 0.0, 1.0)
            # compute entropy at this new point and multiply it by -1, since we want to maximize entropy
            Y_hat = self.model(X_perturbed)
            entropy = -1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat), axis=1)
            loss_adv = -1.0 * tf.reduce_mean(entropy)
            return loss + self.params['lambda'] * loss_adv, loss_adv, accuracy

        else:
            raise ValueError


    # define step function
    # computes gradients and applies them
    @tf.function
    def step_fn(self, X_train, Y_train):
        with tf.GradientTape() as tape:
            loss, loss_adv, accuracy = self.loss_fn(X_train, Y_train)
        grads = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
        return loss, loss_adv, accuracy
    
    
    
    @tf.function
    def eval_ood_helper(self, X_real, X_fake):
        Y_hat_real = self.model(X_real)
        entropy_real = -1.0 * tf.reduce_mean(Y_hat_real * tf.math.log(Y_hat_real), axis=1)
        Y_hat_fake = self.model(X_fake)
        entropy_fake = -1.0 * tf.reduce_mean(Y_hat_fake * tf.math.log(Y_hat_fake), axis=1)
        return tf.concat([entropy_real, entropy_fake], axis=0)

    def eval_ood(self, X_real, X_fake):
        logits = trainer.eval_ood_helper(X_real, X_fake).numpy()
        labels = np.concatenate([np.zeros(128), np.ones(128)])
        auc = roc_auc_score(labels, logits)
        return auc

    @tf.function
    def get_calibration_metrics(self, X, Y):
        logits = tf.math.log(self.model(X))
        brier = tf.reduce_mean(tfp.stats.brier_score(labels=Y, logits=logits))
        ece = tfp.stats.expected_calibration_error(num_bins=20, logits=logits, labels_true=Y)
        nll = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)(y_true=Y, y_pred=logits)
        return brier, ece, nll

        
    def train(self):
        
        # loop over data n_iters times
        for t in tqdm.trange(self.params['n_iters']):
            X_train, Y_train = next(self.iter_train)
            train_loss, train_loss_adv, train_acc = self.step_fn(X_train, Y_train)
            if t % 10 == 0:
                X_test, Y_test = next(self.iter_test)
                test_loss, test_loss_adv, test_acc = self.loss_fn(X_test, Y_test)
                self.history.append((train_loss.numpy(), test_loss.numpy(), train_acc.numpy(), test_acc.numpy(),
                                    train_loss_adv.numpy(), test_loss_adv.numpy()))
        
        self.history = np.array(self.history)


In [None]:
# plotting utils

def plot_training_metrics(trainer_vec):
    c_vec = plt.rcParams['axes.prop_cycle'].by_key()['color']
    plt.figure(figsize=(8, 6))
    for idx, (c, trainer) in enumerate(zip(c_vec, trainer_vec)):
        name = 'baseline' if idx == 0 else 'ours'
        plt.plot(gaussian_filter1d(trainer.history[:, 2], 100), '-o', c=c, label='%s - train' % name, markevery=1000, markersize=10)
        plt.plot(gaussian_filter1d(trainer.history[:, 3], 100), '-^', c=c, label='%s - val' % name, markevery=1000, markersize=10)
    plt.grid()
    plt.xlabel('iterations')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.show()
    
    
def plot_OOD_metrics(results):
    plt.figure(figsize=(12, 6))
    for (j, title) in enumerate(['Accuracy $(\\uparrow)$', 'OOD AUC $(\\uparrow)$', 'Brier $(\\downarrow)$', 'ECE  $(\\downarrow)$', 'NLL  $(\\downarrow)$']):
        plt.subplot(2, 3, j + 1)
        plt.title(title)
        for idx in range(len(trainer_vec)):
            x = np.arange(len(corruption_type_list))
            y = [results[(ctype, idx)][j] for ctype in corruption_type_list]
            width = 0.25
            offset = width
            plt.bar(x + width * (idx + 1) - offset, y, width=width, label='%s' % ('ours' if idx else 'baseline'))
        plt.xticks(np.arange(len(corruption_type_list)), corruption_type_list, rotation=90)
        plt.grid()
        if j == 0:
            plt.legend(fontsize=12, loc='lower right')
        if j == 0:
            plt.ylim([0.4, 0.75])
    plt.tight_layout()
    plt.show()

In [None]:
# baseline
params = {
    'input_shape': (32, 32, 3),
    'n_classes': 10,
    'lambda': 0.0,
    'lr': 3e-4,
    'n_iters': 10000,
    'version': 'original'
}

baseline_trainer = Trainer(params)
baseline_trainer.train()

# our method
params = {
    'input_shape': (32, 32, 3),
    'n_classes': 10,
    'lambda': 1.0,
    'lr': 3e-4,
    'n_iters': 10000,
    'version': 'original'
}


our_trainer = Trainer(params)
our_trainer.train()

# plot
trainer_vec = [baseline_trainer, our_trainer]
plot_training_metrics(trainer_vec)

In [None]:
# compute OOD metrics
results = {}
corruption_type_list = ['brightness_1', 'elastic_5', 'fog_5', 'frost_5', 'frosted_glass_blur_5']

for idx, trainer in enumerate(trainer_vec):
    print(f"Trainer {idx + 1}")
    for corruption_type in corruption_type_list:
        print(corruption_type)
        
        # load corrupted dataset
        (ds_corrupted,) = tfds.load(
            'cifar10_corrupted/%s' % corruption_type,
            split=['test'],
            shuffle_files=True,
            as_supervised=True,
            with_info=False,
        )
        
        # make dataset and get iterator
        ds_corrupted = make_ds(ds_corrupted)
        iter_corrupted = iter(ds_corrupted)

        # record metrics
        acc_vec = []
        auc_vec = []
        brier_vec = []
        ece_vec = []
        nll_vec = []
        
        # average over a 100 batches
        for _ in tqdm.trange(100):
            X_test, Y_test = next(trainer.iter_test)
            _, _, test_acc = trainer.loss_fn(X_test, Y_test)
            X_corrupted, Y_corrupted = next(iter_corrupted)
            test_auc = trainer.eval_ood(X_test, X_corrupted)
            brier, ece, nll = trainer.get_calibration_metrics(X_corrupted, Y_corrupted)

            acc_vec.append(test_acc.numpy())
            auc_vec.append(test_auc)
            brier_vec.append(brier.numpy())
            ece_vec.append(ece.numpy())
            nll_vec.append(nll.numpy())
        
        results[(corruption_type, idx)] = (np.mean(acc_vec), np.mean(auc_vec), np.mean(brier_vec), np.mean(ece_vec), np.mean(nll_vec))

plot_OOD_metrics(results)

### mnist experiments

In [None]:
results = {}
corruption_type_list = ['shot_noise', 'impulse_noise', 'rotate', 'canny_edges']
for corruption_type in corruption_type_list:
  (ds_train, ds_test), ds_info = tfds.load(
      'mnist_corrupted/identity',
      split=['train', 'test'],
      shuffle_files=True,
      as_supervised=True,
      with_info=True,
  )

  (_, ds_corrupted), ds_info = tfds.load(
      'mnist_corrupted/%s' % corruption_type,
      split=['train', 'test'],
      shuffle_files=True,
      as_supervised=True,
      with_info=True,
  )

  def normalize_img(image, label):
    """Normalizes images: `uint8` -> `float32`."""
    image = tf.cast(image, tf.float32) / 255.
    image = tf.reshape(image, (-1,))
    return image, label

  def make_ds(ds):
    ds = ds.map(
      normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds = ds.cache().repeat()
    ds = ds.shuffle(ds_info.splits['train'].num_examples)
    ds = ds.batch(128)
    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
    return ds

  ds_train = make_ds(ds_train)
  iter_train = iter(ds_train)

  ds_test = make_ds(ds_test)
  iter_test = iter(ds_test)

  ds_corrupted = make_ds(ds_corrupted)
  iter_corrupted = iter(ds_corrupted)

  for L in [0.0, 100.0]:
    print(corruption_type, L)
    model = tf.keras.models.Sequential([
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dense(10, activation='softmax')
    ])
    optimizer = tf.keras.optimizers.Adam(learning_rate=3e-4)

    @tf.function
    def loss_fn(X, Y):
      Y_hat = model(X)
      loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
      accuracy = tf.reduce_mean(tf.cast(tf.argmax(Y_hat, axis=1) == Y, tf.float32))

      ### Original version
      version = 'original'
      if version == 'original':
        with tf.GradientTape() as tape:
          tape.watch(X)
          Y_hat = model(X)
          loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
        grads = tape.gradient(loss, X)
        grads = -0.001 * grads / tf.norm(grads, axis=1)[:, None]
        X = tf.clip_by_value(X + grads, 0.0, 1.0)

        Y_hat = model(X)  
        loss2 = tf.keras.losses.CategoricalCrossentropy(from_logits=False)(y_true=tf.ones_like(Y_hat) / 10.0, y_pred=Y_hat)
        return loss + L * loss2, loss2, accuracy
      elif version == 'entropy':
        with tf.GradientTape() as tape:
          tape.watch(X)
          Y_hat = model(X)
          entropy = -1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat), axis=1)
        grads = tape.gradient(entropy, X)
        grads = -0.001 * grads / tf.norm(grads, axis=1)[:, None]
        X = tf.clip_by_value(X + grads, 0.0, 1.0)

        Y_hat = model(X)
        entropy = -1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat), axis=1)
        loss2 = -1.0 * tf.reduce_mean(entropy)
        return loss + L * loss2, loss2, accuracy, X
      else:
        raise ValueError

    @tf.function
    def step_fn(X_train, Y_train):
      with tf.GradientTape() as tape:
        loss, loss2, accuracy, X = loss_fn(X_train, Y_train)
      grads = tape.gradient(loss, model.trainable_variables)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))
      return loss, loss2, accuracy, X

    @tf.function
    def eval_ood_helper(X_real, X_fake):
      Y_hat_real = model(X_real)
      entropy_real = -1.0 * tf.reduce_mean(Y_hat_real * tf.math.log(Y_hat_real), axis=1)
      Y_hat_fake = model(X_fake)
      entropy_fake = -1.0 * tf.reduce_mean(Y_hat_fake * tf.math.log(Y_hat_fake), axis=1)
      return tf.concat([entropy_real, entropy_fake], axis=0)

    def eval_ood(X_real, X_fake):
      logits = eval_ood_helper(X_real, X_fake).numpy()
      labels = np.concatenate([np.zeros(128), np.ones(128)])
      auc = roc_auc_score(labels, logits)
      return auc

    history = []
    for t in tqdm.trange(5000):
      X_train, Y_train = next(iter_train)
      train_loss, train_loss2, train_acc, X = step_fn(X_train, Y_train)
      if t % 10 == 0:
        X_test, Y_test = next(iter_test)
        test_loss, test_loss2, test_acc, _ = loss_fn(X_test, Y_test)
        X_corrupted, _ = next(iter_corrupted)
        test_auc = eval_ood(X_test, X_corrupted)
      history.append((train_loss.numpy(), test_loss.numpy(), train_acc.numpy(), test_acc.numpy(),
                      train_loss2.numpy(), test_loss2.numpy(), test_auc))
      if t % 100 == 0:
        print(test_auc)
    history = np.array(history)
    test_acc = np.mean(history[-500:, 3])
    test_auc = np.mean(history[-500:, 6])
    results[(corruption_type, L)] = (test_acc, test_auc)

    plt.figure(figsize=(12, 6))
    plt.subplot(221)
    plt.title('Loss')
    plt.plot(history[:, 0], label='train')
    plt.plot(history[:, 1], label='val')
    plt.grid()
    plt.legend()
    
    plt.subplot(222)
    plt.title('Accuracy')
    plt.plot(history[:, 2], label='train')
    plt.plot(history[:, 3], label='val')
    plt.grid()
    plt.legend()

    plt.subplot(223)
    plt.title('Loss2')
    plt.plot(history[:, 4], label='train')
    plt.plot(history[:, 5], label='val')
    plt.grid()
    plt.legend()

    plt.subplot(224)
    plt.title('AUC')
    plt.plot(history[:, 6])
    plt.grid()
    plt.legend()
    plt.show()

In [None]:
plt.figure(figsize=(12, 3))
x = np.arange(len(corruption_type_list))
for (index, title) in enumerate(['Accuracy', 'OOD AUC']):
  plt.subplot(1, 2, index + 1)
  plt.title(title)
  for L in [0.0, 100.0]:
    y = []
    for corruption_type in corruption_type_list:
      y.append(results[(corruption_type, L)][index])
    plt.bar(x + 0.3 * (1 if L else 0) - 0.15, y, width=0.3, label='L = %s (%s)' % (L, 'ours' if L else 'baseline'))
  plt.xticks(x, corruption_type_list)
  plt.grid()
  plt.legend(fontsize=12, loc='lower right')
  if index == 0:
    plt.ylim([0.9, 1.0])
  # else:
  #   plt.ylim([0.6, 1.0])
plt.show()