In [8]:
import tensorflow as tf
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
import numpy as np
import tensorflow_datasets as tfds
import tqdm
from tensorflow.keras.models import Model
from sklearn.metrics import roc_auc_score
from sklearn.calibration import calibration_curve
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import tensorflow_probability as tfp
from scipy.ndimage import gaussian_filter1d

In [9]:
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
  tf.config.experimental.set_memory_growth(gpu, True)

### cifar experiments

In [10]:
# helper functions to create TF datasets

def normalize_img(image, label):
    """Normalizes images: `uint8` -> `float32`."""
    image = tf.cast(image, tf.float32) / 255.
    # image = tf.reshape(image, (-1,))
    return image, label


def make_ds(ds, augment=False):
    ds = ds.map(
        normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds = ds.cache()
    ds = ds.repeat()
    ds = ds.shuffle(int(1e5))
    ds = ds.batch(128)
    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
    return ds



In [11]:
EPS = 1e-3

@tf.function
def baseline(X, Y, model, training, **params):
    Y_hat = model(X, training=training) 
    ce_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
    accuracy = tf.reduce_mean(tf.cast(tf.argmax(Y_hat, axis=1) == Y, tf.float32)) * 100.
    entropy_on_original_point = tf.reduce_mean(-1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat + EPS), axis=1))
    return ce_loss, ce_loss, 0., accuracy, entropy_on_original_point
        
@tf.function
def min_max_cent(X, Y, model, training, **params):
    # compute gradient of cross entropy loss wrt X and take a step in +ve direction
    # this would try to find a point in the neighborhood of X that maximizes cross entropy
    Y_hat = model(X, training=training) 
    ce_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
    accuracy = tf.reduce_mean(tf.cast(tf.argmax(Y_hat, axis=1) == Y, tf.float32)) * 100.
    entropy_on_original_point = tf.reduce_mean(-1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat + EPS), axis=1))    
    with tf.GradientTape() as tape:
        tape.watch(X)
        Y_hat = model(X, training=training)
        loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
    grads = tape.gradient(loss, X)
    grads_norm = tf.norm(tf.reshape(grads, (128, -1)), axis=1)
    grads = params['step_size'] * grads / grads_norm[:, None, None, None]
    X_perturbed = tf.clip_by_value(X + grads, 0.0, 1.0)
    # compute cross entropy at this new point
    Y_hat = model(X_perturbed, training=training)
    loss_adv = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
    return ce_loss + params['lambda'] * loss_adv, ce_loss, loss_adv, accuracy, entropy_on_original_point

@tf.function
def max_min_ent(X, Y, model, training, **params):
    # compute grad of entropy wrt X and take a step in negative direction
    # this would find a point in the neighborhood of X that would minimize entropy
    Y_hat = model(X, training=training) 
    ce_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
    accuracy = tf.reduce_mean(tf.cast(tf.argmax(Y_hat, axis=1) == Y, tf.float32)) * 100.
    entropy_on_original_point = tf.reduce_mean(-1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat + EPS), axis=1))    
    with tf.GradientTape() as tape:
        tape.watch(X)
        Y_hat = model(X, training=training)
        exp_neg_entropy = tf.exp(tf.reduce_mean(Y_hat * tf.math.log(Y_hat + EPS), axis=1))
    grads = tape.gradient(exp_neg_entropy, X)
    grads_norm = tf.norm(tf.reshape(grads, (128, -1)), axis=1)
    grads = params['step_size'] * grads / grads_norm[:, None, None, None]      
    X_perturbed = tf.clip_by_value(X + grads, 0.0, 1.0)
    # compute entropy at this new point and multiply it by -1 (raise to exp. for better grads), since we want to maximize entropy
    Y_hat = model(X_perturbed, training=training)
    entropy = -1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat + EPS), axis=1)
    loss_adv = tf.reduce_mean(tf.exp(-1.0 * entropy))
    return ce_loss + params['lambda'] * loss_adv, ce_loss, loss_adv, accuracy, entropy_on_original_point

@tf.function
def min_max_KL_unif(X, Y, model, training, **params):
    # compute grad of KL(unif, p_\theta) wrt X and take a step in +ve direction
    # this would find a point in the neighborhood of X that would maximize KL
    Y_hat = model(X, training=training) 
    ce_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
    accuracy = tf.reduce_mean(tf.cast(tf.argmax(Y_hat, axis=1) == Y, tf.float32)) * 100.
    entropy_on_original_point = tf.reduce_mean(-1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat + 1e-3), axis=1)) 
    with tf.GradientTape() as tape:
        tape.watch(X)
        Y_hat = model(X, training=training)
        KL_unif = tf.keras.losses.CategoricalCrossentropy(from_logits=False)(
            y_true=tf.ones_like(Y_hat) / tf.cast(params['n_classes'], tf.float32), y_pred=Y_hat)
    grads = tape.gradient(KL_unif, X)
    grads_norm = tf.norm(tf.reshape(grads, (128, -1)), axis=1)
    grads = params['step_size'] * grads / grads_norm[:, None, None, None]      
    X_perturbed = tf.clip_by_value(X + grads, 0.0, 1.0)
    # compute entropy at this new point and multiply it by -1, since we want to maximize entropy
    Y_hat = model(X_perturbed, training=training)
    KL_unif = tf.keras.losses.CategoricalCrossentropy(from_logits=False)(
        y_true=tf.ones_like(Y_hat) / tf.cast(params['n_classes'], tf.float32), y_pred=Y_hat)
    loss_adv = tf.reduce_mean(KL_unif)
    return ce_loss + params['lambda'] * loss_adv, ce_loss, loss_adv, accuracy, entropy_on_original_point

