In [None]:
import tensorflow as tf
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
import numpy as np
import tensorflow_datasets as tfds
from tiny_imagenet import TinyImagenetDataset
import tqdm
from tensorflow.keras.models import Model
from sklearn.metrics import roc_auc_score
from sklearn.calibration import calibration_curve
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import tensorflow_probability as tfp
from scipy.ndimage import gaussian_filter1d

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
  tf.config.experimental.set_memory_growth(gpu, True)

In [None]:
import imgaug as ia
from imgaug import augmenters as iaa


def tf_random_rotate_image(image, label):
  im_shape = image.shape
  [image,] = tf.py_function(random_rotate_image, [image], [tf.float32])
  image.set_shape(im_shape)
  return image, label


def augment(image, label):
    # print(ia.is_np_array(image.numpy()))
    # tf.print(image)
    sometimes = lambda aug: iaa.Sometimes(0.5, aug)
    seq = iaa.Sequential([
    iaa.Fliplr(0.5),
    iaa.Flipud(0.2),
    sometimes(iaa.GaussianBlur(sigma=(0, 2.0))),
    sometimes(iaa.CropAndPad(
        percent=(-0.1, 0.2),
        pad_mode=ia.ALL,
        pad_cval=(0, 255))),
    sometimes(iaa.Affine(
        scale={"x": (0.8, 1.5), "y": (0.8, 1.5)},
        translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)},
        rotate=(-45, 45),
        shear=(-16, 16),
        order=[0, 1],
        cval=(0, 255),
        mode=ia.ALL)),
    sometimes(iaa.CoarseDropout(
          (0.03, 0.15), size_percent=(0.02, 0.05),
          per_channel=0.2)),], random_order = True)
    
    image = tf.py_function(func=seq.augment_image, inp=[image], Tout=[tf.float32])
    
    return image, label

