In [3]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import tensorflow_datasets as tfds
import tqdm
from tensorflow.keras.models import Model
from sklearn.metrics import roc_auc_score
from sklearn.calibration import calibration_curve
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import tensorflow_probability as tfp
from scipy.ndimage import gaussian_filter1d

In [4]:
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
  tf.config.experimental.set_memory_growth(gpu, True)

### cifar experiments

In [5]:
# helper functions to create TF datasets

def normalize_img(image, label):
    """Normalizes images: `uint8` -> `float32`."""
    image = tf.cast(image, tf.float32) / 255.
    # image = tf.reshape(image, (-1,))
    return image, label


def make_ds(ds):
    ds = ds.map(
        normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds = ds.cache()
    ds = ds.repeat()
    ds = ds.shuffle(int(1e5))
    ds = ds.batch(128)
    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
    return ds


In [6]:
# create dataset, iterate over it and train, return model and metrics
class Trainer:
    
    def __init__(self, params):
        self.params = params
        print(f"\nRun training with params {self.params}")
        
        # create CIFAR dataset and get iterators
        ds_train, ds_test = out = tfds.load(
            'cifar10',
            split=['train', 'test'],
            shuffle_files=True,
            as_supervised=True,
            with_info=False,
        )
        ds_train = make_ds(ds_train)
        ds_test = make_ds(ds_test)
        self.iter_train = iter(ds_train)
        self.iter_test = iter(ds_test)

        # define model
        self.model = tf.keras.models.Sequential([
            tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=params['input_shape']),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dense(self.params['n_classes'], activation='softmax')])
        print(self.model.summary())

        # define optimizer
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.params['lr'])
        
        # maintain history
        self.history = []
        
        # get last layer reps
        self.model_last_layer = Model(self.model.input, self.model.layers[-2].output)
        


    # define loss function
    # computes loss given a model and X, Y
    @tf.function
    def loss_fn(self, X, Y):
        Y_hat = self.model(X) 
        loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
        accuracy = tf.reduce_mean(tf.cast(tf.argmax(Y_hat, axis=1) == Y, tf.float32)) * 100.
        entropy_on_original_point = tf.reduce_mean(-1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat), axis=1))

        if self.params['lambda'] == 0:
            return loss, 0., accuracy, entropy_on_original_point

        # Do PGD
        step_size = params['step_size']
        
        if self.params['version'] == 'min-max-cent':
            # compute gradient of cross entropy loss wrt X and take a step in +ve direction
            # this would try to find a point in the neighborhood of X that maximizes cross entropy
            with tf.GradientTape() as tape:
                tape.watch(X)
                Y_hat = self.model(X)
                loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
            grads = tape.gradient(loss, X)
            grads_norm = tf.norm(tf.reshape(grads, (128, -1)), axis=1)
            grads = step_size * grads / grads_norm[:, None, None, None]
            entropy_on_original_point = tf.reduce_mean(-1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat), axis=1))
            X_perturbed = tf.clip_by_value(X + grads, 0.0, 1.0)
            # compute cross entropy at this new point
            Y_hat = self.model(X_perturbed)  
            loss_adv = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
            return loss + self.params['lambda'] * loss_adv, loss_adv, accuracy, entropy_on_original_point

        elif self.params['version'] == 'max-min-ent':
            # compute grad of entropy wrt X and take a step in negative direction
            # this would find a point in the neighborhood of X that would minimize entropy
            with tf.GradientTape() as tape:
                tape.watch(X)
                Y_hat = self.model(X)
                exp_neg_entropy = tf.exp(tf.reduce_mean(Y_hat * tf.math.log(Y_hat), axis=1))
                entropy_on_original_point = tf.reduce_mean(-1. * tf.reduce_mean(Y_hat * tf.math.log(Y_hat), axis=1))
            grads = tape.gradient(exp_neg_entropy, X)
            grads_norm = tf.norm(tf.reshape(grads, (128, -1)), axis=1)
            grads = step_size * grads / grads_norm[:, None, None, None]      
            X_perturbed = tf.clip_by_value(X + grads, 0.0, 1.0)
            # compute entropy at this new point and multiply it by -1 (raise to exp. for better grads), since we want to maximize entropy
            Y_hat = self.model(X_perturbed)
            entropy = -1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat), axis=1)
            loss_adv = tf.reduce_mean(tf.exp(-1.0 * entropy))
            return loss + self.params['lambda'] * loss_adv, loss_adv, accuracy, entropy_on_original_point
        
        elif self.params['version'] == 'min-max-KL-unif':
            # compute grad of KL(unif, p_\theta) wrt X and take a step in +ve direction
            # this would find a point in the neighborhood of X that would maximize KL
            with tf.GradientTape() as tape:
                tape.watch(X)
                Y_hat = self.model(X)
                KL_unif = tf.keras.losses.CategoricalCrossentropy(from_logits=False)(y_true=tf.ones_like(Y_hat) / params['n_classes'], y_pred=Y_hat)
            grads = tape.gradient(KL_unif, X)
            grads_norm = tf.norm(tf.reshape(grads, (128, -1)), axis=1)
            grads = step_size * grads / grads_norm[:, None, None, None]      
            entropy_on_original_point = tf.reduce_mean(-1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat), axis=1))
            X_perturbed = tf.clip_by_value(X + grads, 0.0, 1.0)
            # compute entropy at this new point and multiply it by -1, since we want to maximize entropy
            Y_hat = self.model(X_perturbed)
            KL_unif = tf.keras.losses.CategoricalCrossentropy(from_logits=False)(y_true=tf.ones_like(Y_hat) / params['n_classes'], y_pred=Y_hat)
            loss_adv = tf.reduce_mean(KL_unif)
            return loss + self.params['lambda'] * loss_adv, loss_adv, accuracy, entropy_on_original_point

        elif self.params['version'] == 'min-max-cent-label-smooth':
            # compute grad of Cross Entropy(noisy_label, p_\theta) wrt X and take a step in +ve direction
            # this would find a point in the neighborhood of X that would maximize Cross Entropy
            Y = tf.one_hot(Y, self.params['n_classes'])
            Y_noisy = Y * (1 - self.params['label-smoothing-factor']) 
            Y_noisy += (self.params['label-smoothing-factor'] / Y.shape[1])
            with tf.GradientTape() as tape:
                tape.watch(X)
                Y_hat = self.model(X)
                cent_noisy = tf.keras.losses.CategoricalCrossentropy(from_logits=False)(y_true=Y_noisy, y_pred=Y_hat)    
            grads = tape.gradient(cent_noisy, X)
            grads_norm = tf.norm(tf.reshape(grads, (128, -1)), axis=1)
            grads = step_size * grads / grads_norm[:, None, None, None]      
            entropy_on_original_point = tf.reduce_mean(-1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat), axis=1))
            X_perturbed = tf.clip_by_value(X + grads, 0.0, 1.0)
            # compute entropy at this new point and multiply it by -1, since we want to maximize entropy
            Y_hat = self.model(X_perturbed)
            cent_noisy = tf.keras.losses.CategoricalCrossentropy(from_logits=False)(y_true=Y_noisy, y_pred=Y_hat)
            loss_adv = tf.reduce_mean(cent_noisy)
            return loss + self.params['lambda'] * loss_adv, loss_adv, accuracy, entropy_on_original_point
        
        
        elif self.params['version'] == 'label-smoothing':
            Y_hat = self.model(X)
            Y = tf.one_hot(Y, self.params['n_classes'])
            Y_noisy = Y * (1 - self.params['label-smoothing-factor']) 
            Y_noisy += (self.params['label-smoothing-factor'] / Y.shape[1])
            loss = tf.keras.losses.CategoricalCrossentropy(from_logits=False)(y_true=Y_noisy, y_pred=Y_hat)
            entropy_on_original_point = tf.reduce_mean(-1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat), axis=1))
            return loss, 0., accuracy, entropy_on_original_point
            
        else:
            raise ValueError


    # define step function
    # computes gradients and applies them
    @tf.function
    def step_fn(self, X_train, Y_train):
        with tf.GradientTape() as tape:
            loss, loss_adv, accuracy, predent = self.loss_fn(X_train, Y_train)
        grads = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
        return loss, loss_adv, accuracy, predent
    
    
    @tf.function
    def eval_ood_helper(self, X_real, X_fake):
        Y_hat_real = self.model(X_real)
        entropy_real = -1.0 * tf.reduce_mean(Y_hat_real * tf.math.log(Y_hat_real), axis=1)
        Y_hat_fake = self.model(X_fake)
        entropy_fake = -1.0 * tf.reduce_mean(Y_hat_fake * tf.math.log(Y_hat_fake), axis=1)
        return tf.concat([entropy_real, entropy_fake], axis=0)

    
    def eval_ood(self, X_real, X_fake):
        logits = trainer.eval_ood_helper(X_real, X_fake).numpy()
        labels = np.concatenate([np.zeros(128), np.ones(128)])
        auc = roc_auc_score(labels, logits)
        return auc

    @tf.function
    def get_calibration_metrics(self, X, Y):
        logits = tf.math.log(self.model(X))
        brier = tf.reduce_mean(tfp.stats.brier_score(labels=Y, logits=logits))
        ece = tfp.stats.expected_calibration_error(num_bins=20, logits=logits, labels_true=Y)
        nll = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)(y_true=Y, y_pred=logits)
        return brier, ece, nll
    
    
    def get_batch_dimension(self, X): # dimension of the last layer representations from a batch
        repr = self.model_last_layer(X).numpy()
        repr = StandardScaler().fit_transform(repr)
        pca = PCA()
        pca.fit(repr)
        explained_var = pca.explained_variance_
        explained_var /= explained_var.sum()
        return np.sum([explained_var.cumsum() <= 0.9]) # returns no. of dimensions that account for 90% of variance
        
        
    def train(self):
        
        # loop over data n_iters times
        for t in tqdm.trange(self.params['n_iters']):
            train_loss, train_loss_adv, train_acc, train_predent = self.step_fn(*next(self.iter_train))
            if t % 1000 == 0:
                print(train_loss, train_loss_adv)

            if t % 10 == 0:
                test_loss, test_loss_adv, test_acc, test_predent = self.loss_fn(*next(self.iter_test))
                train_dim = self.get_batch_dimension(next(self.iter_train)[0])
                test_dim = self.get_batch_dimension(next(self.iter_test)[0])
            self.history.append((train_loss.numpy(), test_loss.numpy(), train_acc.numpy(), test_acc.numpy(),
                                train_loss_adv.numpy(), test_loss_adv.numpy(), train_predent.numpy(), test_predent.numpy(), 
                                train_dim, test_dim))

        self.history = np.array(self.history)


