In [1]:
#%load_ext tensorboard
import os
import datetime

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.backend as K
import tensorflow.keras.layers as layers
from tensorflow.keras.datasets import mnist
from sklearn.metrics import roc_auc_score
from tqdm.notebook import tqdm_notebook

from util import load_data, create_dataset

In [2]:
(x_train, y_train), (x_test, y_test) = load_data(reshape=1)

In [3]:
def create_model():
    # Fully-connected Variational Autoencoder
    input_size = 28 * 28
    inter_size = 64
    latent_size = 2

    inputs      = keras.Input(shape=(input_size,))
    h           = layers.Dense(inter_size, activation='relu')(inputs)
    z_mean      = layers.Dense(latent_size)(h)
    z_log_sigma = layers.Dense(latent_size)(h)

    # Function for sampling the latent feature space
    def sampling(args):
        z_mean, z_log_sigma = args
        epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_size),
                                  mean=0., stddev=0.1)
        return z_mean + K.exp(z_log_sigma) * epsilon

    z = layers.Lambda(sampling)([z_mean, z_log_sigma])

    # Create encoder
    encoder = keras.Model(inputs, [z_mean, z_log_sigma, z], name='encoder')

    # Create decoder
    latent_inputs  = keras.Input(shape=(latent_size,), name='z_sampling')
    x              = layers.Dense(inter_size, activation='relu')(latent_inputs)
    outputs        = layers.Dense(input_size, activation='sigmoid')(x)
    decoder = keras.Model(latent_inputs, outputs, name='decoder')

    # instantiate VAE model
    outputs = decoder(encoder(inputs)[2])
    vae = keras.Model(inputs, outputs, name='vae_mlp')

    # Loss
    reconstruction_loss = keras.losses.binary_crossentropy(inputs, outputs)
    reconstruction_loss *= input_size
    kl_loss = 1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma)
    kl_loss = K.sum(kl_loss, axis=-1)
    kl_loss *= -0.5
    vae_loss = K.mean(reconstruction_loss + kl_loss)
    vae.add_loss(vae_loss)
    vae.compile(optimizer='adam')
    return vae

def fit_model(model, x_train, y_train, x_test, y_test, normal=4, verbose=1):
    #logdir = os.path.join("logs", str(normal)+"_"+datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
    #tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)
    callbacks = [
        keras.callbacks.EarlyStopping(patience=10)
        #tensorboard_callback
    ]
    train = x_train[y_train==normal]
    test  = x_test[y_test==normal]
    history = model.fit(train, train,
                    epochs=100,
                    batch_size=256,
                    verbose=verbose,
                    shuffle=True,
                    validation_data=(test, test),
                    callbacks=callbacks)
    return model

In [4]:
evals = np.zeros((10, 30))
for i in range(10):
    # Evaluate for all numbers
    for j in tqdm_notebook(range(30)):
        # Evaluate each method 30 times
        model = create_model()
        model = fit_model(model, x_train, y_train, x_test, y_test, normal=i, verbose=0)

        x = np.copy( x_test )
        y = y_test
        labels = np.copy( y )
        labels[ y == i ] = 0
        labels[ y != i ] = 1
        
        xhat = model.predict(x)
        
        x = x.reshape(len(x), 28*28)
        xhat = xhat.reshape(len(xhat), 28*28)
        
        err  = np.sum(np.abs(x-xhat), axis=1)
        # Max-Min normalize the error
        err /= np.max(err)
        # Compute AUC
        AUC = roc_auc_score(labels, err)
        evals[i,j] = AUC
    print(np.mean(evals[i,:])*100)

print(np.mean(evals[:,:5], axis=1)*100)
print(np.std(evals, axis=1)*100)

HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


95.54568607930976


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


99.83334053882143


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


87.81465289591243


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


85.72742467795167


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


87.87840664751815


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


91.25928322019179


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


95.89794347941265


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


93.65773699036576


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


79.54221135140358


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


92.01616989746053
[95.14559143 99.84190662 88.52210422 85.76322537 87.78364972 90.78040839
 96.029651   93.65173926 79.12487812 92.04149199]
[1.04027055 0.02136071 1.18924201 1.01794828 0.72314154 0.62976151
 0.34802407 0.80290864 0.69382041 0.2161652 ]