@tf.function
def label_smoothing(X, Y, model, training, **params):
    Y_hat = model(X, training=training)
    ce_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
    accuracy = tf.reduce_mean(tf.cast(tf.argmax(Y_hat, axis=1) == Y, tf.float32)) * 100.
    entropy_on_original_point = tf.reduce_mean(-1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat + 1e-3), axis=1)) 
    Y = tf.one_hot(Y, params['n_classes'])
    Y_noisy = Y * (1 - params['label-smoothing-factor']) 
    Y_noisy += (params['label-smoothing-factor'] / tf.cast(params['n_classes'], tf.float32))
    noisy_loss = tf.keras.losses.CategoricalCrossentropy(from_logits=False)(y_true=Y_noisy, y_pred=Y_hat)
    return noisy_loss, ce_loss, 0., accuracy, entropy_on_original_point

@tf.function
def get_calibration_metrics(X, Y, model, tau):
    logits = tf.math.log(model(X, training=False) + EPS) * tf.repeat(tau, X.shape[0], axis=0)
    brier = tf.reduce_mean(tfp.stats.brier_score(labels=Y, logits=logits))
    ece = tfp.stats.expected_calibration_error(num_bins=20, logits=logits, labels_true=Y)
    nll = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)(y_true=Y, y_pred=logits)
    return brier, ece, nll
    
    
@tf.function
def eval_ood_helper(model, X_real, X_fake, tau):
    Y_hat_real = tf.nn.softmax(tf.math.log(model(X_real, training=False) + EPS) * tf.repeat(tau, X_real.shape[0], axis=0))
    entropy_real = -1.0 * tf.reduce_mean(Y_hat_real * tf.math.log(Y_hat_real + EPS), axis=1)
    Y_hat_fake = tf.nn.softmax(tf.math.log(model(X_fake, training=False) + EPS) * tf.repeat(tau, X_fake.shape[0], axis=0))
    entropy_fake = -1.0 * tf.reduce_mean(Y_hat_fake * tf.math.log(Y_hat_fake + EPS), axis=1)
    return tf.concat([entropy_real, entropy_fake], axis=0)


def conv_to_tensors(p):
    return {k: tf.constant(v) if type(v) != dict else conv_to_tensors(v) for k, v in p.items()}

# define loss function
# computes loss given a model and X, Y
@ tf.function
def loss_fn(X, Y, model, training, **params):

    if params['version'] == 'baseline':
        return baseline(X, Y, model, training, **params)
    elif params['version'] == 'min-max-cent':
        return min_max_cent(X, Y, model, training, **params)
    elif params['version'] == 'max-min-ent':
        return max_min_ent(X, Y, model, training, **params)
    elif params['version'] == 'min-max-KL-unif': 
        return min_max_KL_unif(X, Y, model, training, **params)
    elif params['version'] == 'label-smoothing':
        return label_smoothing(X, Y, model, training, **params)
    else:
        return 0., 0., 0., 0., 0.

# define step function
# computes gradients and applies them
@tf.function
def step_fn(X, Y, model, optimizer, **params):
    with tf.GradientTape() as tape:
        loss, cent_loss, loss_adv, accuracy, predent = loss_fn(X, Y, model, True, **params)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return loss, cent_loss, loss_adv, accuracy, predent


    

In [34]:
# create dataset, iterate over it and train, return model and metrics

EPS = 1e-3
    