def make_ds(ds):
    ds = ds.map(
        process_example_dict, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds = ds.cache()
    ds = ds.map(augment, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds = ds.batch(128)
    ds = ds.prefetch(256)
    ds = ds.repeat()
    return ds

In [None]:
# helper functions to create TF datasets

resize = tf.keras.layers.Resizing(
    224, 224, interpolation='bilinear', crop_to_aspect_ratio=False)


def process_example_dict(example_dict):
    image, label = example_dict['image'], example_dict['label']
    image = tf.cast(image, tf.float32) / 255.
    return image, label


def normalize_img(image, label):
    """Normalizes images: `uint8` -> `float32`."""
    image=resize(image)
    image = tf.cast(image, tf.float32) / 255.
    # image = tf.reshape(image, (-1,))
    # tf.print(image.shape)
    return image, label


def make_ds(ds):
    ds = ds.map(
        process_example_dict, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds = ds.cache()
    ds = ds.repeat()
    ds = ds.batch(128)
    ds = ds.prefetch(256)
    return ds

In [None]:
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.layers import Dense, Conv2D,  MaxPool2D, Flatten, GlobalAveragePooling2D,  BatchNormalization, Layer, Add
from tensorflow.keras.models import Sequential
from tensorflow.keras.models import Model
import tensorflow as tf


class ResnetBlock(Model):
    """
    A standard resnet block.
    """

    def __init__(self, channels: int, down_sample=False):
        """
        channels: same as number of convolution kernels
        """
        super().__init__()

        self.__channels = channels
        self.__down_sample = down_sample
        self.__strides = [2, 1] if down_sample else [1, 1]

        KERNEL_SIZE = (3, 3)
        # use He initialization, instead of Xavier (a.k.a 'glorot_uniform' in Keras), as suggested in [2]
        INIT_SCHEME = "he_normal"

        self.conv_1 = Conv2D(self.__channels, strides=self.__strides[0],
                             kernel_size=KERNEL_SIZE, padding="same", kernel_initializer=INIT_SCHEME)
        self.bn_1 = BatchNormalization()
        self.conv_2 = Conv2D(self.__channels, strides=self.__strides[1],
                             kernel_size=KERNEL_SIZE, padding="same", kernel_initializer=INIT_SCHEME)
        self.bn_2 = BatchNormalization()
        self.merge = Add()

        if self.__down_sample:
            # perform down sampling using stride of 2, according to [1].
            self.res_conv = Conv2D(
                self.__channels, strides=2, kernel_size=(1, 1), kernel_initializer=INIT_SCHEME, padding="same")
            self.res_bn = BatchNormalization()

    def call(self, inputs):
        res = inputs

        x = self.conv_1(inputs)
        x = self.bn_1(x)
        x = tf.nn.relu(x)
        x = self.conv_2(x)
        x = self.bn_2(x)

        if self.__down_sample:
            res = self.res_conv(res)
            res = self.res_bn(res)

        # if not perform down sample, then add a shortcut directly
        x = self.merge([x, res])
        out = tf.nn.relu(x)
        return out


class ResNet18(Model):

    def __init__(self, num_classes, **kwargs):
        """
            num_classes: number of classes in specific classification task.
        """
        super().__init__(**kwargs)
        self.conv_1 = Conv2D(64, (7, 7), strides=2,
                             padding="same", kernel_initializer="he_normal")
        self.init_bn = BatchNormalization()
        self.pool_2 = MaxPool2D(pool_size=(2, 2), strides=2, padding="same")
        self.res_1_1 = ResnetBlock(64)
        self.res_1_2 = ResnetBlock(64)
        self.res_2_1 = ResnetBlock(128, down_sample=True)
        self.res_2_2 = ResnetBlock(128)
        self.res_3_1 = ResnetBlock(256, down_sample=True)
        self.res_3_2 = ResnetBlock(256)
        self.res_4_1 = ResnetBlock(512, down_sample=True)
        self.res_4_2 = ResnetBlock(512)
        self.avg_pool = GlobalAveragePooling2D()
        self.flat = Flatten()
        self.fc = Dense(num_classes, activation="softmax")

    def call(self, inputs):
        out = self.conv_1(inputs)
        out = self.init_bn(out)
        out = tf.nn.relu(out)
        out = self.pool_2(out)
        for res_block in [self.res_1_1, self.res_1_2, self.res_2_1, self.res_2_2, self.res_3_1, self.res_3_2, self.res_4_1, self.res_4_2]:
            out = res_block(out)
        out = self.avg_pool(out)
        out = self.flat(out)
        out = self.fc(out)
        return out

In [None]:

import numpy as np
from keras.layers import Dense, Dropout, Flatten, Input, GlobalAveragePooling2D, merge, Activation, ZeroPadding2D, Conv2D, MaxPooling2D, BatchNormalization, Concatenate, GlobalMaxPooling2D
from keras.regularizers import l2
from keras.models import Model, Sequential

COMPRESSION = 1.0
CHANNEL = 3
NUM_FILTER = 128
DROPOUT_RATE = 0.
N_LAYERS = 3

# Dense Block
def add_denseblock(input, num_filter = 12, dropout_rate = 0.2):
  temp = input
  for _ in range(N_LAYERS):
      BatchNorm = BatchNormalization(epsilon=1.1e-5)(temp)
      relu = Activation('relu')(BatchNorm)
      # kernel_regularizer to regularze kernel weights
      # l2 for penallizing weights with large magnitudes
      Conv2D_3_3 = Conv2D(int(num_filter*COMPRESSION), (3,3), use_bias=False, padding='same', kernel_regularizer=l2(0.0002))(relu) 
      if dropout_rate>0:
        Conv2D_3_3 = Dropout(dropout_rate)(Conv2D_3_3)
      concat = Concatenate(axis=-1)([temp,Conv2D_3_3])
      temp = concat
  return temp


def add_transition(input, num_filter = 12, dropout_rate = 0.2):
  BatchNorm = BatchNormalization(epsilon=1.1e-5)(input)
  relu = Activation('relu')(BatchNorm)
  # kernel_regularizer to regularize kernel weights
  # l2 for penallizing weights with large magnitudes
  Conv2D_BottleNeck = Conv2D(int(num_filter*COMPRESSION), (1,1), use_bias=False ,padding='same',kernel_regularizer=l2(0.0002))(relu)
  if dropout_rate>0:
    Conv2D_BottleNeck = Dropout(dropout_rate)(Conv2D_BottleNeck)
  avg = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(Conv2D_BottleNeck)
  return avg


def output_layer(input, num_classes):
  BatchNorm = BatchNormalization()(input)
  relu = Activation('relu')(BatchNorm)
  conv2D = Conv2D(num_classes, (1,1), use_bias=False ,kernel_regularizer=l2(0.0002))(relu)
  BatchNorm = BatchNormalization()(conv2D)
  relu = Activation('relu')(BatchNorm)
  GAP = GlobalAveragePooling2D()(relu)
  output = Activation('softmax')(GAP)    
  return output


def DenseNet(num_classes):
  input = Input(shape=(None, None, CHANNEL,))
  First_Conv2D = Conv2D(NUM_FILTER, (3, 3), use_bias=False, padding='same')(input)
  First_Block = add_denseblock(First_Conv2D, NUM_FILTER, DROPOUT_RATE)
  First_Transition = add_transition(First_Block, num_filter=256, dropout_rate=DROPOUT_RATE)
  Second_Block = add_denseblock(First_Transition, NUM_FILTER, DROPOUT_RATE)
  Second_Transition = add_transition(Second_Block, num_filter=320,dropout_rate=DROPOUT_RATE)
  Third_Block = add_denseblock(Second_Transition, NUM_FILTER, DROPOUT_RATE)
  Third_Transition = add_transition(Third_Block, num_filter=384, dropout_rate=DROPOUT_RATE)
  Fourth_Block = add_denseblock(Third_Transition, NUM_FILTER, DROPOUT_RATE)
  Fourth_Transition = add_transition(Fourth_Block, num_filter=512, dropout_rate=DROPOUT_RATE)
  Fifth_Block = add_denseblock(Fourth_Transition, NUM_FILTER, DROPOUT_RATE)
  output = output_layer(Fifth_Block, num_classes)
  model = Model(inputs=[input], outputs=[output])
  return model



In [None]:
# create dataset, iterate over it and train, return model and metrics

EPS = 1e-3
    
class Trainer:
    
    def __init__(self, params):
        self.params = params
        print(f"\nRun training with params {self.params}")
        
        # create imgnet dataset and get iterators
        tiny_imagenet_builder = TinyImagenetDataset()
        tiny_imagenet_builder.download_and_prepare()
        ds_train = tiny_imagenet_builder.as_dataset(split="train")
        ds_test = tiny_imagenet_builder.as_dataset(split="validation")

        # ds_train, ds_test = out = tfds.load(
        #     'imagenet_resized/64x64',
        #     split=['train', 'validation'],
        #     data_dir='/amrith/tensorflow_datasets',
        #     shuffle_files=True,
        #     as_supervised=True,
        #     with_info=False,
        # )
        
        ds_train = make_ds(ds_train)
        ds_test = make_ds(ds_test)
        self.iter_train = iter(ds_train)
        self.iter_test = iter(ds_test)

        # define model
        # self.model = tf.keras.models.Sequential([
        #     tf.keras.layers.RandomFlip(mode='horizontal', input_shape=self.params['input_shape']),
        #     tf.keras.layers.RandomTranslation(0.1, 0.1),
        #     tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same', input_shape=params['input_shape']),
        #     tf.keras.layers.BatchNormalization(fused=True),
        #     tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
        #     tf.keras.layers.BatchNormalization(fused=True),
        #     tf.keras.layers.MaxPooling2D((2, 2)),
        #     tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
        #     tf.keras.layers.BatchNormalization(fused=True),
        #     tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
        #     tf.keras.layers.BatchNormalization(fused=True),
        #     tf.keras.layers.MaxPooling2D((2, 2)),
        #     tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
        #     tf.keras.layers.BatchNormalization(fused=True),
        #     tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
        #     tf.keras.layers.BatchNormalization(fused=True),
        #     tf.keras.layers.MaxPooling2D((2, 2)),
        #     tf.keras.layers.Flatten(),
        #     tf.keras.layers.Dropout(0.2),
        #     tf.keras.layers.Dense(1024, activation='relu'),
        #     tf.keras.layers.Dropout(0.2),
        #     tf.keras.layers.Dense(self.params['n_classes'], activation='softmax')])
    
    
        self.model = DenseNet(self.params['n_classes']) 
        # ResNet18(self.params['n_classes'])
        tf.autograph.experimental.do_not_convert(self.model.build)
        self.model.build(input_shape = (None, *(self.params['input_shape'])))
        # self.model = tf.keras.applications.resnet50.ResNet50(
        #     include_top=True, weights=None, input_shape=self.params['input_shape'])
        
        print(self.model.summary())

        # define optimizer
        self.optimizer = tfa.optimizers.SGDW(learning_rate=self.params['lr'], weight_decay=params['weight_decay'], nesterov=True)
        
        # maintain history
        self.history = []
        
        # get last layer reps
        # self.model_last_layer = Model(self.model.input, self.model.layers[-2].output)
        
        # temperature for platt scaling
        self.tau = tf.Variable(tf.ones((1, self.params['n_classes'])))

    
    def eval_ood(self, X_real, X_fake):
        logits = eval_ood_helper(X_real, X_fake, self.tau).numpy()
        labels = np.concatenate([np.zeros(128), np.ones(128)])
        auc = roc_auc_score(labels, logits)
        return auc

    
    def get_batch_dimension(self, repr): # dimension of the last layer representations from a batch
        repr = StandardScaler().fit_transform(repr)
        pca = PCA()
        pca.fit(repr)
        explained_var = pca.explained_variance_
        explained_var /= (explained_var.sum() + EPS)
        return np.sum([explained_var.cumsum() <= 0.9]) # returns no. of dimensions that account for 90% of variance
    
    def get_model_weights(self):
        params = np.array([])
        for layer in t.model.layers:
            for wt in layer.trainable_variables: 
                params = np.concatenate([params, wt.numpy().flatten()], axis=0)
        return params
    
    def set_model_weights(self, param):
        idx=0
        for layer in self.model.layers:
            for wt in layer.trainable_variables: 
                wt.assign(param[prev:prev+np.prod(wt.shape)].reshape(wt.shape))
                idx+=np.prod(wt.shape)
                
    def evaluate_n_random_batches(self, n=10):
        loss = 0.
        for _ in range(n):
            loss += self.loss_fn(*(self.iter_train), trainable=False)
        return loss / n
        
    def compute_sharpness_metric(self, p=100, delta=0.001):
        x_0 = self.get_model_weights() 
        A = tf.random.normal((x_0.shape[0], p))
        proj = tf.linalg.pinv(A) @ x_0
        y_min = (tf.math.abs(proj)+1)*delta
        y_max = (tf.math.abs(proj)+1)*(-delta)
        y_0 = tf.Variable(np.random.zeros(p), trainable=True)
        # for LBFS solver, returns func evaluation and gradient
        def f(y):
            with tf.GradientTape() as tape:
                tape.watch()
                self.set_model_weights(x_0 + A@y)
                loss = - self.evaluate_n_random_batches(n=10)
                # we want to maximize the loss hence, the negative sign
            return loss, tape.gradient(loss, y)
        _, neg_maxf, _ = scipy.optimize.fmin_l_bfgs_b(
            f, y_0, bounds=zip(y_min, y_max), maxiter=10)
        maxf = -neg_maxf
        fx = self.evaluate_n_random_batches(n=10)
        sharpness = (maxf - fx) * 100. / (1 + fx)
        # reset model weights
        self.set_model_weights(x_0)
        return sharpness


        
    def train(self):
        
        @tf.function
        def baseline(X, Y, training):
            print("Tracing baseline")
            Y_hat = self.model(X, training=training) 
            ce_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
            accuracy = tf.reduce_mean(tf.cast(tf.argmax(Y_hat, axis=1) == Y, tf.float32)) * 100.
            entropy_on_original_point = tf.reduce_mean(-1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat + EPS), axis=1))
            return ce_loss, ce_loss, 0., accuracy, entropy_on_original_point

        @tf.function
        def min_max_cent(X, Y, training, params):
            # compute gradient of cross entropy loss wrt X and take a step in +ve direction
            # this would try to find a point in the neighborhood of X that maximizes cross entropy
            Y_hat = self.model(X, training=training) 
            ce_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
            accuracy = tf.reduce_mean(tf.cast(tf.argmax(Y_hat, axis=1) == Y, tf.float32)) * 100.
            entropy_on_original_point = tf.reduce_mean(-1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat + EPS), axis=1))    
            if params['step_size'] > 0.:
                with tf.GradientTape() as tape:
                    tape.watch(X)
                    Y_hat = self.model(X, training=training)
                    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
                grads = tape.gradient(loss, X)
                grads_norm = tf.norm(tf.reshape(grads, (128, -1)), axis=1)
                grads = params['step_size'] * grads / grads_norm[:, None, None, None]
                X_perturbed = tf.clip_by_value(X + grads, 0.0, 1.0)
            else:
                X_perturbed = X
            # compute cross entropy at this new point
            Y_hat = self.model(X_perturbed, training=training)
            loss_adv = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
            return ce_loss + params['lambda'] * loss_adv, ce_loss, loss_adv, accuracy, entropy_on_original_point

        @tf.function
        def max_min_ent(X, Y, training, params):
            # compute grad of entropy wrt X and take a step in negative direction
            # this would find a point in the neighborhood of X that would minimize entropy
            Y_hat = self.model(X, training=training) 
            ce_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
            accuracy = tf.reduce_mean(tf.cast(tf.argmax(Y_hat, axis=1) == Y, tf.float32)) * 100.
            entropy_on_original_point = tf.reduce_mean(-1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat + EPS), axis=1))    
            if params['step_size'] > 0.:
                with tf.GradientTape() as tape:
                    tape.watch(X)
                    Y_hat = self.model(X, training=training)
                    exp_neg_entropy = tf.exp(tf.reduce_mean(Y_hat * tf.math.log(Y_hat + EPS), axis=1))
                grads = tape.gradient(exp_neg_entropy, X)
                grads_norm = tf.norm(tf.reshape(grads, (128, -1)), axis=1)
                grads = params['step_size'] * grads / grads_norm[:, None, None, None]      
                X_perturbed = tf.clip_by_value(X + grads, 0.0, 1.0)
            else:
                X_perturbed = X
            # compute entropy at this new point and multiply it by -1 (raise to exp. for better grads), since we want to maximize entropy
            Y_hat = self.model(X_perturbed, training=training)
            entropy = -1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat + EPS), axis=1)
            loss_adv = tf.reduce_mean(tf.exp(-1.0 * entropy))
            return ce_loss + params['lambda'] * loss_adv, ce_loss, loss_adv, accuracy, entropy_on_original_point

        @tf.function
        def min_max_KL_unif(X, Y, training, params):
            # compute grad of KL(unif, p_\theta) wrt X and take a step in +ve direction
            # this would find a point in the neighborhood of X that would maximize KL
            Y_hat = self.model(X, training=training) 
            ce_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
            accuracy = tf.reduce_mean(tf.cast(tf.argmax(Y_hat, axis=1) == Y, tf.float32)) * 100.
            entropy_on_original_point = tf.reduce_mean(-1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat + 1e-3), axis=1)) 
            if params['step_size'] > 0.:
                with tf.GradientTape() as tape:
                    tape.watch(X)
                    Y_hat = self.model(X, training=training)
                    KL_unif = tf.keras.losses.CategoricalCrossentropy(from_logits=False)(
                        y_true=tf.ones_like(Y_hat) / params['n_classes'], y_pred=Y_hat)
                grads = tape.gradient(KL_unif, X)
                grads_norm = tf.norm(tf.reshape(grads, (128, -1)), axis=1)
                grads = self.params['step_size'] * grads / grads_norm[:, None, None, None]      
                X_perturbed = tf.clip_by_value(X + grads, 0.0, 1.0)
            else:
                X_perturbed = X
            # compute entropy at this new point and multiply it by -1, since we want to maximize entropy
            Y_hat = self.model(X_perturbed, training=training)
            KL_unif = tf.keras.losses.CategoricalCrossentropy(from_logits=False)(
                y_true=tf.ones_like(Y_hat) / params['n_classes'], y_pred=Y_hat)
            loss_adv = tf.reduce_mean(KL_unif)
            return ce_loss + params['lambda'] * loss_adv, ce_loss, loss_adv, accuracy, entropy_on_original_point

        @tf.function
        def label_smoothing(X, Y, training, params):
            Y_hat = self.model(X, training=training)
            ce_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
            accuracy = tf.reduce_mean(tf.cast(tf.argmax(Y_hat, axis=1) == Y, tf.float32)) * 100.
            entropy_on_original_point = tf.reduce_mean(-1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat + 1e-3), axis=1)) 
            Y = tf.one_hot(Y, params['n_classes'])
            Y_noisy = Y * (1 - params['label-smoothing-factor']) 
            Y_noisy += (self.params['label-smoothing-factor'] / tf.cast(params['n_classes'], tf.float32))
            noisy_loss = tf.keras.losses.CategoricalCrossentropy(from_logits=False)(y_true=Y_noisy, y_pred=Y_hat)
            return noisy_loss, ce_loss, 0., accuracy, entropy_on_original_point

        @tf.function
        def get_calibration_metrics(X, Y):
            logits = tf.math.log(self.model(X, training=False) + EPS) * tf.repeat(self.tau, X.shape[0], axis=0)
            brier = tf.reduce_mean(tfp.stats.brier_score(labels=Y, logits=logits))
            ece = tfp.stats.expected_calibration_error(num_bins=20, logits=logits, labels_true=Y)
            nll = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)(y_true=Y, y_pred=logits)
            return brier, ece, nll


        @tf.function
        def eval_ood_helper(X_real, X_fake):
            Y_hat_real = tf.nn.softmax(tf.math.log(self.model(X_real, training=False) + EPS) * tf.repeat(self.tau, X_real.shape[0], axis=0))
            entropy_real = -1.0 * tf.reduce_mean(Y_hat_real * tf.math.log(Y_hat_real + EPS), axis=1)
            Y_hat_fake = tf.nn.softmax(tf.math.log(self.model(X_fake, training=False) + EPS) * tf.repeat(self.tau, X_fake.shape[0], axis=0))
            entropy_fake = -1.0 * tf.reduce_mean(Y_hat_fake * tf.math.log(Y_hat_fake + EPS), axis=1)
            return tf.concat([entropy_real, entropy_fake], axis=0)



        # define loss function
        # computes loss given a model and X, Y
        @ tf.function
        def loss_fn(X, Y, training, params):

            if params['version'] == 'baseline':
                return baseline(X, Y, training)
            elif params['version'] == 'min-max-cent':
                return min_max_cent(X, Y, training, params)
            elif params['version'] == 'max-min-ent':
                return max_min_ent(X, Y, training, params)
            elif params['version'] == 'min-max-KL-unif': 
                return min_max_KL_unif(X, Y, training, params)
            elif params['version'] == 'label-smoothing':
                return label_smoothing(X, Y, training, params)
            else:
                raise ValueError
                
        # define step function
        # computes gradients and applies them
        @tf.function
        def step_fn(X, Y, params):
            with tf.GradientTape() as tape:
                loss, cent_loss, loss_adv, accuracy, predent = loss_fn(X, Y, True, params)
            grads = tape.gradient(loss, self.model.trainable_variables)
            self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
            return loss, cent_loss, loss_adv, accuracy, predent
        
        # loop over data n_iters times
        for t in tqdm.trange(self.params['n_iters']):
            train_loss, train_loss_cent, train_loss_adv, train_acc, train_predent = step_fn(*next(self.iter_train), self.params)    
            if t % 20 == 0:
                test_loss = []
                test_loss_cent = []
                test_loss_adv = []
                test_acc = []
                test_predent = []
                for _ in range(10):
                    res = loss_fn(*next(self.iter_test), False, self.params)
                    test_loss.append(res[0].numpy())
                    test_loss_cent.append(res[1].numpy())
                    test_loss_adv.append(res[2].numpy())
                    test_acc.append(res[3].numpy())
                    test_predent.append(res[4].numpy())
                # train_dim = self.get_batch_dimension(self.model_last_layer(next(self.iter_train)[0], training=False))
                # test_dim = self.get_batch_dimension(self.model_last_layer(next(self.iter_test)[0], training=False))
            self.history.append((train_loss.numpy(), np.mean(test_loss), train_acc.numpy(), np.mean(test_acc),
                                train_loss_adv.numpy(), np.mean(test_loss_adv), train_predent.numpy(), np.mean(test_predent)))
                                # train_dim, test_dim))
            
            if t % 1000 == 0:
                tf.print("Tr Total:", train_loss, "Tr CE:", train_loss_cent, "Tr Adv:", train_loss_adv, "Tr Acc:", train_acc, "Test Acc", self.history[-1][3])
            
            if ('lambda_schedule' in self.params) and ((t+1) % self.params['lambda_schedule']['frequency'] == 0):
                self.params['lambda'] *= self.params['lambda_schedule']['factor']
                
            if ('lr_schedule' in self.params) and ((t+1) % self.params['lr_schedule']['frequency'] == 0):
                self.optimizer.lr.assign(self.optimizer.lr * self.params['lr_schedule']['factor'])
                

        self.history = np.array(self.history)

        
