## Load libraries

In [1]:
import os
import pandas as pd
import numpy as np
from sklearn.metrics import precision_recall_fscore_support

from alphai_watson.performance import GANPerformanceAnalysis
from alphai_watson.transformer import NullTransformer
from alphai_rickandmorty_oracle.datasource.mnist import MNISTDataSource
from alphai_rickandmorty_oracle.detective import RickAndMortyDetective
from alphai_rickandmorty_oracle.model_mnist import RickAndMorty

  from ._conv import register_converters as _register_converters
  return f(*args, **kwds)
DEBUG:matplotlib:CACHEDIR=/home/ubuntu/.cache/matplotlib
DEBUG:matplotlib.font_manager:Using fontManager instance from /home/ubuntu/.cache/matplotlib/fontList.json
DEBUG:matplotlib.backends:backend agg version v2.2


Enabling weight norm
Uppercase local vars:
	BATCH_SIZE: 50
	CRITIC_ITERS: 5
	DEFAULT_FIT_EPOCHS: 1000
	DEFAULT_LEARN_RATE: 0.0001
	DEFAULT_TRAIN_ITERS: 5000
	DEFAULT_Z_DIM: 128
	DIAGNOSIS_LEARN_RATE: 0.01
	DIM: 64
	DISC_FILTER_SIZE: 5
	LAMBDA: 10
	LAMBDA_2: 2.0
	OUTPUT_DIM: 784


## Define MNIST Datasource

In [2]:
file_path = '../../tests/resources'

abnormal_digit = 0

# Train and test data file
train_data_file = os.path.join(file_path, 'mnist_data_train_abnormalclass-{}.hd5'.format(abnormal_digit))
test_data_file = os.path.join(file_path, 'mnist_data_test_abnormalclass-{}.hd5'.format(abnormal_digit))

# Model parameters
n_sensors = 28
n_timesteps = 784 // n_sensors

train_data_source = MNISTDataSource(source_file=train_data_file, 
                                    transformer=NullTransformer(number_of_timesteps=n_timesteps,
                                                                number_of_sensors=n_sensors))
test_data_source = MNISTDataSource(source_file=test_data_file,
                                   transformer=NullTransformer(number_of_timesteps=n_timesteps,
                                                               number_of_sensors=n_sensors))

train_data = train_data_source.get_train_data('NORMAL')

DEBUG:root:Start file parsing
DEBUG:root:end file parsing
DEBUG:root:Start file parsing
DEBUG:root:end file parsing


In [3]:
np.random.normal(size=(2)).astype('float32')

array([-0.22402091, -0.9306218 ], dtype=float32)

## Define Model

In [None]:
model_dir = './mnist_models'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

batch_size = 64
output_dimensions = 784
train_iters = 500
plot_save_path = model_dir

model = RickAndMorty(batch_size=batch_size, 
                     output_dimensions=output_dimensions, 
                     train_iters=train_iters,
                     plot_save_path=plot_save_path)

detective = RickAndMortyDetective(model_configuration={
    'model': model,
    'batch_size': batch_size,
    'output_dimensions': output_dimensions,
    'train_iters': train_iters,
    'save_path' : '{}/MNIST-abnormalclass-{}'.format(model_dir, abnormal_digit),
    'plot_save_path' : plot_save_path
})

detective.train(train_data)

DEBUG:root:Starting session
DEBUG:root:Start training loop...
INFO:root:Initialising Model
INFO:root:Training iteration 0 of 500
DEBUG:matplotlib.font_manager:findfont: Matching :family=sans-serif:style=normal:variant=normal:weight=normal:stretch=normal:size=10.0 to DejaVu Sans ('/opt/anaconda/envs/ai/lib/python3.6/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans.ttf') with score of 0.050000


iter 0	train disc cost	4.448973178863525	time	1.5029473304748535
iter 1	train disc cost	4.078559875488281	time	0.5121018886566162
iter 2	train disc cost	3.6282434463500977	time	0.39653921127319336
iter 3	train disc cost	3.2144811153411865	time	0.3991217613220215
iter 4	train disc cost	2.8161511421203613	time	0.3966710567474365


INFO:root:Saving fake samples to png: [[4.4489777e-01 3.3024967e-01 3.0298778e-01 ... 3.8713755e-05
  1.0057024e-04 8.6436031e-04]
 [4.4287777e-01 3.4936598e-01 3.0335483e-01 ... 1.1523318e-04
  2.7655842e-04 2.0732493e-03]
 [4.2080510e-01 3.3166012e-01 2.5407052e-01 ... 1.9293455e-05
  5.3430918e-05 4.3736753e-04]
 ...
 [4.3437430e-01 3.3959484e-01 2.6679581e-01 ... 5.1089341e-05
  1.9232633e-04 1.2045383e-03]
 [4.4132638e-01 3.0945745e-01 2.5402907e-01 ... 1.1197264e-05
  6.4925989e-05 3.4866590e-04]
 [4.3560097e-01 3.1633222e-01 2.5955385e-01 ... 4.6262194e-05
  1.6314983e-04 1.4763810e-03]]
INFO:root:Training iteration 100 of 500


iter 99	train disc cost	-0.22218067944049835	time	0.39823265075683595


