Autoencoder Anomaly Testing
===

This is rebuilt from the "Collecting Network Statistics" notebook. The goal of this notebook is to collect together a set of in-distribution and out-of-distribution images and confirm that the model can distinguish them with a high degree of accuracy.


## Setup

We begin by importing our dependencies.

In [1]:
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import math

from pytorch_msssim import ssim,ms_ssim,SSIM
from model import SplitAutoencoder,ExtensibleEncoder,ExtensibleDecoder
from CustomDataSet import CustomDataSet,CustomDataSetWithError
from GaussianNoise import AddGaussianNoiseAndRescale,Rescale
import os

from sklearn.metrics import roc_auc_score,roc_curve,auc

import GPUtil

Set our seed and other configurations for reproducibility.

In [2]:
seed = 42
#seed = 2662
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

basis = "holdout"
distribution = "screening"

if torch.cuda.is_available():
    platform = "cuda"
else:
    platform = "cpu"
#platform = "cpu"
print(platform)

cuda


We set the batch size, the number of training epochs, and the learning rate. Batch size has to be reasonably low as we can't fit a huge number of these images into VRAM on my laptop.

Image size can be set here as I'm automatically resizing the images in my extraction code.

In [3]:
width = 256
height = 256

image_size = width * height

batch_size = 1

#meta-parameters
l2_decay = 0.0
dropout_rate = 0.1
code_sides = [20]
convolution_filters = 32

model_extension = str(width) + "_" + str(code_sides[0]) + "_" + str(convolution_filters) + "_" + str(dropout_rate) + "_" + str(l2_decay)
full_extension = "_" + basis + "_" + distribution + "_" + model_extension

model_path = "../../Data/OPTIMAM_NEW/model" + full_extension + ".pt"

#image_count = 500
image_count = -1

validation_split = 0.95

## Gather Base Distribution Information

First we run the model on the entire original distribution and gather statistics on the loss values, encodings etc.

In [4]:
from torchvision.transforms import ToTensor,Normalize
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(0.0,65535.0)
    ])

if distribution=="screening":
    root_dir = "../../Data/OPTIMAM_NEW/png_images/casewise/ScreeningMammography/" + str(width)
else:
    root_dir = "../../Data/OPTIMAM_NEW/png_images/lesions"
train_dataset = CustomDataSet(root_dir, transform)

if (image_count == -1):
    train_dataset_subset = train_dataset
    image_count = len(train_dataset)
else:
    train_dataset_subset = torch.utils.data.Subset(train_dataset, numpy.random.choice(len(train_dataset), image_count, replace=False))

train_subset_idx = np.random.choice(len(train_dataset), int(image_count * validation_split), replace=False)

n = np.arange(len(train_dataset))
mask = np.ones(len(train_dataset), dtype=bool)
mask[train_subset_idx] = False
holdout_subset_idx = n[mask]

dataset_size = len(train_dataset_subset)
      
t_subset = torch.utils.data.Subset(train_dataset_subset, train_subset_idx)

train_loader = torch.utils.data.DataLoader(
    t_subset, batch_size=batch_size, shuffle=False
)

h_subset = torch.utils.data.Subset(train_dataset_subset, holdout_subset_idx)

holdout_loader = torch.utils.data.DataLoader(
    h_subset, batch_size=batch_size, shuffle=False
)

In [5]:
print(train_subset_idx)
print(holdout_subset_idx)