# post hoc platt scaling
def calibrate_model(trainer):
    # loop over data n_iters times
    tau_optimizer = tf.keras.optimizers.Adam(learning_rate=trainer.params['lr-calibrator']*10.)
    for t in tqdm.trange(trainer.params['n_iters']//2):
        X, Y = next(trainer.iter_train)
        Y_pred_logits = tf.math.log(trainer.model(X, training=False) + EPS)
        with tf.GradientTape() as tape:
            tape.watch(trainer.tau)
            Y_pred_logits *= tf.repeat(trainer.tau, Y_pred_logits.shape[0], axis=0)
            loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)(y_true=Y, y_pred=Y_pred_logits)
        grad_tau = tape.gradient(loss, trainer.tau)
        grad_tau = tf.clip_by_norm(grad_tau, 0.1)
        tau_optimizer.apply_gradients([(grad_tau, trainer.tau)])
        if t % 1000 == 0:
            print(f"Calibration Loss: {loss}, {trainer.tau}")

In [None]:
# baseline
params = {
    'input_shape': (64, 64, 3),
    'n_classes': 200,
    'lambda': 0.0,
    'lr': 5e-4,
    'lr-calibrator': 1e-4,
    'step_size': 0.,
    'n_iters': 10000,
    'version': 'baseline',
    'weight_decay': 0.0001,
    'lr_schedule':{
        'frequency':8000,
        'factor': 1.
    },
    'label-smoothing-factor': 0.
}
baseline_trainer = Trainer(params)
baseline_trainer.train()