INFO:root:Saving fake samples to png: [[3.79183702e-02 2.03327113e-03 2.00124853e-03 ... 2.49480195e-07
  1.07168842e-06 2.22474023e-06]
 [4.70315814e-02 3.28988186e-03 2.87445448e-03 ... 1.21397920e-06
  4.57488795e-06 1.02515687e-05]
 [3.08882501e-02 1.86576496e-03 1.40997313e-03 ... 1.08692355e-07
  4.29164288e-07 8.87650742e-07]
 ...
 [2.58488152e-02 1.46688230e-03 1.19498896e-03 ... 5.12603492e-07
  2.94917504e-06 6.11738460e-06]
 [2.61848271e-02 1.14410440e-03 1.00085442e-03 ... 3.92370545e-08
  3.87484789e-07 3.87008726e-07]
 [3.95275727e-02 1.95313338e-03 1.69305189e-03 ... 4.37539541e-07
  2.30128967e-06 6.50285892e-06]]
INFO:root:Training iteration 200 of 500


iter 199	train disc cost	-0.1313803642988205	time	0.39064417362213133


INFO:root:Saving fake samples to png: [[2.8243230e-04 1.6987351e-05 1.2207779e-05 ... 3.6036621e-10
  9.4901544e-09 1.6202771e-08]
 [3.6671275e-04 2.6121745e-05 1.6468655e-05 ... 3.3105367e-09
  6.2246286e-08 1.2402785e-07]
 [1.7804011e-04 1.3447468e-05 6.8548011e-06 ... 1.0548542e-10
  3.6842918e-09 6.2038010e-09]
 ...
 [1.0084806e-04 8.0837926e-06 4.8154225e-06 ... 1.2428658e-09
  4.2814428e-08 7.7799029e-08]
 [6.7747183e-05 3.5963369e-06 2.2597510e-06 ... 2.4845927e-11
  1.9463833e-09 1.2929158e-09]
 [2.5689817e-04 1.3056019e-05 7.8546536e-06 ... 9.2804642e-10
  2.5639551e-08 5.8758950e-08]]
INFO:root:Training iteration 300 of 500


iter 299	train disc cost	-0.15441879630088806	time	0.3929048252105713


INFO:root:Saving fake samples to png: [[7.8635430e-06 7.0277033e-06 4.9657700e-07 ... 1.4894560e-11
  3.1708267e-10 3.6440068e-10]
 [9.4492143e-06 9.7603242e-06 6.3952535e-07 ... 1.9128797e-10
  3.0117091e-09 4.1360604e-09]
 [3.2676401e-06 4.3497307e-06 1.9650075e-07 ... 3.9576714e-12
  1.1605189e-10 1.3282216e-10]
 ...
 [2.5825943e-06 4.4938229e-06 2.2927611e-07 ... 6.8780773e-11
  2.0602064e-09 2.6496394e-09]
 [7.8839071e-07 1.0932749e-06 4.7537647e-08 ... 6.4420046e-13
  4.0255677e-11 1.8083916e-11]
 [4.8201186e-06 4.4037497e-06 2.3408617e-07 ... 4.8743856e-11
  1.1223431e-09 1.5982106e-09]]
INFO:root:Training iteration 400 of 500


iter 399	train disc cost	-0.17418472468852997	time	0.3950315308570862


## Evaluate Results

### Load trained model

In [None]:
# detective = RickAndMortyDetective(model_configuration={
#     'batch_size': 64,
#     'output_dimensions': 784,
#     'train_iters': 300,
#     'load_path' : '{}/MNIST-abnormalclass-{}'.format(model_dir, abnormal_digit),
#     'plot_save_path' : model_dir
# })

### Load test data


In [None]:
# Get test data
test_data_normal = test_data_source.get_train_data('NORMAL')
test_data_abnormal = test_data_source.get_train_data('ABNORMAL')
test_data = test_data_source.get_train_data('ALL')

# Ground truth for ABNORMAL data is 1 , ground truth for NORMAL data is 0
n1 = np.ones(len(test_data_abnormal.data))
n2 = np.zeros(len(test_data_normal.data))
expected_truth = np.hstack((n1, n2))

### Calculate ROC Score

In [None]:
detection_result = detective.detect(test_data)

roc_score = GANPerformanceAnalysis({}).analyse(
  detection_result=detection_result.data,
  expected_truth=expected_truth
)

print('ROC Score: {}'.format(roc_score))

### Generate classification report

In [None]:
# Save ; Compared ground truth to np.rint(detection_result.data), which rounds probability <0.5 to 0 and >0.5 to 1
clf_rep = precision_recall_fscore_support(expected_truth, np.rint(detection_result.data))
out_dict = {
             "precision" :clf_rep[0].round(2)
            ,"recall" : clf_rep[1].round(2)
            ,"f1-score" : clf_rep[2].round(2)
            ,"support" : clf_rep[3]
            }
df_out = pd.DataFrame(out_dict, index = ['NORMAL', 'ABNORMAL'])
avg_tot = (df_out.apply(lambda x: round(x.mean(), 2) if x.name != "support" else  round(x.sum(), 2)).to_frame().T)
avg_tot.index = ["avg/total"]
df_out = df_out.append(avg_tot)
print(df_out)

# Save Classification report to CSV (Optional)
# df_out.to_csv('classification_report_digit-{}.csv'.format(abnormal_digit), sep=';')