class Trainer:
    
    def __init__(self, params):
        self.params = params
        print(f"\nRun training with params {self.params}")
    
        ds_train, ds_test = out = tfds.load(
            'cifar10',
            split=['train', 'test'],
            data_dir='/amrith/tensorflow_datasets',
            shuffle_files=True,
            as_supervised=True,
            with_info=False,
        )
        
        ds_train = make_ds(ds_train)
        ds_test = make_ds(ds_test)
        self.iter_train = iter(ds_train)
        self.iter_test = iter(ds_test)

        # define model
        self.model = tf.keras.models.Sequential([
            tf.keras.layers.RandomFlip(mode='horizontal', input_shape=self.params['input_shape']),
            tf.keras.layers.RandomTranslation(0.1, 0.1),
            tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same', input_shape=params['input_shape']),
            tf.keras.layers.BatchNormalization(fused=True),
            tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
            tf.keras.layers.BatchNormalization(fused=True),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
            tf.keras.layers.BatchNormalization(fused=True),
            tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
            tf.keras.layers.BatchNormalization(fused=True),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
            tf.keras.layers.BatchNormalization(fused=True),
            tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
            tf.keras.layers.BatchNormalization(fused=True),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(1024, activation='relu'),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(self.params['n_classes'], activation='softmax')])
    
    
        # self.model = tf.keras.applications.resnet50.ResNet50(
        #     include_top=True, weights='imagenet', input_shape=self.params['input_shape'])
        
        print(self.model.summary())

        # define optimizer
        self.optimizer = tfa.optimizers.AdamW(learning_rate=self.params['lr'], weight_decay=params['weight_decay'])
        
        # maintain history
        self.history = []
        
        # get last layer reps
        self.model_last_layer = Model(self.model.input, self.model.layers[-2].output)
        
        # temperature for platt scaling
        self.tau = tf.Variable(tf.ones((1, self.params['n_classes'])))

    
    def eval_ood(self, X_real, X_fake):
        logits = eval_ood_helper(X_real, X_fake, self.tau).numpy()
        labels = np.concatenate([np.zeros(128), np.ones(128)])
        auc = roc_auc_score(labels, logits)
        return auc

    
    def get_batch_dimension(self, repr): # dimension of the last layer representations from a batch
        repr = StandardScaler().fit_transform(repr)
        pca = PCA()
        pca.fit(repr)
        explained_var = pca.explained_variance_
        explained_var /= (explained_var.sum() + EPS)
        return np.sum([explained_var.cumsum() <= 0.9]) # returns no. of dimensions that account for 90% of variance
    
    def get_model_weights(self):
        params = np.array([])
        for layer in t.model.layers:
            for wt in layer.trainable_variables: 
                params = np.concatenate([params, wt.numpy().flatten()], axis=0)
        return params
    
    def set_model_weights(self, param):
        idx=0
        for layer in self.model.layers:
            for wt in layer.trainable_variables: 
                wt.assign(param[prev:prev+np.prod(wt.shape)].reshape(wt.shape))
                idx+=np.prod(wt.shape)
                
    def evaluate_n_random_batches(self, n=10):
        loss = 0.
        for _ in range(n):
            loss += self.loss_fn(*(self.iter_train), trainable=False)
        return loss / n
        
    def compute_sharpness_metric(self, p=100, delta=0.001):
        x_0 = self.get_model_weights() 
        A = tf.random.normal((x_0.shape[0], p))
        proj = tf.linalg.pinv(A) @ x_0
        y_min = (tf.math.abs(proj)+1)*delta
        y_max = (tf.math.abs(proj)+1)*(-delta)
        y_0 = tf.Variable(np.random.zeros(p), trainable=True)
        # for LBFS solver, returns func evaluation and gradient
        def f(y):
            with tf.GradientTape() as tape:
                tape.watch()
                self.set_model_weights(x_0 + A@y)
                loss = - self.evaluate_n_random_batches(n=10)
                # we want to maximize the loss hence, the negative sign
            return loss, tape.gradient(loss, y)
        _, neg_maxf, _ = scipy.optimize.fmin_l_bfgs_b(
            f, y_0, bounds=zip(y_min, y_max), maxiter=10)
        maxf = -neg_maxf
        fx = self.evaluate_n_random_batches(n=10)
        sharpness = (maxf - fx) * 100. / (1 + fx)
        # reset model weights
        self.set_model_weights(x_0)
        return sharpness


        
    def train(self):
        
        @tf.function
        def baseline(X, Y, training):
            print("Tracing baseline")
            Y_hat = self.model(X, training=training) 
            ce_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
            accuracy = tf.reduce_mean(tf.cast(tf.argmax(Y_hat, axis=1) == Y, tf.float32)) * 100.
            entropy_on_original_point = tf.reduce_mean(-1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat + EPS), axis=1))
            return ce_loss, ce_loss, 0., accuracy, entropy_on_original_point

        @tf.function
        def min_max_cent(X, Y, training, params):
            # compute gradient of cross entropy loss wrt X and take a step in +ve direction
            # this would try to find a point in the neighborhood of X that maximizes cross entropy
            Y_hat = self.model(X, training=training) 
            ce_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
            accuracy = tf.reduce_mean(tf.cast(tf.argmax(Y_hat, axis=1) == Y, tf.float32)) * 100.
            entropy_on_original_point = tf.reduce_mean(-1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat + EPS), axis=1))    
            if params['step_size'] > 0.:
                with tf.GradientTape() as tape:
                    tape.watch(X)
                    Y_hat = self.model(X, training=training)
                    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
                grads = tape.gradient(loss, X)
                grads_norm = tf.norm(tf.reshape(grads, (128, -1)), axis=1)
                grads = params['step_size'] * grads / grads_norm[:, None, None, None]
                X_perturbed = tf.clip_by_value(X + grads, 0.0, 1.0)
            else:
                X_perturbed = X
            # compute cross entropy at this new point
            Y_hat = self.model(X_perturbed, training=training)
            loss_adv = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
            return ce_loss + params['lambda'] * loss_adv, ce_loss, loss_adv, accuracy, entropy_on_original_point

        @tf.function
        def max_min_ent(X, Y, training, params):
            # compute grad of entropy wrt X and take a step in negative direction
            # this would find a point in the neighborhood of X that would minimize entropy
            Y_hat = self.model(X, training=training) 
            ce_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
            accuracy = tf.reduce_mean(tf.cast(tf.argmax(Y_hat, axis=1) == Y, tf.float32)) * 100.
            entropy_on_original_point = tf.reduce_mean(-1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat + EPS), axis=1))    
            if params['step_size'] > 0.:
                with tf.GradientTape() as tape:
                    tape.watch(X)
                    Y_hat = self.model(X, training=training)
                    exp_neg_entropy = tf.exp(tf.reduce_mean(Y_hat * tf.math.log(Y_hat + EPS), axis=1))
                grads = tape.gradient(exp_neg_entropy, X)
                grads_norm = tf.norm(tf.reshape(grads, (128, -1)), axis=1)
                grads = params['step_size'] * grads / grads_norm[:, None, None, None]      
                X_perturbed = tf.clip_by_value(X + grads, 0.0, 1.0)
            else:
                X_perturbed = X
            # compute entropy at this new point and multiply it by -1 (raise to exp. for better grads), since we want to maximize entropy
            Y_hat = self.model(X_perturbed, training=training)
            entropy = -1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat + EPS), axis=1)
            loss_adv = tf.reduce_mean(tf.exp(-1.0 * entropy))
            return ce_loss + params['lambda'] * loss_adv, ce_loss, loss_adv, accuracy, entropy_on_original_point

        @tf.function
        def min_max_KL_unif(X, Y, training, params):
            # compute grad of KL(unif, p_\theta) wrt X and take a step in +ve direction
            # this would find a point in the neighborhood of X that would maximize KL
            Y_hat = self.model(X, training=training) 
            ce_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
            accuracy = tf.reduce_mean(tf.cast(tf.argmax(Y_hat, axis=1) == Y, tf.float32)) * 100.
            entropy_on_original_point = tf.reduce_mean(-1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat + 1e-3), axis=1)) 
            if params['step_size'] > 0.:
                with tf.GradientTape() as tape:
                    tape.watch(X)
                    Y_hat = self.model(X, training=training)
                    KL_unif = tf.keras.losses.CategoricalCrossentropy(from_logits=False)(
                        y_true=tf.ones_like(Y_hat) / params['n_classes'], y_pred=Y_hat)
                grads = tape.gradient(KL_unif, X)
                grads_norm = tf.norm(tf.reshape(grads, (128, -1)), axis=1)
                grads = self.params['step_size'] * grads / grads_norm[:, None, None, None]      
                X_perturbed = tf.clip_by_value(X + grads, 0.0, 1.0)
            else:
                X_perturbed = X
            # compute entropy at this new point and multiply it by -1, since we want to maximize entropy
            Y_hat = self.model(X_perturbed, training=training)
            KL_unif = tf.keras.losses.CategoricalCrossentropy(from_logits=False)(
                y_true=tf.ones_like(Y_hat) / params['n_classes'], y_pred=Y_hat)
            loss_adv = tf.reduce_mean(KL_unif)
            return ce_loss + params['lambda'] * loss_adv, ce_loss, loss_adv, accuracy, entropy_on_original_point

        @tf.function
        def label_smoothing(X, Y, training, params):
            Y_hat = self.model(X, training=training)
            ce_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
            accuracy = tf.reduce_mean(tf.cast(tf.argmax(Y_hat, axis=1) == Y, tf.float32)) * 100.
            entropy_on_original_point = tf.reduce_mean(-1.0 * tf.reduce_mean(Y_hat * tf.math.log(Y_hat + 1e-3), axis=1)) 
            Y = tf.one_hot(Y, params['n_classes'])
            Y_noisy = Y * (1 - params['label-smoothing-factor']) 
            Y_noisy += (self.params['label-smoothing-factor'] / tf.cast(params['n_classes'], tf.float32))
            noisy_loss = tf.keras.losses.CategoricalCrossentropy(from_logits=False)(y_true=Y_noisy, y_pred=Y_hat)
            return noisy_loss, ce_loss, 0., accuracy, entropy_on_original_point

        @tf.function
        def get_calibration_metrics(X, Y):
            logits = tf.math.log(self.model(X, training=False) + EPS) * tf.repeat(self.tau, X.shape[0], axis=0)
            brier = tf.reduce_mean(tfp.stats.brier_score(labels=Y, logits=logits))
            ece = tfp.stats.expected_calibration_error(num_bins=20, logits=logits, labels_true=Y)
            nll = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)(y_true=Y, y_pred=logits)
            return brier, ece, nll


        @tf.function
        def eval_ood_helper(X_real, X_fake):
            Y_hat_real = tf.nn.softmax(tf.math.log(self.model(X_real, training=False) + EPS) * tf.repeat(self.tau, X_real.shape[0], axis=0))
            entropy_real = -1.0 * tf.reduce_mean(Y_hat_real * tf.math.log(Y_hat_real + EPS), axis=1)
            Y_hat_fake = tf.nn.softmax(tf.math.log(self.model(X_fake, training=False) + EPS) * tf.repeat(self.tau, X_fake.shape[0], axis=0))
            entropy_fake = -1.0 * tf.reduce_mean(Y_hat_fake * tf.math.log(Y_hat_fake + EPS), axis=1)
            return tf.concat([entropy_real, entropy_fake], axis=0)



        # define loss function
        # computes loss given a model and X, Y
        @ tf.function
        def loss_fn(X, Y, training, params):

            if params['version'] == 'baseline':
                return baseline(X, Y, training)
            elif params['version'] == 'min-max-cent':
                return min_max_cent(X, Y, training, params)
            elif params['version'] == 'max-min-ent':
                return max_min_ent(X, Y, training, params)
            elif params['version'] == 'min-max-KL-unif': 
                return min_max_KL_unif(X, Y, training, params)
            elif params['version'] == 'label-smoothing':
                return label_smoothing(X, Y, training, params)
            else:
                raise ValueError
                
        # define step function
        # computes gradients and applies them
        @tf.function
        def step_fn(X, Y, params):
            with tf.GradientTape() as tape:
                loss, cent_loss, loss_adv, accuracy, predent = loss_fn(X, Y, True, params)
            grads = tape.gradient(loss, self.model.trainable_variables)
            self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
            return loss, cent_loss, loss_adv, accuracy, predent
        
        # loop over data n_iters times
        for t in tqdm.trange(self.params['n_iters']):
            train_loss, train_loss_cent, train_loss_adv, train_acc, train_predent = step_fn(*next(self.iter_train), self.params)    
            if t % 1000 == 0:
                tf.print("Total:", train_loss, "CE:", train_loss_cent, "Adv:", train_loss_adv, "Acc:", train_acc)
            if t % 20 == 0:
                test_loss = []
                test_loss_cent = []
                test_loss_adv = []
                test_acc = []
                test_predent = []
                for _ in range(10):
                    res = loss_fn(*next(self.iter_test), False, self.params)
                    test_loss.append(res[0].numpy())
                    test_loss_cent.append(res[1].numpy())
                    test_loss_adv.append(res[2].numpy())
                    test_acc.append(res[3].numpy())
                    test_predent.append(res[4].numpy())
                train_dim = self.get_batch_dimension(self.model_last_layer(next(self.iter_train)[0], training=False))
                test_dim = self.get_batch_dimension(self.model_last_layer(next(self.iter_test)[0], training=False))
            self.history.append((train_loss.numpy(), np.mean(test_loss), train_acc.numpy(), np.mean(test_acc),
                                train_loss_adv.numpy(), np.mean(test_loss_adv), train_predent.numpy(), np.mean(test_predent)))
                                # train_dim, test_dim))
            
            if ('lambda_schedule' in self.params) and ((t+1) % self.params['lambda_schedule']['frequency'] == 0):
                self.params['lambda'] *= self.params['lambda_schedule']['factor']
                
            if ('lr_schedule' in self.params) and ((t+1) % self.params['lr_schedule']['frequency'] == 0):
                self.optimizer.lr.assign(self.optimizer.lr * self.params['lr_schedule']['factor'])

        self.history = np.array(self.history)

        
