In [2]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import tensorflow_datasets as tfds
import tqdm
from tensorflow.keras.models import Model
from sklearn.metrics import roc_auc_score
from sklearn.calibration import calibration_curve
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import tensorflow_probability as tfp
from scipy.ndimage import gaussian_filter1d

In [4]:
ds_train, ds_test = out = tfds.load(
    'imagenet2012',
    split=['train', 'validation'],
    data_dir='/amrith/tensorflow_datasets',
    shuffle_files=True,
    as_supervised=True,
    with_info=False,
)
# ds_train = make_ds(ds_train)
# ds_test = make_ds(ds_test)
# iter_train = iter(ds_train)
# iter_test = iter(ds_test)

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
  tf.config.experimental.set_memory_growth(gpu, True)

### cifar autoencoder + adv loss experiments

In [None]:
# helper functions to create TF datasets

def normalize_img(image, label):
    """Normalizes images: `uint8` -> `float32`."""
    image = tf.cast(image, tf.float32) / 255.
    # image = tf.reshape(image, (-1,))
    return image, label


def make_ds(ds):
    ds = ds.map(
        normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds = ds.cache()
    ds = ds.repeat()
    ds = ds.shuffle(int(1e5))
    ds = ds.batch(128)
    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
    return ds


In [None]:
# create dataset, iterate over it and train, return model and metrics
class Trainer:
    
    def __init__(self, params):
        self.params = params
        print(f"\nRun training with params {self.params}")
        
        # create CIFAR dataset and get iterators
        ds_train, ds_test = out = tfds.load(
            'cifar10',
            split=['train', 'test'],
            shuffle_files=True,
            as_supervised=True,
            with_info=False,
        )
        ds_train = make_ds(ds_train)
        ds_test = make_ds(ds_test)
        self.iter_train = iter(ds_train)
        self.iter_test = iter(ds_test)

        # define model
        self.encoder = tf.keras.models.Sequential([
            tf.keras.layers.Conv2D(256, (3, 3), activation='relu', input_shape=params['input_shape']),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same',),
            tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same',),
            tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Flatten(),
        ])
        self.decoder = tf.keras.models.Sequential([
            tf.keras.layers.Dense(1024, activation='relu', input_shape=[self.encoder.output.shape[1]]),
            tf.keras.layers.Reshape((4, 4, 64)),
            tf.keras.layers.Conv2DTranspose(64, (3, 3), activation='relu'),
            tf.keras.layers.UpSampling2D((2, 2)),
            tf.keras.layers.Conv2DTranspose(64, (3, 3), activation='relu'),
            tf.keras.layers.UpSampling2D((2, 2)),
            tf.keras.layers.Conv2DTranspose(32, (3, 3), activation='relu'),
            tf.keras.layers.Conv2DTranspose(3, (3, 3), activation='relu'),
        ])
        self.classifier_layer = tf.keras.models.Sequential([
            tf.keras.layers.Dense(1024, activation='relu', input_shape=[self.encoder.output.shape[1]]),
            tf.keras.layers.Dense(512, activation='relu'),
            tf.keras.layers.Dense(self.params['n_classes'], activation='softmax')
        ])
        print('----Encoder----')
        print(self.encoder.summary())
        print('----Decoder----')
        print(self.decoder.summary())
        print('----Classifier----')
        print(self.classifier_layer.summary())

        # recon_model = encoder + decoder, classifier_model = encoder + classifier layer
        self.recon_model = tf.keras.models.Sequential([self.encoder, self.decoder])
        self.classifier_model = tf.keras.models.Sequential([self.encoder, self.classifier_layer])
        
        # define optimizer
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.params['lr'])
        
        # maintain history
        self.history = []
        


    # define loss function
    # computes loss given a model and X, Y
    @tf.function
    def loss_fn(self, X, Y):
        X_enc = self.encoder(X)
        # reconstruct
        X_hat = self.decoder(X_enc)
        bsize = X_hat.shape[0]
        recon_loss =  tf.reduce_mean(
                    tf.norm(tf.reshape(X_hat, (bsize, -1)) - tf.reshape(X, (bsize, -1)), axis=1))
        # classify
        Y_hat = self.classifier_layer(X_enc)
        cls_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(y_true=Y, y_pred=Y_hat)
        accuracy = tf.reduce_mean(tf.cast(tf.argmax(Y_hat, axis=1) == Y, tf.float32)) * 100.
        return cls_loss + recon_loss * params['lambda'], recon_loss, cls_loss, accuracy 
        

    # define step function
    # computes gradients and applies them
    @tf.function
    def step_fn(self, X_train, Y_train):
        with tf.GradientTape() as tape:
            total_loss, recon_loss, cls_loss, accuracy = self.loss_fn(X_train, Y_train)
        trainable_variables = self.encoder.trainable_variables + self.decoder.trainable_variables + self.classifier_layer.trainable_variables
        grads = tape.gradient(
            total_loss, trainable_variables)
        self.optimizer.apply_gradients(zip(grads, trainable_variables))
        return total_loss, recon_loss, cls_loss, accuracy
    
    
    def train(self):
        
        # loop over data n_iters times
        for t in tqdm.trange(self.params['n_iters']):
            train_total_loss, train_recon_loss, train_cls_loss, train_accuracy   = self.step_fn(*next(self.iter_train))
            if t % 10 == 0:
                test_total_loss, test_recon_loss, test_cls_loss, test_accuracy  = self.loss_fn(*next(self.iter_test))
            self.history.append((train_total_loss.numpy(), test_total_loss.numpy(), 
                                 train_recon_loss.numpy(), test_recon_loss.numpy(), 
                                 train_cls_loss.numpy(), test_cls_loss.numpy(),
                                 train_accuracy, test_accuracy))
        self.history = np.array(self.history)