[2934 1524 1451 ... 2597 3421 2618]
[  26   27   60   66   91  103  109  112  123  169  190  211  232  406
  422  430  468  470  480  496  520  526  558  559  562  570  575  580
  603  678  685  697  703  704  705  716  728  741  780  799  825  843
  858  911  944 1007 1016 1050 1061 1081 1110 1118 1137 1193 1229 1237
 1255 1265 1302 1354 1365 1385 1393 1409 1420 1457 1469 1470 1498 1513
 1538 1546 1565 1581 1594 1607 1625 1636 1640 1685 1718 1744 1750 1753
 1788 1805 1832 1859 1870 1891 1919 1932 1934 1948 1979 2000 2019 2022
 2023 2039 2044 2052 2100 2145 2150 2176 2183 2184 2194 2228 2258 2282
 2326 2336 2337 2367 2369 2377 2449 2481 2484 2503 2523 2529 2539 2545
 2552 2595 2610 2619 2652 2664 2671 2677 2678 2717 2720 2740 2820 2827
 2894 2896 2908 2928 2941 2953 2972 2997 3003 3004 3007 3019 3021 3081
 3087 3113 3118 3165 3192 3197 3203 3205 3225 3232 3238 3279 3281 3298
 3334 3352 3388 3391 3405 3408 3417 3429 3511 3520 3521 3568 3578 3592
 3641 3663]


In [6]:
#  use gpu if available
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device(platform)

code_size = code_sides[0] * code_sides[0]

# mean-squared error loss
criterion = nn.MSELoss()

features = [None] * len(t_subset)
losses = [None] * len(t_subset)
encodings = [None] * len(t_subset)
outputs = [None] * len(t_subset)

In [7]:
# reload the saved model
model = torch.load(model_path,map_location=device)
model.eval()

FileNotFoundError: [Errno 2] No such file or directory: '../../Data/OPTIMAM_NEW/model_holout_screening_256_20_32_0.1_0.0.pt'

We run our autoencoder on the entire dataset and store the encodings

In [None]:
with torch.no_grad():
    count = 0
    for batch_features in train_loader:
        # load it to the active device
        batch_features = batch_features.to(device)

        features[count] = batch_features.cpu()
        
        # compute reconstructions
        code = model.encoder(batch_features)
        output = model.decoder(code)
        
        outputs[count] = output.cpu()
        
        code_reshaped = code.detach().cpu().numpy()[0]
        code_reshaped.reshape(code_size)

        encodings[count] = code_reshaped

        # compute training reconstruction loss
        error_criterion = criterion(output,batch_features)

        losses[count] = error_criterion.cpu().numpy()

        count = count + 1

And calculate the encoding statistics:

In [None]:
encodings_n = np.stack(encodings)
encodings_t = torch.from_numpy(encodings_n)
encodings_mean = torch.mean(encodings_t,0)

encodings_mean_np = encodings_mean.numpy()
print(encodings_n)
encodings_covariance = np.cov(encodings_n.T)
print(encodings_covariance)
encodings_inv_covariance = np.linalg.inv(encodings_covariance)

In [None]:
mse_min = np.amin(losses)
mse_max = np.amax(losses)
mse_mean = np.mean(losses)
mse_std = np.std(losses)
print("MSE Min/Mean/Max/SD:" + str(mse_min) + "/" + str(mse_mean) + "/" + str(mse_max) + "/" + str(mse_std)   )

In [None]:
features_n = np.stack(features)
outputs_n = np.stack(outputs)
print(features_n.shape)
print(outputs_n.shape)

pre_ssims = [None] * len(encodings)

ssim_module = SSIM(data_range=1.0, size_average=False, channel=3)

for i in range(len(encodings)):
    features_s = features_n[i].reshape(1,1,height,width).repeat(3,1)
    outputs_s = outputs_n[i].reshape(1,1,height,width).repeat(3,1)
    ssim_f = ssim_module(torch.from_numpy(features_s), torch.from_numpy(outputs_s))
    pre_ssims[i] = ssim_f.item()
    
ssim_min = np.amin(pre_ssims)
ssim_max = np.amax(pre_ssims)
ssim_mean = np.mean(pre_ssims)
ssim_sd = np.std(pre_ssims)
print("SSIM Min/Mean/Max/SD:" + str(ssim_min) + "/" + str(ssim_mean) + "/" + str(ssim_max) + "/" + str(ssim_sd))

