In [None]:
# Automatically re-import files when updated
%load_ext autoreload
%autoreload 2  
import time, json, pandas as pd, IPython
# Load packages
from models import *
from utils import *
from matplotlib import pyplot as plt

# Load models

In [None]:
# Load models
with tf.keras.utils.custom_object_scope({
    'Inception': Inception,
    "downsampler": downsampler,
    "upsampler": upsampler
    }):
    decoder = tf.keras.models.load_model("decoder.keras")
    encoder = tf.keras.models.load_model("encoder.keras")
    classifier = tf.keras.models.load_model("classifier.keras")

# Visualize results

In [None]:
# Get data
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
# Convert to tf dataset 
data_train = dataset_from_arrays(X_train, y_train)
data_test = dataset_from_arrays(X_test, y_test)

# For originals
original_predictions = classifier(imgs)
# For adversarials
adv_imgs = fgsm(classifier, tf.keras.losses.CategoricalCrossentropy(), imgs, labels)
adv_predictions = classifier(adv_imgs)
# For reconstructions
z = encoder(adv_imgs)
reconstructed = decoder(z)
predictions = classifier(reconstructed)

# Plot
for (img, orig_pred, adv_img, rec, label, pred, adv_pred) in zip(imgs, original_predictions, adv_imgs, reconstructed, labels, predictions, adv_predictions):
    # Clear display
    IPython.display.clear_output()
    # Create subplots
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
    # Plot the original image
    ax1.imshow(img)
    ax1.set_title("Original image")
    ax1.axis("off")
    # Plot the original image
    ax2.imshow(adv_img)
    ax2.set_title("Adversarial example")
    ax2.axis("off")
    # Plot the reconstructed image
    ax3.imshow(rec)
    ax3.set_title("Reconstructed image")
    ax3.axis("off")
    # Add context
    print(f"True label: {tf.math.argmax(label).numpy()}\
        \npredicted label on original image: {tf.math.argmax(orig_pred).numpy()}\
        \nPredicted label on adversarial example: {tf.math.argmax(adv_pred).numpy()}\
        \nPredicted label on reconstructed image: {tf.math.argmax(pred).numpy()}")
    # Show for some time
    plt.show()
    time.sleep(2)

# Training progress

In [None]:
# Plot training progress
history = json.load(open("history.json", mode="r"))
history = pd.DataFrame.from_dict(history, orient="index")
history["classifier_loss"].plot();
history["ae_loss"].plot(secondary_y=True);
plt.legend();

# Plot models

In [None]:
# Plot models
tf.keras.utils.plot_model(encoder, show_layer_names=False, to_file="encoder.png");
conv = downsampler(1, 6, 1, 1)
conv.build((1,*INPUT_SHAPE))
tf.keras.utils.plot_model(conv, show_layer_names=False, to_file="downsampler.png");
inception = Inception(2)
inception.build((1,*INPUT_SHAPE))
tf.keras.utils.plot_model(inception, show_layer_names=False, to_file="inception.png", expand_nested=True);