In [None]:
# max-min-ent
params = {
    'input_shape': (64, 64, 3),
    'n_classes': 200,
    'lambda': 5.,
    'lr': 5e-4,
    'lr-calibrator': 1e-4,
    'step_size': 0.2,
    'n_iters': 10000,
    'version': 'max-min-ent',
    'weight_decay': 0.0001,
    'lr_schedule':{
        'frequency':8000,
        'factor': 0.1
    },
    'lambda_schedule': {
        'frequency': 500,
        'factor': 1.
    }
}
max_min_ent_trainer = Trainer(params)
max_min_ent_trainer.train()

In [None]:
# min-max-KL-unif
params = {
    'input_shape': (64, 64, 3),
    'n_classes': 200,
    'lambda': 0.1,
    'lr': 5e-4,
    'lr-calibrator': 1e-4,
    'step_size': 0.2,
    'n_iters': 10000,
    'version': 'min-max-KL-unif',
    'weight_decay': 0.0001,
    'lr_schedule':{
        'frequency':8000,
        'factor': 0.1
    },
    'lambda_schedule': {
        'frequency': 500,
        'factor': 1.
    }
}
min_max_KL_unif_trainer = Trainer(params)
min_max_KL_unif_trainer.train()

In [None]:
# compute ID metrics
ID_results = [
    [np.mean(baseline_trainer.history[:, 2][-20:]), np.mean(baseline_trainer.history[:, 3][-20:])],
    [np.mean(max_min_ent_trainer.history[:, 2][-20:]), np.mean(max_min_ent_trainer.history[:, 3][-20:])],    
    # [np.mean(min_max_KL_unif_trainer.history[:, 2][-20:]), np.mean(min_max_KL_unif_trainer.history[:, 3][-20:])],
]
plot_ID_metrics(ID_results, tags = ['baseline', 'max_min_ent',])
# 'min_max_KL_unif'])
plot_training_metrics(baseline_trainer, min_max_KL_unif_trainer, ['baseline', 'max_min_ent',])
# 'min_max_KL_unif'])