Now we save the compiled statistics to an excel file.

In [None]:
with torch.no_grad():
    np_losses = np.asarray(losses)
    np_pre_ssims = np.asarray(pre_ssims)
    np_compiled = np.concatenate((np_losses[:, np.newaxis], encodings), axis=1)

    suffix = full_extension
    
    np.savetxt('base_encodings' + suffix + '.csv', encodings, delimiter=',',fmt='%10.5f',newline='\n')
    np.savetxt('base_losses' + suffix + '.csv', np_losses, delimiter=',',fmt='%10.5f',newline='\n')
    np.savetxt('base_ssim' + suffix + '.csv', np_pre_ssims, delimiter=',',fmt='%10.5f',newline='\n')
    np.savetxt('base_combined' + suffix + '.csv', np_compiled, delimiter=',',fmt='%10.5f',newline='\n')

## Adversarials

We have 2 Datasets (mammographic and non-mammographic) and 3 DataLoaders - Clean Mammo, Distorted Mammo, and Non-Mammo. The goal here is to build an analogously large set of OOD images and test to what degree the autoencoder is capable of detecting the distortions.

The first method for doing this builds a large set of all the datasets classified into In-Distribution and Out-Of-Distribution and determine the accuracy rating of the model as a classifier. The second generates a set of distorted mammographic images at specified distances from the distribution, along with a value roughly analogous to that distortion level. This second method is intended to determine the range in distribution space at which the model becomes able to distinguish, as well as the degree of "grey area" between in and out of distribution (as detected).

In [None]:
with torch.no_grad():
    trigger_chance = 0.4

    PIL_transforms = torchvision.transforms.RandomApply([
        torchvision.transforms.RandomAffine(degrees=10,translate=(0.2,0.2),shear=25),
        torchvision.transforms.RandomVerticalFlip(),
        torchvision.transforms.RandomHorizontalFlip(),
#        torchvision.transforms.GaussianBlur(kernel_size=3)
        ],p=trigger_chance)
    
    tensor_transforms = torchvision.transforms.RandomApply([
        torchvision.transforms.RandomErasing(p=1.0,value=torch.rand(1).item()),
        torchvision.transforms.Lambda(lambda x : x + (torch.randn_like(x) * 0.001))
#        AddGaussianNoiseAndRescale(0.0,0.2)
        ],p=trigger_chance)

    adversarial_transform = torchvision.transforms.Compose([        
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(0.0,65535.0),
        PIL_transforms,
        tensor_transforms,
        Rescale()
        ])
   
    adversarial_image_count = image_count
    adversarial_dataset = CustomDataSetWithError(root_dir, adversarial_transform)
    
    if(basis == "holdout"): 
        a_subset = torch.utils.data.Subset(adversarial_dataset, holdout_subset_idx)
    else:
        a_subset = torch.utils.data.Subset(adversarial_dataset, train_subset_idx)

    adversarial_loader = torch.utils.data.DataLoader(
        a_subset,shuffle=True
    )
    
    

Build the first (mixed) set:

In [None]:
with torch.no_grad():
    adversarial_iterator = iter(adversarial_loader)
    genuine_iterator = iter(holdout_loader)
    mixed_set_scale = 300
    
    mixed_set = []
    mixed_set_np = []
    mixed_class_set = []
    mixed_error_set = []
    for i in range(mixed_set_scale):
        r = torch.rand(1)
        if(r.item() > 0.5):
            error = 0.0
            while(error==0.0):
                try:
                    adversarial_t = next(adversarial_iterator)
                    adversarial = adversarial_t[0].cpu()
                    adversarial_np = adversarial.numpy().reshape(width,height)
                    error = adversarial_t[1].item()
                except StopIteration:
                    adversarial_iterator = iter(adversarial_loader)
            mixed_class_set.append(1) # positive class, since we're trying to detect adversarials
            mixed_set_np.append(adversarial_np)
            mixed_set.append(adversarial)
            mixed_error_set.append(error)
            print("Positive: Min " + str(np.amin(adversarial_np)) + " Max " + str(np.amax(adversarial_np)) + " Error " + str(error))
        else:
            genuine = next(genuine_iterator).cpu()
            genuine_np = genuine.numpy().reshape(width,height)
            mixed_class_set.append(0) # negative class
            mixed_set.append(genuine)
            mixed_set_np.append(genuine_np)
            mixed_error_set.append(0.0) # genuine, so no drift
            print("Negative: Min " + str(np.amin(genuine_np)) + " Max " + str(np.amax(genuine_np)) + " Error 0.0")
        
    mixed_code_set = []
    mixed_code_np_set = []
    mixed_reconstruction_set = []
    mixed_losses = []

