## Load libraries

In [15]:
import os
import pandas as pd
import numpy as np
from sklearn.metrics import precision_recall_fscore_support

from alphai_watson.performance import GANPerformanceAnalysis
from alphai_watson.transformer import NullTransformer
from alphai_rickandmorty_oracle.datasource.mnist import MNISTDataSource
from alphai_rickandmorty_oracle.detective import RickAndMortyDetective

## Define MNIST Datasource

In [2]:
file_path = '../../tests/resources'

abnormal_digit = 0

# Train and test data file
train_data_file = os.path.join(file_path, 'mnist_data_train_abnormalclass-{}.hd5'.format(abnormal_digit))
test_data_file = os.path.join(file_path, 'mnist_data_test_abnormalclass-{}.hd5'.format(abnormal_digit))

# Model parameters
n_sensors = 28
n_timesteps = 784 // n_sensors

train_data_source = MNISTDataSource(source_file=train_data_file, 
                                    transformer=NullTransformer(number_of_timesteps=n_timesteps,
                                                                number_of_sensors=n_sensors))
test_data_source = MNISTDataSource(source_file=test_data_file,
                                   transformer=NullTransformer(number_of_timesteps=n_timesteps,
                                                               number_of_sensors=n_sensors))

train_data = train_data_source.get_train_data('NORMAL')

DEBUG:root:Start file parsing
DEBUG:root:end file parsing
DEBUG:root:Start file parsing
DEBUG:root:end file parsing


## Define Model

In [3]:
model_dir = './mnist_models'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

detective = RickAndMortyDetective(model_configuration={
    'batch_size': 64,
    'output_dimensions': 784,
    'train_iters': 300,
    'save_path' : '{}/MNIST-abnormalclass-{}'.format(model_dir, abnormal_digit),
    'plot_save_path' : model_dir
})

detective.train(train_data)

DEBUG:root:Starting session
DEBUG:root:Start training loop...
INFO:root:Initialising Model
INFO:root:Training iteration 0 of 300
DEBUG:matplotlib.font_manager:findfont: Matching :family=sans-serif:style=normal:variant=normal:weight=normal:stretch=normal:size=10.0 to DejaVu Sans ('/opt/anaconda/envs/ai/lib/python3.6/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans.ttf') with score of 0.050000


iter 0	train disc cost	4.037641525268555	time	1.510230541229248
iter 1	train disc cost	3.8670847415924072	time	0.5102806091308594
iter 2	train disc cost	3.4915263652801514	time	0.39618349075317383
iter 3	train disc cost	3.1009411811828613	time	0.39540553092956543
iter 4	train disc cost	2.637146234512329	time	0.39788055419921875


INFO:root:Training iteration 10 of 300
INFO:root:Training iteration 20 of 300
INFO:root:Training iteration 30 of 300
INFO:root:Training iteration 40 of 300
INFO:root:Training iteration 50 of 300
INFO:root:Training iteration 60 of 300
INFO:root:Training iteration 70 of 300
INFO:root:Training iteration 80 of 300
INFO:root:Training iteration 90 of 300
INFO:root:Saving fake samples to png: [[4.0527284e-01 2.4000376e-01 3.0115387e-01 ... 8.9863385e-04
  9.3013444e-04 9.0060709e-03]
 [4.2366964e-01 2.9641557e-01 3.3525011e-01 ... 1.7667129e-03
  1.3773292e-03 1.5553150e-02]
 [4.2145753e-01 2.7865052e-01 3.3675086e-01 ... 3.0118148e-03
  2.3665472e-03 2.4681127e-02]
 ...
 [4.1139325e-01 2.3799224e-01 3.2147908e-01 ... 1.1823700e-03
  9.7324932e-04 8.1918174e-03]
 [4.1284037e-01 2.6070544e-01 3.1745490e-01 ... 6.2933320e-04
  3.6564804e-04 6.5768426e-03]
 [4.1086984e-01 2.5339261e-01 3.1878462e-01 ... 1.0698074e-03
  9.6572412e-04 6.6419370e-03]]
INFO:root:Training iteration 100 of 300


iter 99	train disc cost	-0.17775702476501465	time	0.3984001310248124


INFO:root:Training iteration 110 of 300
INFO:root:Training iteration 120 of 300
INFO:root:Training iteration 130 of 300
INFO:root:Training iteration 140 of 300
INFO:root:Training iteration 150 of 300
INFO:root:Training iteration 160 of 300
INFO:root:Training iteration 170 of 300
INFO:root:Training iteration 180 of 300
INFO:root:Training iteration 190 of 300
INFO:root:Saving fake samples to png: [[1.3422823e-02 1.2413806e-03 1.3170124e-03 ... 4.1285674e-07
  1.4309966e-06 3.2356380e-05]
 [3.5850018e-02 4.9649724e-03 3.9995620e-03 ... 1.2199120e-06
  3.2905812e-06 8.2162398e-05]
 [2.7280426e-02 4.1583166e-03 3.9525921e-03 ... 4.6929413e-06
  1.0976233e-05 2.6004211e-04]
 ...
 [1.7542917e-02 1.6914802e-03 2.1081713e-03 ... 6.3477677e-07
  1.7679511e-06 4.1119751e-05]
 [1.8112201e-02 2.2311225e-03 2.3569188e-03 ... 1.0988447e-07
  2.8662546e-07 1.6359489e-05]
 [2.0872641e-02 2.6567273e-03 2.9331450e-03 ... 4.9102954e-07
  1.4841416e-06 2.5545311e-05]]