In [None]:
# min-max-KL-unif
params = {
    'input_shape': (64, 64, 3),
    'n_classes': 200,
    'lambda': 0.2,
    'lr': 5e-4,
    'lr-calibrator': 1e-4,
    'step_size': 0.2,
    'n_iters': 20000,
    'version': 'min-max-KL-unif',
    'weight_decay': 0.0001,
    'lr_schedule':{
        'frequency':16000,
        'factor': 0.1
    },
    'lambda_schedule': {
        'frequency': 500,
        'factor': 1.
    }
}
min_max_KL_unif_trainer2 = Trainer(params)
min_max_KL_unif_trainer2.train()

In [None]:
# compute ID metrics
ID_results = [
    [np.mean(baseline_trainer.history[:, 2][-20:]), np.mean(baseline_trainer.history[:, 3][-20:])],
    [np.mean(min_max_KL_unif_trainer2.history[:, 2][-20:]), np.mean(min_max_KL_unif_trainer2.history[:, 3][-20:])],
]
plot_ID_metrics(ID_results, tags = ['baseline', 'min_max_KL_unif'])
plot_training_metrics(baseline_trainer, min_max_KL_unif_trainer2, ['baseline', 'min_max_KL_unif'])

In [None]:
# plotting utils
def plot_training_metrics(baseline_trainer, our_trainer, tags=['baseline', '']):
    
    trainer_vec = [baseline_trainer, our_trainer]
    
    c_vec = plt.rcParams['axes.prop_cycle'].by_key()['color']
    plt.figure(figsize=(20, 3))

    # total loss
    plt.subplot(1, 4, 1)
    plt.title("Train/Test total loss")
    for idx, (c, trainer) in enumerate(zip(c_vec, trainer_vec)):
        name = tags[idx]
        plt.plot(gaussian_filter1d(trainer.history[:, 0], 100), '-o', c=c, label='%s - train' % name, markevery=1000, markersize=10)
        plt.plot(gaussian_filter1d(trainer.history[:, 1], 100), '-^', c=c, label='%s - val' % name, markevery=1000, markersize=10)
    plt.grid()
    plt.xlabel('iterations')
    plt.ylabel('total loss')
    plt.legend()
    
    
    # accuracy
    plt.subplot(1, 4, 2)
    plt.title("Train/Test accuracy")
    for idx, (c, trainer) in enumerate(zip(c_vec, trainer_vec)):
        name = tags[idx]
        plt.plot(gaussian_filter1d(trainer.history[:, 2], 100), '-o', c=c, label='%s - train' % name, markevery=1000, markersize=10)
        plt.plot(gaussian_filter1d(trainer.history[:, 3], 100), '-^', c=c, label='%s - val' % name, markevery=1000, markersize=10)
    plt.grid()
    plt.xlabel('iterations')
    plt.ylabel('Accuracy')
    plt.legend()
    

    # entropy
    plt.subplot(1, 4, 3)
    plt.title("Train/Test predictive entropy")
    for idx, (c, trainer) in enumerate(zip(c_vec, trainer_vec)):
        name = tags[idx]
        plt.plot(np.abs(gaussian_filter1d(trainer.history[:, 6], 100)), '-o', c=c, label='%s - train' % name, markevery=1000, markersize=10)
        plt.plot(np.abs(gaussian_filter1d(trainer.history[:, 7], 100)), '-^', c=c, label='%s - val' % name, markevery=1000, markersize=10)
    plt.grid()
    plt.xlabel('iterations')
    plt.ylabel('Entropy')
    
    # dimensionality
    # plt.subplot(1, 4, 4)
    # plt.title("Train/Test last-layer-rep dim")
    # for idx, (c, trainer) in enumerate(zip(c_vec, trainer_vec)):
    #     name = tags[idx]
    #     plt.plot(gaussian_filter1d(trainer.history[:, 8], 100), '-o', c=c, label='%s - train' % name, markevery=1000, markersize=10)
    #     plt.plot(gaussian_filter1d(trainer.history[:, 9], 100), '-^', c=c, label='%s - val' % name, markevery=1000, markersize=10)
    # plt.grid()
    # plt.xlabel('iterations')
    # plt.ylabel('# dims that capture 90% of feature var')


    plt.show()
    
    
    