Now run the model on the mixed set:

In [None]:
with torch.no_grad():
    count = 0
    for mixed_item in mixed_set:
        mixed_example = mixed_item.to(device)
        
        n_code = model.encoder(mixed_example)
        reconstruction = model(mixed_example)
        
        mixed_code_set.append(n_code.cpu())
        mixed_code_np_set.append(n_code.cpu().numpy())
        mixed_reconstruction_set.append(reconstruction.cpu())
        
        error_criterion = criterion(reconstruction,mixed_example)
        loss = error_criterion.cpu().numpy()
        mixed_losses.append(loss)

Next, measure the loss and feature statistics for the adversarials:            

#### Mean Squared Error (MSE):

First calculate the MSE for all the reconstructed images.

In [None]:
mixed_losses_np = np.asarray(mixed_losses)
post_mse_min = np.amin(mixed_losses_np)
post_mse_max = np.amax(mixed_losses_np)
post_mse_mean = np.mean(mixed_losses_np)
post_mse_std = np.std(mixed_losses_np)
print("Prediction MSE Min/Mean/Max/SD:" + str(post_mse_min) + "/" + str(post_mse_mean) + "/" + str(post_mse_max) + "/" + str(post_mse_std))

And attempt to predict classes based on MSE:

In [None]:
mse_threshold = mse_mean + (2 * mse_std)

predicted_class = []
for i in range(mixed_set_scale):
    if(mixed_losses_np[i]<mse_threshold):
        predicted_class.append(0) # distribution
    else:
        predicted_class.append(1) # adversarial

mixed_class_np = np.asarray(mixed_class_set)
predicted_class_np = np.asarray(predicted_class)
print(mixed_class_np)
print(predicted_class_np)

tp = np.sum(mixed_class_np * predicted_class_np)
tn = np.sum((1 - mixed_class_np) * (1 - predicted_class_np))
fp = np.sum((1 - mixed_class_np) * predicted_class_np)
fn = np.sum(mixed_class_np * (1 - predicted_class_np))

hits = tn + tp
accuracy = hits / mixed_set_scale
print("Accuracy:" + str(accuracy))

precision = tp / (tp+fp)
recall = tp / (tp+fn)

f1_score = 2 * ((precision * recall) / (precision + recall))

print("Precision:" + str(precision))
print("Recall:" + str(recall))
print("F1:" + str(f1_score))

fails = np.where(mixed_class_np != predicted_class_np)
print(fails)