In [7]:
# plotting utils

def plot_training_metrics(baseline_trainer, our_trainer, tags=['baseline', '']):
    
    trainer_vec = [baseline_trainer, our_trainer]
    
    c_vec = plt.rcParams['axes.prop_cycle'].by_key()['color']
    plt.figure(figsize=(20, 3))
    
    # accuracy
    plt.subplot(1, 4, 1)
    plt.title("Train/Test accuracy")
    for idx, (c, trainer) in enumerate(zip(c_vec, trainer_vec)):
        name = tags[idx]
        plt.plot(gaussian_filter1d(trainer.history[:, 2], 100), '-o', c=c, label='%s - train' % name, markevery=1000, markersize=10)
        plt.plot(gaussian_filter1d(trainer.history[:, 3], 100), '-^', c=c, label='%s - val' % name, markevery=1000, markersize=10)
    plt.grid()
    plt.xlabel('iterations')
    plt.ylabel('Accuracy')
    plt.legend()
    

    # entropy
    plt.subplot(1, 4, 2)
    plt.title("Train/Test predictive entropy")
    for idx, (c, trainer) in enumerate(zip(c_vec, trainer_vec)):
        name = tags[idx]
        plt.plot(np.abs(gaussian_filter1d(trainer.history[:, 6], 100)), '-o', c=c, label='%s - train' % name, markevery=1000, markersize=10)
        plt.plot(np.abs(gaussian_filter1d(trainer.history[:, 7], 100)), '-^', c=c, label='%s - val' % name, markevery=1000, markersize=10)
    plt.grid()
    plt.xlabel('iterations')
    plt.ylabel('Entropy')
    
    # dimensionality
    plt.subplot(1, 4, 3)
    plt.title("Train/Test last-layer-rep dim")
    for idx, (c, trainer) in enumerate(zip(c_vec, trainer_vec)):
        name = tags[idx]
        plt.plot(gaussian_filter1d(trainer.history[:, 8], 100), '-o', c=c, label='%s - train' % name, markevery=1000, markersize=10)
        plt.plot(gaussian_filter1d(trainer.history[:, 9], 100), '-^', c=c, label='%s - val' % name, markevery=1000, markersize=10)
    plt.grid()
    plt.xlabel('iterations')
    plt.ylabel('Adv Loss')


