## Load libraries

In [None]:
import os
import logging

import pandas as pd
import numpy as np
import h5py
from sklearn.metrics import classification_report

from alphai_watson.performance import GANPerformanceAnalysis
from alphai_watson.transformer import NullTransformer
from alphai_rickandmorty_oracle.datasource.mnist import MNISTDataSource
from alphai_rickandmorty_oracle.detective import RickAndMortyDetective
from alphai_rickandmorty_oracle.model import RickAndMorty
from alphai_rickandmorty_oracle.networks.mnist import MNISTGanArchitecture

from matplotlib import pyplot as plt
import matplotlib.gridspec as gridspec

logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)

%matplotlib inline


## Define MNIST Datasource

In [None]:
file_path = '../../tests/resources'

abnormal_digit = 0

# Train and test data file
train_data_file = os.path.join(file_path, 'mnist_data_train_abnormalclass-{}.hd5'.format(abnormal_digit))
test_data_file = os.path.join(file_path, 'mnist_data_test_abnormalclass-{}.hd5'.format(abnormal_digit))

# Model parameters
n_sensors = 28
n_timesteps = 784 // n_sensors

train_data_source = MNISTDataSource(source_file=train_data_file, 
                                    transformer=NullTransformer(number_of_timesteps=n_timesteps,
                                                                number_of_sensors=n_sensors))
test_data_source = MNISTDataSource(source_file=test_data_file,
                                   transformer=NullTransformer(number_of_timesteps=n_timesteps,
                                                               number_of_sensors=n_sensors))

train_data = train_data_source.get_train_data('NORMAL')

### Plot input images to verify correctness

In [None]:
# First row is first image
image = train_data.data[np.random.randint(0, len(train_data.data))]

# h_train = h5py.File(test_data_file)
# image = np.array(h_train.get('ABNORMAL'))[156]

# Plot the image
plt.imshow(np.reshape(image, (28, 28)), cmap='gray')
plt.show()

## Define Model & Train

In [None]:
model_dir = './mnist_models'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

batch_size = 64
output_dimensions = 784
plot_dimensions = (28, 28)
train_iters = 100
plot_save_path = model_dir

architecture = MNISTGanArchitecture(output_dimensions, plot_dimensions)

model = RickAndMorty(architecture=architecture,
                     batch_size=batch_size,
                     train_iters=train_iters,
                     plot_save_path=plot_save_path)

detective = RickAndMortyDetective(
    model_configuration = {
        'model': model,
        'output_dimensions': output_dimensions,
        'batch_size': batch_size,
        'train_iters': train_iters,
        'save_path' : '{}/MNIST-abnormalclass-{}'.format(model_dir, abnormal_digit),
        'plot_save_path' : plot_save_path
    }
)

In [None]:
detective.train(train_data)

## Evaluate Results

### Load trained model

In [None]:
# detective = RickAndMortyDetective(model_configuration={
#     'model': model,
#     'batch_size': batch_size,
#     'output_dimensions': output_dimensions,
#     'train_iters': train_iters,
#     'load_path' : '{}/MNIST-abnormalclass-{}'.format(model_dir, abnormal_digit),
#     'plot_save_path' : plot_save_path
# })

### Visualise generated samples

In [None]:
n_row = 4
n_col = 8

plt.figure(figsize=(2*n_col, 2*n_row))

gs = gridspec.GridSpec(n_row, n_col)
gs.update(wspace=0.025, hspace=0.05)

generated_samples = detective.model.generate_fake_samples()

for i in range(n_row * n_col):
    ax = plt.subplot(gs[i])
    ax.axis('off')
    ax.imshow(generated_samples[i], cmap='gray')

### Load test data


In [None]:
# Get test data
test_data_normal = test_data_source.get_train_data('NORMAL')
test_data_abnormal = test_data_source.get_train_data('ABNORMAL')
test_data = test_data_source.get_train_data('ALL')

# Ground truth for ABNORMAL data is 0, ground truth for NORMAL data is 1
n1 = np.zeros(len(test_data_abnormal.data))
n2 = np.ones(len(test_data_normal.data))
expected_truth = np.hstack((n1, n2))

### Calculate ROC Score

In [None]:
detection_result = detective.detect(test_data)

roc_score = GANPerformanceAnalysis({}).analyse(
  detection_result=detection_result.data,
  expected_truth=expected_truth
)

print('ROC Score: {}'.format(roc_score))

### Generate classification report

In [None]:
train_results = detective.detect(train_data).data
threshold = np.mean(train_results)
prediction = [1 if x >= threshold else 0 for x in detection_result.data]

target_names = ['ABNORMAL', 'NORMAL']
print(classification_report(expected_truth, prediction, target_names=target_names))

## Root Cause Analysis

In [None]:
n_img = 4

normal_test_images = \
    test_data_normal.data[np.random.choice(test_data_normal.data.shape[0], n_img, replace=False), :, :]
abnormal_test_images = \
    test_data_abnormal.data[np.random.choice(test_data_abnormal.data.shape[0], n_img, replace=False), :, :]

normal_best_fakes = [detective.diagnose(nor_img) for nor_img in normal_test_images]
abnormal_best_fakes = [detective.diagnose(abn_img) for abn_img in abnormal_test_images]

# Plot normal class
fig, ax = plt.subplots(3, n_img, figsize=(16, 12))
fig.suptitle('Normal class', fontsize=28)
ax[0, 0].set_ylabel('Real', fontsize=20)
ax[1, 0].set_ylabel('Best fake', fontsize=20)
ax[2, 0].set_ylabel('Squared Distance', fontsize=20)

for i in range(n_img):
    ax[0, i].imshow(normal_test_images[i], cmap='gray'); 
    ax[1, i].imshow(normal_best_fakes[i], cmap='gray')
    ax[2, i].imshow(np.square(normal_test_images[i] - normal_best_fakes[i]), cmap='YlOrRd', vmin=0, vmax=1)

# Plot abnormal class
fig, ax = plt.subplots(3, n_img, figsize=(16, 12))
fig.suptitle('Abnormal class', fontsize=28)
ax[0, 0].set_ylabel('Real', fontsize=20)
ax[1, 0].set_ylabel('Best fake', fontsize=20)
ax[2, 0].set_ylabel('Squared Distance', fontsize=20)

for i in range(n_img):
    ax[0, i].imshow(abnormal_test_images[i], cmap='gray'); 
    ax[1, i].imshow(abnormal_best_fakes[i], cmap='gray')
    ax[2, i].imshow(np.square(abnormal_test_images[i] - abnormal_best_fakes[i]), cmap='YlOrRd', vmin=0, vmax=1)