# post hoc platt scaling
def calibrate_model(trainer):
    # loop over data n_iters times
    tau_optimizer = tf.keras.optimizers.Adam(learning_rate=trainer.params['lr-calibrator']*10.)
    for t in tqdm.trange(trainer.params['n_iters']//2):
        X, Y = next(trainer.iter_train)
        Y_pred_logits = tf.math.log(trainer.model(X, training=False) + EPS)
        with tf.GradientTape() as tape:
            tape.watch(trainer.tau)
            Y_pred_logits *= tf.repeat(trainer.tau, Y_pred_logits.shape[0], axis=0)
            loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)(y_true=Y, y_pred=Y_pred_logits)
        grad_tau = tape.gradient(loss, trainer.tau)
        grad_tau = tf.clip_by_norm(grad_tau, 0.1)
        tau_optimizer.apply_gradients([(grad_tau, trainer.tau)])
        if t % 1000 == 0:
            print(f"Calibration Loss: {loss}, {trainer.tau}")

In [35]:
# baseline
params = {
    'input_shape': (32, 32, 3),
    'n_classes': 10,
    'lambda': 0.0,
    'lr': 5e-4,
    'lr-calibrator': 1e-4,
    'step_size': 0.,
    'n_iters': 10000,
    'version': 'baseline',
    'weight_decay': 0.0001,
    'lr_schedule':{
        'frequency':8000,
        'factor': 0.1
    },
    'label-smoothing-factor': 0.
}
baseline_trainer = Trainer(params)
baseline_trainer.train()