#     # loss
#     plt.subplot(1, 4, 3)
#     plt.title("Train/Test cross entropy loss")
#     for idx, (c, trainer) in enumerate(zip(c_vec, trainer_vec)):
#         name = tags[idx]
#         plt.plot(gaussian_filter1d(trainer.history[:, 0], 100), '-o', c=c, label='%s - train' % name, markevery=1000, markersize=10)
#         plt.plot(gaussian_filter1d(trainer.history[:, 1], 100), '-^', c=c, label='%s - val' % name, markevery=1000, markersize=10)
#     plt.grid()
#     plt.xlabel('iterations')
#     plt.ylabel('Loss')
    
        
#     # adv loss
#     plt.subplot(1, 4, 4)
#     plt.title("Train/Test adversarial loss")
#     for idx, (c, trainer) in enumerate(zip(c_vec, trainer_vec)):
#         name = tags[idx]
#         plt.plot(np.abs(gaussian_filter1d(trainer.history[:, 4], 100)), '-o', c=c, label='%s - train' % name, markevery=1000, markersize=10)
#         plt.plot(np.abs(gaussian_filter1d(trainer.history[:, 5], 100)), '-^', c=c, label='%s - val' % name, markevery=1000, markersize=10)
#     plt.grid()
#     plt.xlabel('iterations')
#     plt.ylabel('Adv Loss')


    plt.show()
    
    
    