INFO:root:Training iteration 200 of 30

iter 199	train disc cost	-0.11982018500566483	time	0.39054878234863283


INFO:root:Training iteration 210 of 300
INFO:root:Training iteration 220 of 300
INFO:root:Training iteration 230 of 300
INFO:root:Training iteration 240 of 300
INFO:root:Training iteration 250 of 300
INFO:root:Training iteration 260 of 300
INFO:root:Training iteration 270 of 300
INFO:root:Training iteration 280 of 300
INFO:root:Training iteration 290 of 300
INFO:root:Saving fake samples to png: [[4.98269255e-05 3.14861177e-06 4.12205463e-06 ... 3.08888071e-09
  2.56038248e-08 3.21569985e-07]
 [3.43296066e-04 2.66075112e-05 2.26329685e-05 ... 5.83007687e-09
  4.29774119e-08 6.37947721e-07]
 [2.54455837e-04 2.81456851e-05 3.18975617e-05 ... 9.42354603e-08
  4.41803792e-07 6.04461684e-06]
 ...
 [9.63019411e-05 6.20263563e-06 9.92296373e-06 ... 4.84061990e-09
  3.32511227e-08 4.44717614e-07]
 [5.13663290e-05 4.15056684e-06 5.90219543e-06 ... 2.51597299e-10
  2.14234119e-09 6.83629935e-08]
 [1.03021135e-04 8.50947254e-06 1.23131476e-05 ... 2.31774000e-09
  1.88606144e-08 1.94293534e-07]]
DE

iter 299	train disc cost	-0.1573355793952942	time	0.3923739767074585


## Evaluate Results

### Load trained model

In [5]:
# detective = RickAndMortyDetective(model_configuration={
#     'batch_size': 64,
#     'output_dimensions': 784,
#     'train_iters': 300,
#     'load_path' : '{}/MNIST-abnormalclass-{}'.format(model_dir, abnormal_digit),
#     'plot_save_path' : model_dir
# })

ValueError: Variable ano_z already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? Originally defined at:

  File "/home/ubuntu/alpha-i/code/detective-gan-rick-and-morty/alphai_rickandmorty_oracle/model_mnist.py", line 69, in __init__
    self.ano_z = tf.get_variable('ano_z', shape=[1, self.z_dim], dtype=tf.float32, initializer=z_init)
  File "/home/ubuntu/alpha-i/code/detective-gan-rick-and-morty/alphai_rickandmorty_oracle/detective.py", line 54, in __init__
    plot_save_path=plot_save_path, load_path=load_path)
  File "<ipython-input-3-18d99b8064b3>", line 11, in <module>
    'plot_save_path' : model_dir


### Load test data


In [8]:
# Get test data
test_data_normal = test_data_source.get_train_data('NORMAL')
test_data_abnormal = test_data_source.get_train_data('ABNORMAL')
test_data = test_data_source.get_train_data('ALL')

# Ground truth for ABNORMAL data is 1 , ground truth for NORMAL data is 0
n1 = np.ones(len(test_data_abnormal.data))
n2 = np.zeros(len(test_data_normal.data))
expected_truth = np.hstack((n1, n2))

### Calculate ROC Score

In [11]:
detection_result = detective.detect(test_data)

roc_score = GANPerformanceAnalysis({}).analyse(
  detection_result=detection_result.data,
  expected_truth=expected_truth
)

print(roc_score)

INFO:root:Running detector on <alphai_watson.datasource.Sample object at 0x7f04784a7908>
Exception ignored in: <bound method RickAndMorty.__del__ of <alphai_rickandmorty_oracle.model_mnist.RickAndMorty object at 0x7f04e5450c18>>
Traceback (most recent call last):
  File "/home/ubuntu/alpha-i/code/detective-gan-rick-and-morty/alphai_rickandmorty_oracle/model_mnist.py", line 74, in __del__
    self.tf_session.close()
  File "/opt/anaconda/envs/ai/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1627, in close
    self._default_session.__exit__(None, None, None)
  File "/opt/anaconda/envs/ai/lib/python3.6/contextlib.py", line 88, in __exit__
    next(self.gen)
  File "/opt/anaconda/envs/ai/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 4345, in get_controller
    type(default))
AssertionError: Nesting violated for default stack of <class 'tensorflow.python.client.session.InteractiveSession'> objects
INFO:root:Detection completed in 1.3423072732

0.4639372268149566


### Generate classification report

In [17]:
# Save ; Compared ground truth to np.rint(detection_result.data), which rounds probability <0.5 to 0 and >0.5 to 1
clf_rep = precision_recall_fscore_support(expected_truth, np.rint(detection_result.data))
out_dict = {
             "precision" :clf_rep[0].round(2)
            ,"recall" : clf_rep[1].round(2)
            ,"f1-score" : clf_rep[2].round(2)
            ,"support" : clf_rep[3]
            }
df_out = pd.DataFrame(out_dict, index = ['NORMAL', 'ABNORMAL'])
avg_tot = (df_out.apply(lambda x: round(x.mean(), 2) if x.name != "support" else  round(x.sum(), 2)).to_frame().T)
avg_tot.index = ["avg/total"]
df_out = df_out.append(avg_tot)
print(df_out)

# Save Classification report to CSV (Optional)
# df_out.to_csv('classification_report_digit-{}.csv'.format(abnormal_digit), sep=';')

           f1-score  precision  recall  support
NORMAL         0.41       0.59    0.32  12620.0
ABNORMAL       0.42       0.33    0.60   6903.0
avg/total      0.42       0.46    0.46  19523.0