In [None]:
fpr, tpr, thresh = roc_curve(mixed_class_np, mixed_losses_np)
roc_auc = auc(fpr, tpr)
print(fpr)
print(tpr)
print(thresh)
plt.figure()
lw = 2
plt.plot(fpr, tpr, color='darkorange',
         lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('MSE basis ROC Curve')
plt.legend(loc="lower right")

out_path = "mse_roc_graph_output" + full_extension + ".eps"
plt.savefig(out_path)
plt.show()

#### STRUCTURAL SIMILARITY INDEX (SSIM)

Generate RGB versions of the base and recreated images.

In [None]:
mixed_set_s = np.stack(mixed_set).reshape(mixed_set_scale,1,height,width).repeat(3,1)
mixed_recon_s = np.stack(mixed_reconstruction_set).reshape(mixed_set_scale,1,height,width).repeat(3,1)
print(mixed_set_s.shape)
print(mixed_recon_s.shape)

And use that for Structural Similarity Index:

In [None]:
ssim_t = ssim(torch.from_numpy(mixed_set_s), torch.from_numpy(mixed_recon_s), data_range=1.0, size_average = False)
ms_ssim_t = ms_ssim(torch.from_numpy(mixed_set_s), torch.from_numpy(mixed_recon_s), data_range=1.0, size_average = False)

In [None]:
post_ssim_min = np.amin(ssim_t.numpy())
post_ssim_max = np.amax(ssim_t.numpy())
post_ssim_mean = np.mean(ssim_t.numpy())
post_ssim_std = np.std(ssim_t.numpy())
print("Prediction SSIM Min/Mean/Max/SD:" + str(post_ssim_min) + "/" + str(post_ssim_mean) + "/" + str(post_ssim_max) + "/" + str(post_ssim_std))

In [None]:
with torch.no_grad():
    np_post_losses = np.asarray(mixed_losses_np)
    np_post_ssims = np.asarray(ssim_t)
    np_distances = np.asarray(mixed_error_set)
    
    suffix = full_extension
    
    np.savetxt('mixed_losses' + suffix + '.csv', np_post_losses, delimiter=',',fmt='%10.5f',newline='\n')
    np.savetxt('mixed_ssim' + suffix + '.csv', np_post_ssims, delimiter=',',fmt='%10.5f',newline='\n')
    np.savetxt('mixed_distance' + suffix + '.csv', np_distances, delimiter=',',fmt='%10.5f',newline='\n')

Attempt to predict the class based on the SSIM:

In [None]:
ssim_threshold = ssim_mean - (2 * ssim_sd)
#ssim_threshold = ssim_min
predicted_class_ssim = []
for i in range(mixed_set_scale):
    if(ssim_t[i]>ssim_threshold):
        predicted_class_ssim.append(0) # distribution
    else:
        predicted_class_ssim.append(1) # adversarial

predicted_class_ssim_np = np.asarray(predicted_class_ssim)
print(mixed_class_np)
print(predicted_class_ssim_np)
tp_ssim = np.sum(mixed_class_np * predicted_class_ssim_np)
tn_ssim = np.sum((1 - mixed_class_np) * (1 - predicted_class_ssim_np))
fp_ssim = np.sum((1 - mixed_class_np) * predicted_class_ssim_np)
fn_ssim = np.sum(mixed_class_np * (1 - predicted_class_ssim_np))

hits_ssim = tn_ssim + tp_ssim
accuracy_ssim = hits_ssim / mixed_set_scale
print("Accuracy:" + str(accuracy_ssim))

precision_ssim = tp_ssim / (tp_ssim+fp_ssim)
recall_ssim = tp_ssim / (tp_ssim+fn_ssim)

f1_score_ssim = 2 * ((precision_ssim * recall_ssim) / (precision_ssim + recall_ssim))

print("Precision:" + str(precision_ssim))
print("Recall:" + str(recall_ssim))
print("F1:" + str(f1_score_ssim))

In [None]:
#confidence value for roc curve is 1-SSIM errpr (ie, for extremely high error we are less sure)
fpr, tpr, thresh = roc_curve(mixed_class_np, 1 - ssim_t)
roc_auc = auc(fpr, tpr)
print(fpr)
print(tpr)
print(thresh)
plt.figure()
lw = 2
plt.plot(fpr, tpr, color='darkorange',
         lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('SSIM Basis ROC Curve')
plt.legend(loc="lower right")

out_path = "ssim_roc_graph_output" + full_extension + ".eps"
plt.savefig(out_path)
plt.show()

#### Mahalanobis Distance

We calculate the Mahalanobis distance for every result in the set.

In [None]:
with torch.no_grad():
    mahalanobis_set = []
    for i in range(len(mixed_code_set)):
        code = mixed_code_np_set[i]
        x_minus_mu =  code - encodings_mean_np
        left_term = np.dot(x_minus_mu, encodings_inv_covariance)
        mahal = np.dot(left_term, x_minus_mu.T)
        m = mahal.diagonal()
        
        mahalanobis_set.append(m)
    mahalanobis_set_np = np.stack(mahalanobis_set)

In [None]:
mahal_threshold = 240
#ssim_threshold = ssim_min
predicted_class_mahal = []
for i in range(mixed_set_scale):
    if(mahalanobis_set[i]<mahal_threshold):
        predicted_class_mahal.append(0) # distribution
    else:
        predicted_class_mahal.append(1) # adversarial

predicted_class_mahal_np = np.asarray(predicted_class_mahal)
print(mixed_class_np)
print(predicted_class_mahal_np)
tp_mahal = np.sum(mixed_class_np * predicted_class_mahal_np)
tn_mahal = np.sum((1 - mixed_class_np) * (1 - predicted_class_mahal_np))
fp_mahal = np.sum((1 - mixed_class_np) * predicted_class_mahal_np)
fn_mahal = np.sum(mixed_class_np * (1 - predicted_class_mahal_np))

hits_mahal = tn_mahal + tp_mahal
accuracy_mahal = hits_mahal / mixed_set_scale
print("Accuracy:" + str(accuracy_mahal))

precision_mahal = tp_mahal / (tp_mahal+fp_mahal)
recall_mahal = tp_mahal / (tp_mahal+fn_mahal)

f1_score_mahal = 2 * ((precision_mahal * recall_mahal) / (precision_mahal + recall_mahal))

print("Precision:" + str(precision_mahal))
print("Recall:" + str(recall_mahal))
print("F1:" + str(f1_score_mahal))

In [None]:
fpr, tpr, thresh = roc_curve(mixed_class_np, mahalanobis_set)
roc_auc = auc(fpr, tpr)
plt.figure()
lw = 2
plt.plot(fpr, tpr, color='darkorange',
         lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Mahalanobis Basis ROC Curve')
plt.legend(loc="lower right")

out_path = "mahal_roc_graph_output" + full_extension + ".eps"
plt.savefig(out_path)
plt.show()

## Plot Outputs

And plot the first 10 results:

In [None]:
with torch.no_grad():
    mixed_set_g = np.stack(mixed_set).reshape(mixed_set_scale,height,width,1).repeat(3,3)
    mixed_recon_g = np.stack(mixed_reconstruction_set).reshape(mixed_set_scale,height,width,1).repeat(3,3)
    print(mixed_set_g.shape)
    print(mixed_recon_g.shape)

    #number = mixed_set_scale
    number = 10
    plt.figure(figsize=(25, 9))
    for index in range(number):
        # display original
        ax = plt.subplot(3, number, index + 1)
        test_examples = mixed_set_g
        copyback = test_examples[index]
        print(copyback.shape)
        plt.imshow(copyback)
        #plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        # display codes
        ax = plt.subplot(3, number, index + number + 1)
        codes = mixed_code_set
        code_copyback = codes[index].cpu()
        plt.imshow(code_copyback.reshape(code_sides[0],code_sides[0]))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        # display reconstruction
        ax = plt.subplot(3, number, index + (number*2) + 1)
        reconstruction = mixed_recon_g
        recon_copyback = reconstruction[index]
        plt.imshow(recon_copyback)
        #plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

    out_path = "adv_output" + full_extension + ".png"
    plt.savefig(out_path)
    out_path = "adv_output" + full_extension + ".eps"
    plt.savefig(out_path)
    plt.show()

And plot the failures from the MSE calculation:

In [None]:
with torch.no_grad():
    if(len(fails[0])<30):
        #number = mixed_set_scale
        number = len(fails[0])
        plt.figure(figsize=(25, 9))
        for index in range(number):
            image_index = fails[0][index]
            # display original
            ax = plt.subplot(3, number, index + 1)
            test_examples = mixed_set_g
            copyback = test_examples[image_index]
            print(copyback.shape)
            plt.imshow(copyback)
            #plt.gray()
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)

            # display codes
            ax = plt.subplot(3, number, index + number + 1)
            codes = mixed_code_set
            code_copyback = codes[image_index].cpu()
            plt.imshow(code_copyback.reshape(code_sides[0],code_sides[0]))
            plt.gray()
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)

            # display reconstruction
            ax = plt.subplot(3, number, index + (number*2) + 1)
            reconstruction = mixed_recon_g
            recon_copyback = reconstruction[image_index]
            plt.imshow(recon_copyback)
            #plt.gray()
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)

        out_path = "fail_output" + full_extension + ".png"
        plt.savefig(out_path)
        out_path = "fail_output" + full_extension + ".eps"    
        plt.savefig(out_path)
        plt.show()