def plot_OOD_metrics(results, tags):
    plt.figure(figsize=(12, 6))
    for (j, title) in enumerate(['Accuracy OOD $(\\uparrow)$', 'OOD AUC $(\\uparrow)$', 'Brier $(\\downarrow)$', 'ECE  $(\\downarrow)$', 'NLL  $(\\downarrow)$']):
        plt.subplot(2, 3, j + 1)
        plt.title(title)
        for idx in range(len(tags)):
            x = np.arange(len(corruption_type_list))
            y = [results[(ctype, idx)][j] for ctype in corruption_type_list]
            width = 0.25
            offset = width
            plt.bar(x + width * (idx + 1) - offset, y, width=width, label=f'{tags[idx]}')
        plt.xticks(np.arange(len(corruption_type_list)), corruption_type_list, rotation=90)
        plt.grid()
        if j == 0:
            plt.legend(fontsize=12, loc='lower right')
        
    plt.tight_layout()
    plt.show()

In [None]:
# baseline
params = {
    'input_shape': (32, 32, 3),
    'n_classes': 10,
    'lambda': 0.0,
    'lr': 5e-4,
    'step_size': 0.,
    'n_iters': 12000,
    'version': 'NA'
}
baseline_trainer = Trainer(params)
baseline_trainer.train()


# min max cross-entropy
params = {
    'input_shape': (32, 32, 3),
    'n_classes': 10,
    'lambda': 10.0,
    'lr': 5e-4,
    'step_size': 0.2,
    'n_iters': 12000,
    'version': 'min-max-cent'
}
min_max_cent_trainer = Trainer(params)
min_max_cent_trainer.train()



# max-min-ent
params = {
    'input_shape': (32, 32, 3),
    'n_classes': 10,
    'lambda': 10.0,
    'lr': 5e-4,
    'step_size': 0.2,
    'n_iters': 12000,
    'version': 'max-min-ent'
}
max_min_ent_trainer = Trainer(params)
max_min_ent_trainer.train()