In [None]:
# plotting utils
def plot_training_metrics(baseline_trainer, our_trainer=None, tags=['baseline', '']):
    
    trainer_vec = [baseline_trainer, our_trainer]
    
    c_vec = plt.rcParams['axes.prop_cycle'].by_key()['color']
    plt.figure(figsize=(20, 3))
    
    # total loss
    plt.subplot(1, 4, 1)
    plt.title("Train/Test total loss")
    for idx, (c, trainer) in enumerate(zip(c_vec, trainer_vec)):
        name = tags[idx]
        if trainer is None:
            continue
        plt.plot(gaussian_filter1d(trainer.history[:, 0], 100), '-o', c=c, label='%s - train' % name, markevery=1000, markersize=10)
        plt.plot(gaussian_filter1d(trainer.history[:, 1], 100), '-^', c=c, label='%s - val' % name, markevery=1000, markersize=10)
    plt.grid()
    plt.xlabel('iterations')
    plt.ylabel('Total loss')
    
    # recon loss
    plt.subplot(1, 4, 2)
    plt.title("Train/Test recon loss")
    for idx, (c, trainer) in enumerate(zip(c_vec, trainer_vec)):
        name = tags[idx]
        if trainer is None:
            continue
        plt.plot(gaussian_filter1d(trainer.history[:, 2], 100), '-o', c=c, label='%s - train' % name, markevery=1000, markersize=10)
        plt.plot(gaussian_filter1d(trainer.history[:, 3], 100), '-^', c=c, label='%s - val' % name, markevery=1000, markersize=10)
    plt.grid()
    plt.xlabel('iterations')
    plt.ylabel('Recon loss')
    
    # cls loss
    plt.subplot(1, 4, 3)
    plt.title("Train/Test cls loss")
    for idx, (c, trainer) in enumerate(zip(c_vec, trainer_vec)):
        name = tags[idx]
        if trainer is None:
            continue
        plt.plot(gaussian_filter1d(trainer.history[:, 4], 100), '-o', c=c, label='%s - train' % name, markevery=1000, markersize=10)
        plt.plot(gaussian_filter1d(trainer.history[:, 5], 100), '-^', c=c, label='%s - val' % name, markevery=1000, markersize=10)
    plt.grid()
    plt.xlabel('iterations')
    plt.ylabel('Cls loss')
    
    # accuracy
    plt.subplot(1, 4, 4)
    plt.title("Train/Test accuracy")
    for idx, (c, trainer) in enumerate(zip(c_vec, trainer_vec)):
        name = tags[idx]
        if trainer is None:
            continue
        plt.plot(gaussian_filter1d(trainer.history[:, 6], 100), '-o', c=c, label='%s - train' % name, markevery=1000, markersize=10)
        plt.plot(gaussian_filter1d(trainer.history[:, 7], 100), '-^', c=c, label='%s - val' % name, markevery=1000, markersize=10)
    plt.grid()
    plt.xlabel('iterations')
    plt.ylabel('accuracy')
    
    
    plt.legend()

In [None]:
# test recon loss
params = {
    'input_shape': (32, 32, 3),
    'lr': 3e-4,
    'n_iters': 15000,
    'n_classes': 10,
    'lambda': 0.05,
}
trainer = Trainer(params)
trainer.train()
plot_training_metrics(trainer)

In [None]:
trainer.history[0]

In [None]:
imgs, _ = next(trainer.iter_train)

In [None]:
imgs[0].shape

In [None]:
plt.imshow(tf.cast(imgs[0] * 255., tf.uint8))

In [None]:
tf.norm(trainer.decoder(trainer.encoder(imgs))[0] - imgs[0])

In [None]:
plt.imshow(tf.cast(trainer.decoder(trainer.encoder(imgs))[0] * 255., tf.uint8))