Run training with params {'input_shape': (32, 32, 3), 'n_classes': 10, 'lambda': 0.0, 'lr': 0.0005, 'lr-calibrator': 0.0001, 'step_size': 0.0, 'n_iters': 10000, 'version': 'baseline', 'weight_decay': 0.0001, 'lr_schedule': {'frequency': 8000, 'factor': 0.1}, 'label-smoothing-factor': 0.0}
Model: "sequential_13"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
random_flip_13 (RandomFlip)  (None, 32, 32, 3)         0         
_________________________________________________________________
random_translation_13 (Rando (None, 32, 32, 3)         0         
_________________________________________________________________
conv2d_78 (Conv2D)           (None, 32, 32, 32)        896       
_________________________________________________________________
batch_normalization_78 (Batc (None, 32, 32, 32)        128       
_________________________________________________________________
conv2d_79 (Conv2D)        

  0%|                                                                                                                              | 0/10000 [00:00<?, ?it/s]

Tracing baseline
Total: 4.17955542 CE: 4.17955542 Adv: 0 Acc: 8.59375
Tracing baseline


 10%|███████████▌                                                                                                       | 1000/10000 [00:30<03:41, 40.72it/s]

Total: 0.889262378 CE: 0.889262378 Adv: 0 Acc: 67.1875


 20%|███████████████████████                                                                                            | 2000/10000 [00:56<02:44, 48.70it/s]

Total: 0.667754352 CE: 0.667754352 Adv: 0 Acc: 77.34375


 30%|██████████████████████████████████▌                                                                                | 3000/10000 [01:22<02:20, 49.69it/s]

Total: 0.383449674 CE: 0.383449674 Adv: 0 Acc: 85.9375


 40%|█████████████████████████████████████████████▉                                                                     | 3996/10000 [01:48<02:22, 42.22it/s]

Total: 0.371704 CE: 0.371704 Adv: 0 Acc: 85.15625


 50%|█████████████████████████████████████████████████████████▌                                                         | 5000/10000 [02:13<01:47, 46.69it/s]

Total: 0.414470553 CE: 0.414470553 Adv: 0 Acc: 85.15625


 60%|█████████████████████████████████████████████████████████████████████                                              | 6000/10000 [02:39<01:25, 47.05it/s]

Total: 0.320916235 CE: 0.320916235 Adv: 0 Acc: 89.84375


 70%|████████████████████████████████████████████████████████████████████████████████▍                                  | 6993/10000 [03:05<01:11, 42.09it/s]

Total: 0.308200955 CE: 0.308200955 Adv: 0 Acc: 89.0625


 80%|███████████████████████████████████████████████████████████████████████████████████████████▉                       | 7993/10000 [03:31<00:47, 41.98it/s]

Total: 0.295175344 CE: 0.295175344 Adv: 0 Acc: 91.40625


 90%|███████████████████████████████████████████████████████████████████████████████████████████████████████▍           | 8997/10000 [03:56<00:22, 44.71it/s]

Total: 0.180043638 CE: 0.180043638 Adv: 0 Acc: 93.75


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [04:22<00:00, 38.15it/s]


In [None]:
# max-min-ent
params = {
    'input_shape': (32, 32, 3),
    'n_classes': 10,
    'lambda': 5.,
    'lr': 5e-4,
    'lr-calibrator': 1e-4,
    'step_size': 0.,
    'n_iters': 10000,
    'version': 'min-max-KL-unif',
    'weight_decay': 0.0001,
    'lr_schedule':{
        'frequency':8000,
        'factor': 0.1
    },
    'lambda_schedule': {
        'frequency': 500,
        'factor': 1.
    }
}
max_min_ent_trainer = Trainer(params)
max_min_ent_trainer.train()


Run training with params {'input_shape': (32, 32, 3), 'n_classes': 10, 'lambda': 5.0, 'lr': 0.0005, 'lr-calibrator': 0.0001, 'step_size': 0.0, 'n_iters': 10000, 'version': 'min-max-KL-unif', 'weight_decay': 0.0001, 'lr_schedule': {'frequency': 8000, 'factor': 0.1}, 'lambda_schedule': {'frequency': 500, 'factor': 1.0}}
Model: "sequential_14"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
random_flip_14 (RandomFlip)  (None, 32, 32, 3)         0         
_________________________________________________________________
random_translation_14 (Rando (None, 32, 32, 3)         0         
_________________________________________________________________
conv2d_84 (Conv2D)           (None, 32, 32, 32)        896       
_________________________________________________________________
batch_normalization_84 (Batc (None, 32, 32, 32)        128       
______________________________________________________________

  0%|                                                                                                                              | 0/10000 [00:00<?, ?it/s]

Total: 22.9821129 CE: 4.05168819 Adv: 3.78608489 Acc: 9.375


 10%|███████████▌                                                                                                        | 999/10000 [00:44<04:22, 34.32it/s]

Total: 13.6791277 CE: 2.02478671 Adv: 2.33086824 Acc: 47.65625


 20%|██████████████████████▉                                                                                            | 1997/10000 [01:21<04:37, 28.82it/s]

Total: 13.5871162 CE: 1.90086699 Adv: 2.33724976 Acc: 57.8125


 30%|██████████████████████████████████▍                                                                                | 2991/10000 [01:54<02:18, 50.72it/s]

Total: 13.52 CE: 1.7884903 Adv: 2.34630203 Acc: 73.4375


 40%|█████████████████████████████████████████████▉                                                                     | 3991/10000 [02:15<01:59, 50.11it/s]

Total: 13.5442724 CE: 1.77408898 Adv: 2.35403681 Acc: 67.1875


 44%|███████████████████████████████████████████████████▏                                                               | 4446/10000 [02:43<02:52, 32.26it/s]

In [None]:
# min-max-KL-unif
params = {
    'input_shape': (32, 32, 3),
    'n_classes': 10,
    'lambda': 0.2,
    'lr': 5e-4,
    'lr-calibrator': 1e-4,
    'step_size': 0.,
    'n_iters': 10000,
    'version': 'min-max-KL-unif',
    'weight_decay': 0.0001,
    'lr_schedule':{
        'frequency':8000,
        'factor': 0.1
    },
    'lambda_schedule': {
        'frequency': 500,
        'factor': 1.
    }
}
min_max_KL_unif_trainer = Trainer(params)
min_max_KL_unif_trainer.train()

In [None]:
# compute ID metrics
ID_results = [
    [np.mean(baseline_trainer.history[:, 2][-20:]), np.mean(baseline_trainer.history[:, 3][-20:])],
    [np.mean(max_min_ent_trainer.history[:, 2][-20:]), np.mean(max_min_ent_trainer.history[:, 3][-20:])],    
    [np.mean(min_max_KL_unif_trainer.history[:, 2][-20:]), np.mean(min_max_KL_unif_trainer.history[:, 3][-20:])],
]
plot_ID_metrics(ID_results, tags = ['baseline', 'max_min_ent', 'min_max_KL_unif'])
plot_training_metrics(baseline_trainer, min_max_KL_unif_trainer, ['baseline', 'max_min_ent', 'min_max_KL_unif'])

In [16]:
# plotting utils
def plot_training_metrics(baseline_trainer, our_trainer, tags=['baseline', '']):
    
    trainer_vec = [baseline_trainer, our_trainer]
    
    c_vec = plt.rcParams['axes.prop_cycle'].by_key()['color']
    plt.figure(figsize=(20, 3))

    # total loss
    plt.subplot(1, 4, 1)
    plt.title("Train/Test total loss")
    for idx, (c, trainer) in enumerate(zip(c_vec, trainer_vec)):
        name = tags[idx]
        plt.plot(gaussian_filter1d(trainer.history[:, 0], 100), '-o', c=c, label='%s - train' % name, markevery=1000, markersize=10)
        plt.plot(gaussian_filter1d(trainer.history[:, 1], 100), '-^', c=c, label='%s - val' % name, markevery=1000, markersize=10)
    plt.grid()
    plt.xlabel('iterations')
    plt.ylabel('total loss')
    plt.legend()
    
    
    # accuracy
    plt.subplot(1, 4, 2)
    plt.title("Train/Test accuracy")
    for idx, (c, trainer) in enumerate(zip(c_vec, trainer_vec)):
        name = tags[idx]
        plt.plot(gaussian_filter1d(trainer.history[:, 2], 100), '-o', c=c, label='%s - train' % name, markevery=1000, markersize=10)
        plt.plot(gaussian_filter1d(trainer.history[:, 3], 100), '-^', c=c, label='%s - val' % name, markevery=1000, markersize=10)
    plt.grid()
    plt.xlabel('iterations')
    plt.ylabel('Accuracy')
    plt.legend()
    

    # entropy
    plt.subplot(1, 4, 3)
    plt.title("Train/Test predictive entropy")
    for idx, (c, trainer) in enumerate(zip(c_vec, trainer_vec)):
        name = tags[idx]
        plt.plot(np.abs(gaussian_filter1d(trainer.history[:, 6], 100)), '-o', c=c, label='%s - train' % name, markevery=1000, markersize=10)
        plt.plot(np.abs(gaussian_filter1d(trainer.history[:, 7], 100)), '-^', c=c, label='%s - val' % name, markevery=1000, markersize=10)
    plt.grid()
    plt.xlabel('iterations')
    plt.ylabel('Entropy')
    
    # dimensionality
    plt.subplot(1, 4, 4)
    plt.title("Train/Test last-layer-rep dim")
    for idx, (c, trainer) in enumerate(zip(c_vec, trainer_vec)):
        name = tags[idx]
        plt.plot(gaussian_filter1d(trainer.history[:, 8], 100), '-o', c=c, label='%s - train' % name, markevery=1000, markersize=10)
        plt.plot(gaussian_filter1d(trainer.history[:, 9], 100), '-^', c=c, label='%s - val' % name, markevery=1000, markersize=10)
    plt.grid()
    plt.xlabel('iterations')
    plt.ylabel('# dims that capture 90% of feature var')


    plt.show()
    
    
    
def plot_OOD_metrics(results, tags):
    plt.figure(figsize=(12, 6))
    for (j, title) in enumerate(['Accuracy OOD $(\\uparrow)$', 'OOD AUC $(\\uparrow)$', 'Brier $(\\downarrow)$', 'ECE  $(\\downarrow)$', 'NLL  $(\\downarrow)$']):
        plt.subplot(2, 3, j + 1)
        plt.title(title)
        for idx in range(len(tags)):
            x = np.arange(len(corruption_type_list))
            y = [results[(ctype, idx)][j] for ctype in corruption_type_list]
            width = 0.1
            offset = width
            plt.bar(x + width * (idx + 1) - offset, y, width=width)
        plt.xticks(np.arange(len(corruption_type_list)) + 2*width, corruption_type_list, rotation=90)
        plt.grid()

    plt.subplot(2, 3, 6)
    plt.title('legend')
    for idx in range(len(tags)):
        plt.scatter(0., 0., label=f'{tags[idx]}')
    plt.legend(fontsize=8, loc='lower right')
        
    plt.tight_layout()
    plt.show()
    
    
def plot_ID_metrics(results, tags):
    plt.figure(figsize=(5, 6))
    for (j, title) in enumerate(['Train Accuracy ID $(\\uparrow)$', 'Test Accuracy ID $(\\uparrow)$']):
        plt.subplot(1, 2, j + 1)
        plt.title(title)
        width = 1.
        offset = width    
        for idx in range(len(tags)):
            plt.bar(width * (idx + 1), results[idx][j], width=width)
        plt.xticks(np.arange(len(tags)) + width, tags, rotation=90)
        plt.grid()
    plt.tight_layout()
    plt.show()

In [None]:
# baseline
params = {
    'input_shape': (32, 32, 3),
    'n_classes': 10,
    'lambda': 0.0,
    'lr': 3e-4,
    'lr-calibrator': 1e-4,
    'step_size': 0.,
    'n_iters': 12000,
    'version': 'NA',
    'weight_decay': 0.
}
baseline_trainer = Trainer(params)
baseline_trainer.train()
# calibrate_model(baseline_trainer)


# min max cross-entropy
params = {
    'input_shape': (32, 32, 3),
    'n_classes': 10,
    'lambda': 4.0,
    'lr': 3e-4,
    'lr-calibrator': 1e-4,
    'step_size': 0.1,
    'n_iters': 12000,
    'version': 'min-max-cent',
    'weight_decay': 0.
}
min_max_cent_trainer = Trainer(params)
min_max_cent_trainer.train()
# calibrate_model(min_max_cent_trainer)



# # max-min-ent
params = {
    'input_shape': (32, 32, 3),
    'n_classes': 10,
    'lambda': 4.0,
    'lr': 3e-4,
    'lr-calibrator': 1e-4,
    'step_size': 0.1,
    'n_iters': 12000,
    'version': 'max-min-ent',
    'weight_decay': 0.
}
max_min_ent_trainer = Trainer(params)
max_min_ent_trainer.train()
# calibrate_model(max_min_ent_trainer)


# min max KL with unif
params = {
    'input_shape': (32, 32, 3),
    'n_classes': 10,
    'lambda': 0.1,
    'lr': 3e-4,
    'lr-calibrator': 1e-4,
    'step_size': 0.1,
    'n_iters': 12000,
    'version': 'min-max-KL-unif',
    'weight_decay': 0.
}
min_max_KL_unif_trainer = Trainer(params)
min_max_KL_unif_trainer.train()
# calibrate_model(min_max_KL_unif_trainer)

# label smoothing
params = {
    'input_shape': (32, 32, 3),
    'n_classes': 10,
    'lambda': 0.,
    'lr': 3e-4,
    'lr-calibrator': 1e-4,
    'step_size': 0.,
    'n_iters': 12000,
    'version': 'label-smoothing',
    'label-smoothing-factor': 0.1,
    'weight_decay': 0.
}
label_smoothing_trainer = Trainer(params)
label_smoothing_trainer.train()
# calibrate_model(label_smoothing_trainer)

In [None]:
baseline_trainer.history[:, 2].shape

In [None]:
# compute ID metrics
ID_results = [
    [np.mean(baseline_trainer.history[:, 2][-20:]), np.mean(baseline_trainer.history[:, 3][-20:])],
    [np.mean(min_max_cent_trainer.history[:, 2][-20:]), np.mean(min_max_cent_trainer.history[:, 3][-20:])],
    [np.mean(max_min_ent_trainer.history[:, 2][-20:]), np.mean(max_min_ent_trainer.history[:, 3][-20:])],
    [np.mean(min_max_KL_unif_trainer.history[:, 2][-20:]), np.mean(min_max_KL_unif_trainer.history[:, 3][-20:])],
    [np.mean(label_smoothing_trainer.history[:, 2][-20:]), np.mean(label_smoothing_trainer.history[:, 3][-20:])],
    #     [np.mean(min_max_cent_trainer_alp0.history[:, 2][-20:]), np.mean(min_max_cent_trainer_alp0.history[:, 3][-50:])],
    #     [np.mean(max_min_ent_trainer_alp0.history[:, 2][-20:]), np.mean(max_min_ent_trainer_alp0.history[:, 3][-50:])],
    #     [np.mean(min_max_KL_unif_trainer_alp0.history[:, 2][-20:]), np.mean(min_max_KL_unif_trainer_alp0.history[:, 3][-50:])],
]
plot_ID_metrics(ID_results, tags = ['baseline', 'min-max-cent', 'max-min-ent', 'min-max-KL-unif', 'label-smoothing'])

In [None]:
# plot
plot_training_metrics(baseline_trainer, min_max_cent_trainer, ['baseline', 'min-max-cent'])
plot_training_metrics(baseline_trainer, max_min_ent_trainer, ['baseline', 'max-min-ent'])
plot_training_metrics(baseline_trainer, min_max_KL_unif_trainer, ['baseline', 'min-max-KL-unif'])
plot_training_metrics(baseline_trainer, label_smoothing_trainer, ['baseline', 'label-smoothing'])

# plot_training_metrics(baseline_trainer, min_max_cent_trainer_alp0, ['baseline', 'min-max-cent-alp0'])
# plot_training_metrics(baseline_trainer, max_min_ent_trainer_alp0, ['baseline', 'max-min-ent-alp0'])
# plot_training_metrics(baseline_trainer, min_max_KL_unif_trainer_alp0, ['baseline', 'min-max-KL-unif-alp0'])



In [None]:
# compute OOD metrics
trainer_vec = [baseline_trainer, min_max_KL_unif_trainer]
#                min_max_cent_trainer, max_min_ent_trainer, min_max_KL_unif_trainer, label_smoothing_trainer]
tags = ['baseline', 'min-max-KL-unif']
#         'min-max-cent', 'max-min-ent', 'min-max-KL-unif', 'label-smoothing']

results = {}
corruption_type_list = ['brightness_1', 'elastic_5', 'fog_5', 'frost_5', 'frosted_glass_blur_5']

for idx, trainer in enumerate(trainer_vec):
    print(f"Trainer {idx + 1}")
    for corruption_type in corruption_type_list:
        print(corruption_type)
        
        # load corrupted dataset
        (ds_corrupted,) = tfds.load(
            'cifar10_corrupted/%s' % corruption_type,
            split=['test'],
            shuffle_files=True,
            as_supervised=True,
            with_info=False,
        )
        
        # make dataset and get iterator
        ds_corrupted = make_ds(ds_corrupted)
        iter_corrupted = iter(ds_corrupted)

        # record metrics
        acc_OOD_vec = []
        auc_vec = []
        brier_vec = []
        ece_vec = []
        nll_vec = []
        
        # average over a 100 batches
        for _ in tqdm.trange(100):
            X_test, Y_test = next(trainer.iter_test)
            X_corrupted, Y_corrupted = next(iter_corrupted)
            _, _, test_acc_OOD, _ = loss_fn(X_corrupted, Y_corrupted, False)
            test_auc = trainer.eval_ood(X_test, X_corrupted)
            brier, ece, nll = get_calibration_metrics(X_corrupted, Y_corrupted, self.tau)

            acc_OOD_vec.append(test_acc_OOD.numpy())
            auc_vec.append(test_auc)
            brier_vec.append(brier.numpy())
            ece_vec.append(ece.numpy())
            nll_vec.append(nll.numpy())
        
        results[(corruption_type, idx)] = (np.mean(acc_OOD_vec), np.mean(auc_vec), np.mean(brier_vec), np.mean(ece_vec), np.mean(nll_vec))

plot_OOD_metrics(results, tags)

In [None]:
plot_OOD_metrics(results, tags)

In [None]:
# rt = Trainer(params)
rt.compute_sharpness_metric()

In [None]:
with tf.GradientTape() as tape:
    loss, _, _, _ = t.loss_fn(*next(t.iter_train), training=False)
grad = np.concatenate([g.numpy().flatten() for g in tape.gradient(loss, t.model.trainable_variables)], axis=0)

In [None]:
tf.reduce_sum(tf.linalg.pinv(A))

In [None]:
A

In [None]:
A = tf.random.normal((2396330, 100))

In [None]:
# A = tf.Variable(np.random.randn((2396330, 100)), trainable=False)
tf.reduce_sum(tf.linalg.pinv(A) @ tf.random.normal((2396330,))[:, None])

In [None]:
a = 3
@tf.function
def a_function_with_python_side_effect_2(x, **xargs):
  print("Tracing!") # An eager-only side effect.
  return a + x + xargs['0'] * xargs['1']['b'] + (tf.constant(2.) if ('version' in xargs) and (xargs['version'] == 'add2') else tf.constant(0.))


@tf.function
def a_function_with_python_side_effect(x, **xargs):
#   print("Tracing!") # An eager-only side effect.
  return a_function_with_python_side_effect_2(x, **xargs)

# This is traced the first time.


def conv_to_tensors(p):
    return {k: tf.constant(v) if type(v) != dict else conv_to_tensors(v) for k, v in p.items()}

x = tf.Variable(20.)
with tf.GradientTape() as tape:
    tape.watch(x)
    p = {'0':1., '1': {'a': 2., 'b':3.}, 'version': 'dontadd2'}
    print(a_function_with_python_side_effect(x, **conv_to_tensors(p)))
    # The second time through, you won't see the side effect.
#     x.assign(10)
#     p = {'0':3, '1':{'a': 2, 'b':-3}, 'version': 'add2'}
#     print(a_function_with_python_side_effect(x, **conv_to_tensors(p)))
# grad = tape.gradient(v, x)
# grad

In [None]:


x = tf.Variable(2.)
# with tf.GradientTape() as tape:
#     tape.watch(x)
# for i in range(10):
#     xargs['0'] = 2. * i
#     x.assign(i*1.0)
    
xargs = {'0':1., '1': {'a': 2., 'b':3.}, 'version': 'dontadd2'}

@tf.function
def a_function_with_python_side_effect(x, xargs):
  print("Tracing!") # An eager-only side effect.
  return x + xargs['0'] * xargs['1']['b'] + (tf.constant(2.) if ('version' in xargs) and (xargs['version'] == 'add2') else tf.constant(0.))


print(a_function_with_python_side_effect(x, xargs))

# xargs['0']=2
# tf.function(a_function_with_python_side_effect).get_concrete_function(x)
print(a_function_with_python_side_effect(x, xargs))


# tf.function(evaluate).get_concrete_function(new_model, x)
# print(a_function_with_python_side_effect(x, **conv_to_tensors(p)))
    
    # The second time through, you won't see the side effect.
#     x.assign(10)
#     p = {'0':3, '1':{'a': 2, 'b':-3}, 'version': 'add2'}
#     print(a_function_with_python_side_effect(x, **conv_to_tensors(p)))
# grad = tape.gradient(v, x)
# grad