# min max KL with unif
params = {
    'input_shape': (32, 32, 3),
    'n_classes': 10,
    'lambda': 0.25,
    'lr': 5e-4,
    'step_size': 0.1,
    'n_iters': 12000,
    'version': 'min-max-KL-unif'
}
min_max_KL_unif_trainer = Trainer(params)
min_max_KL_unif_trainer.train()


# label smoothing
params = {
    'input_shape': (32, 32, 3),
    'n_classes': 10,
    'lambda': 0.,
    'lr': 5e-4,
    'step_size': 0.,
    'n_iters': 12000,
    'version': 'label-smoothing',
    'label-smoothing-factor': 0.1
}
label_smoothing_trainer = Trainer(params)
label_smoothing_trainer.train()


Run training with params {'input_shape': (32, 32, 3), 'n_classes': 10, 'lambda': 0.0, 'lr': 0.0005, 'step_size': 0.0, 'n_iters': 12000, 'version': 'NA'}


2021-10-05 09:51:28.764960: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set
2021-10-05 09:51:28.765353: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:941] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-10-05 09:51:28.766101: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 0 with properties: 
pciBusID: 0000:00:04.0 name: Tesla P100-PCIE-16GB computeCapability: 6.0
coreClock: 1.3285GHz coreCount: 56 deviceMemorySize: 15.90GiB deviceMemoryBandwidth: 681.88GiB/s
2021-10-05 09:51:28.766152: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
2021-10-05 09:51:28.766232: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
2021-10-05 09:51:28.766254: I tensorflow/stream_executor/platform/

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 30, 30, 32)        896       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 15, 15, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 13, 13, 64)        18496     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 6, 6, 64)          0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 4, 4, 64)          36928     
_________________________________________________________________
flatten (Flatten)            (None, 1024)              0         
_________________________________________________________________
dense (Dense)                (None, 64)                6

  0%|                                                 | 0/12000 [00:00<?, ?it/s]2021-10-05 09:51:29.766382: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2021-10-05 09:51:29.766977: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2199995000 Hz
2021-10-05 09:51:33.369259: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
2021-10-05 09:51:33.670315: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublasLt.so.11
2021-10-05 09:51:33.673473: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.8


tf.Tensor(2.3119864, shape=(), dtype=float32) tf.Tensor(0.0, shape=(), dtype=float32)


  9%|███▏                                 | 1024/12000 [00:10<00:52, 209.89it/s]

tf.Tensor(1.1929727, shape=(), dtype=float32) tf.Tensor(0.0, shape=(), dtype=float32)


 17%|██████▎                              | 2030/12000 [00:15<00:44, 223.50it/s]

tf.Tensor(1.0098134, shape=(), dtype=float32) tf.Tensor(0.0, shape=(), dtype=float32)


 25%|█████████▎                           | 3030/12000 [00:19<00:42, 212.07it/s]

tf.Tensor(0.72637284, shape=(), dtype=float32) tf.Tensor(0.0, shape=(), dtype=float32)


 34%|████████████▍                        | 4038/12000 [00:24<00:36, 218.59it/s]

tf.Tensor(0.80525386, shape=(), dtype=float32) tf.Tensor(0.0, shape=(), dtype=float32)


 42%|███████████████▍                     | 5024/12000 [00:29<00:32, 216.50it/s]

tf.Tensor(0.90484774, shape=(), dtype=float32) tf.Tensor(0.0, shape=(), dtype=float32)


 50%|██████████████████▌                  | 6037/12000 [00:33<00:27, 217.04it/s]

tf.Tensor(0.77912486, shape=(), dtype=float32) tf.Tensor(0.0, shape=(), dtype=float32)


 59%|█████████████████████▋               | 7021/12000 [00:38<00:23, 214.49it/s]

tf.Tensor(0.75604033, shape=(), dtype=float32) tf.Tensor(0.0, shape=(), dtype=float32)


 67%|████████████████████████▋            | 8023/12000 [00:43<00:18, 215.36it/s]

tf.Tensor(0.5690528, shape=(), dtype=float32) tf.Tensor(0.0, shape=(), dtype=float32)


 75%|███████████████████████████▉         | 9041/12000 [00:47<00:13, 213.27it/s]