def plot_OOD_metrics(results, tags):
    plt.figure(figsize=(12, 6))
    for (j, title) in enumerate(['Accuracy OOD $(\\uparrow)$', 'OOD AUC $(\\uparrow)$', 'Brier $(\\downarrow)$', 'ECE  $(\\downarrow)$', 'NLL  $(\\downarrow)$']):
        plt.subplot(2, 3, j + 1)
        plt.title(title)
        for idx in range(len(tags)):
            x = np.arange(len(corruption_type_list))
            y = [results[(ctype, idx)][j] for ctype in corruption_type_list]
            width = 0.1
            offset = width
            plt.bar(x + width * (idx + 1) - offset, y, width=width)
        plt.xticks(np.arange(len(corruption_type_list)) + 2*width, corruption_type_list, rotation=90)
        plt.grid()

    plt.subplot(2, 3, 6)
    plt.title('legend')
    for idx in range(len(tags)):
        plt.scatter(0., 0., label=f'{tags[idx]}')
    plt.legend(fontsize=8, loc='lower right')
        
    plt.tight_layout()
    plt.show()
    
    
def plot_ID_metrics(results, tags):
    plt.figure(figsize=(5, 6))
    for (j, title) in enumerate(['Train Accuracy ID $(\\uparrow)$', 'Test Accuracy ID $(\\uparrow)$']):
        plt.subplot(1, 2, j + 1)
        plt.title(title)
        width = 1.
        offset = width    
        for idx in range(len(tags)):
            plt.bar(width * (idx + 1), results[idx][j], width=width)
        plt.xticks(np.arange(len(tags)) + width, tags, rotation=90)
        plt.grid()
    plt.tight_layout()
    plt.show()

In [None]:
import tensorflow as tf

class Test(tf.keras.Model):

  def __init__(self):
    super().__init__()
    self.model = tf.keras.layers.Dense(1)
    self.optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
    
  @tf.function
  def step(self, X, y):
    with tf.GradientTape() as tape:
        loss = tf.reduce_mean((self.model(X) - y) ** 2)
    grads = tape.gradient(loss, self.model.trainable_variables)
    self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
    return loss
    
t = Test()

In [None]:
X = tf.random.normal((2, 10))
y = tf.random.normal((2, 1))

In [None]:
for _ in range(1000):
    print(t.step(X, y))