Finally, plot projected error against reconstructed MSE:

In [None]:
pred_1 = np.where(predicted_class_np != 0)
pred_0 = np.where(predicted_class_np == 0)

mixed_error_set_np = np.asarray(mixed_error_set)
mixed_error_0 = mixed_error_set_np[pred_0]
mixed_error_1 = mixed_error_set_np[pred_1]

post_losses_0 = np_post_losses[pred_0]
post_losses_1 = np_post_losses[pred_1]

In [None]:
plt.figure(figsize=(6,6))
plt.title('Reconstruction relationship')
plt.xlabel('Original Error')
plt.ylabel('Reconstruction Error')
plt.plot(mixed_error_0, post_losses_0, 'o', color='red');
plt.plot(mixed_error_1, post_losses_1, 'o', color='blue');
out_path = "reconstruction_graph_output" + full_extension + ".eps"
plt.savefig(out_path)
plt.show()

#### Distribution histogram

The high ROC can be explained by looking at the comparative distribution histograms.

In [None]:
losses_np = np.stack(losses)
print(len(post_losses_0))
print(len(post_losses_1))
print(len(losses_np))

In [None]:
plt.figure(figsize=(6,6))
plt.title("Predicted Distribution histogram - MSE")
plt.xlabel("MSE")
plt.ylabel("Frequency")
plt.hist(losses_np,density=True,bins=80,color='green')
plt.hist(post_losses_0,density=True,bins=80,color='blue')
plt.hist(post_losses_1,density=True,bins=80,color='red')
plt.show()