tf.Tensor(0.7019794, shape=(), dtype=float32) tf.Tensor(0.0, shape=(), dtype=float32)


 84%|██████████████████████████████      | 10022/12000 [00:52<00:09, 207.50it/s]

tf.Tensor(0.61864173, shape=(), dtype=float32) tf.Tensor(0.0, shape=(), dtype=float32)


 92%|█████████████████████████████████   | 11031/12000 [00:57<00:04, 211.54it/s]

tf.Tensor(0.49456102, shape=(), dtype=float32) tf.Tensor(0.0, shape=(), dtype=float32)


100%|████████████████████████████████████| 12000/12000 [01:01<00:00, 194.40it/s]



Run training with params {'input_shape': (32, 32, 3), 'n_classes': 10, 'lambda': 10.0, 'lr': 0.0005, 'step_size': 0.2, 'n_iters': 12000, 'version': 'min-max-cent'}
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_3 (Conv2D)            (None, 30, 30, 32)        896       
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 15, 15, 32)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 13, 13, 64)        18496     
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 6, 6, 64)          0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 4, 4, 64)          36928     
_________________________________________________________________
flatten_1 (Flatten)  

  0%|                                                 | 0/12000 [00:00<?, ?it/s]

tf.Tensor(25.611135, shape=(), dtype=float32) tf.Tensor(2.3302844, shape=(), dtype=float32)


  8%|███▏                                  | 1019/12000 [00:14<01:53, 97.14it/s]

tf.Tensor(15.611682, shape=(), dtype=float32) tf.Tensor(1.4344823, shape=(), dtype=float32)


 10%|███▉                                  | 1226/12000 [00:16<01:48, 99.46it/s]

In [None]:
# plot
plot_training_metrics(baseline_trainer, min_max_cent_trainer, ['baseline', 'min-max-cent'])
plot_training_metrics(baseline_trainer, max_min_ent_trainer, ['baseline', 'max-min-ent'])
plot_training_metrics(baseline_trainer, min_max_KL_unif_trainer, ['baseline', 'min-max-KL-unif'])
plot_training_metrics(baseline_trainer, label_smoothing_trainer, ['baseline', 'label-smoothing'])

In [None]:
# compute OOD metrics
trainer_vec = [baseline_trainer, our_trainer_1, our_trainer_2]
results = {}
corruption_type_list = ['brightness_1', 'elastic_5', 'fog_5', 'frost_5', 'frosted_glass_blur_5']

for idx, trainer in enumerate(trainer_vec):
    print(f"Trainer {idx + 1}")
    for corruption_type in corruption_type_list:
        print(corruption_type)
        
        # load corrupted dataset
        (ds_corrupted,) = tfds.load(
            'cifar10_corrupted/%s' % corruption_type,
            split=['test'],
            shuffle_files=True,
            as_supervised=True,
            with_info=False,
        )
        
        # make dataset and get iterator
        ds_corrupted = make_ds(ds_corrupted)
        iter_corrupted = iter(ds_corrupted)

        # record metrics
        acc_OOD_vec = []
        auc_vec = []
        brier_vec = []
        ece_vec = []
        nll_vec = []
        
        # average over a 100 batches
        for _ in tqdm.trange(100):
            X_test, Y_test = next(trainer.iter_test)
            X_corrupted, Y_corrupted = next(iter_corrupted)
            _, _, test_acc_OOD = trainer.loss_fn(X_corrupted, Y_corrupted)
            test_auc = trainer.eval_ood(X_test, X_corrupted)
            brier, ece, nll = trainer.get_calibration_metrics(X_corrupted, Y_corrupted)

            acc_OOD_vec.append(test_acc_OOD.numpy())
            auc_vec.append(test_auc)
            brier_vec.append(brier.numpy())
            ece_vec.append(ece.numpy())
            nll_vec.append(nll.numpy())
        
        results[(corruption_type, idx)] = (np.mean(acc_OOD_vec), np.mean(auc_vec), np.mean(brier_vec), np.mean(ece_vec), np.mean(nll_vec))

plot_OOD_metrics(results)