In [None]:
base_1 = np.where(mixed_class_np != 0)
base_0 = np.where(mixed_class_np == 0)

mixed_error_set_np = np.asarray(mixed_error_set)
mixed_error_0 = mixed_error_set_np[base_0]
mixed_error_1 = mixed_error_set_np[base_1]

base_losses_0 = mixed_losses_np[base_0]
base_losses_1 = mixed_losses_np[base_1]

In [None]:
plt.figure(figsize=(6,8))
plt.title("Base Distribution boxplot - MSE")
plt.xlabel("Distribution")
plt.ylabel("MSE")
plt.boxplot((losses,base_losses_0,base_losses_1),labels=('Base','Real','Adversarial'))
out_path = "distribution_boxplot_output" + full_extension + ".eps"
plt.savefig(out_path)
plt.show()

In [None]:
plt.figure(figsize=(6,6))
plt.title("Actual Distribution histogram - MSE")
plt.xlabel("MSE")
plt.ylabel("Frequency Density")

#plt.hist(base_losses_0,density=True,bins=20,color='blue')

plt.hist(losses_np,density=True,bins=80,color='green')
plt.hist(base_losses_1,density=True,bins=80,color='red')
out_path = "distribution_hist_output" + full_extension + ".eps"
plt.savefig(out_path)
plt.show()