## Load libraries

In [1]:
import os
import pandas as pd
import numpy as np
from sklearn.metrics import precision_recall_fscore_support

from alphai_watson.performance import GANPerformanceAnalysis
from alphai_watson.transformer import NullTransformer
from alphai_rickandmorty_oracle.datasource.mnist import MNISTDataSource
from alphai_rickandmorty_oracle.detective import RickAndMortyDetective
from alphai_rickandmorty_oracle.model_mnist import RickAndMorty

  from ._conv import register_converters as _register_converters
  return f(*args, **kwds)
DEBUG:matplotlib:CACHEDIR=/home/ubuntu/.cache/matplotlib
DEBUG:matplotlib.font_manager:Using fontManager instance from /home/ubuntu/.cache/matplotlib/fontList.json
DEBUG:matplotlib.backends:backend agg version v2.2


Enabling weight norm
Uppercase local vars:
	BATCH_SIZE: 50
	CRITIC_ITERS: 5
	DEFAULT_FIT_EPOCHS: 1000
	DEFAULT_LEARN_RATE: 0.0001
	DEFAULT_TRAIN_ITERS: 5000
	DEFAULT_Z_DIM: 128
	DIAGNOSIS_LEARN_RATE: 0.01
	DIM: 64
	DISC_FILTER_SIZE: 5
	LAMBDA: 10
	LAMBDA_2: 2.0
	OUTPUT_DIM: 784


## Define MNIST Datasource

In [2]:
file_path = '../../tests/resources'

abnormal_digit = 0

# Train and test data file
train_data_file = os.path.join(file_path, 'mnist_data_train_abnormalclass-{}.hd5'.format(abnormal_digit))
test_data_file = os.path.join(file_path, 'mnist_data_test_abnormalclass-{}.hd5'.format(abnormal_digit))

# Model parameters
n_sensors = 28
n_timesteps = 784 // n_sensors

train_data_source = MNISTDataSource(source_file=train_data_file, 
                                    transformer=NullTransformer(number_of_timesteps=n_timesteps,
                                                                number_of_sensors=n_sensors))
test_data_source = MNISTDataSource(source_file=test_data_file,
                                   transformer=NullTransformer(number_of_timesteps=n_timesteps,
                                                               number_of_sensors=n_sensors))

train_data = train_data_source.get_train_data('NORMAL')

DEBUG:root:Start file parsing
DEBUG:root:end file parsing
DEBUG:root:Start file parsing
DEBUG:root:end file parsing


## Define Model

In [3]:
model_dir = './mnist_models'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

batch_size = 64
output_dimensions = 784
train_iters = 500
plot_save_path = model_dir

model = RickAndMorty(batch_size=batch_size, 
                     output_dimensions=output_dimensions, 
                     train_iters=train_iters,
                     plot_save_path=plot_save_path)

detective = RickAndMortyDetective(model_configuration={
    'model': model,
    'batch_size': batch_size,
    'output_dimensions': output_dimensions,
    'train_iters': train_iters,
    'save_path' : '{}/MNIST-abnormalclass-{}'.format(model_dir, abnormal_digit),
    'plot_save_path' : plot_save_path
})

detective.train(train_data)

DEBUG:root:Starting session
DEBUG:root:Start training loop...
INFO:root:Initialising Model
INFO:root:Training iteration 0 of 500
DEBUG:matplotlib.font_manager:findfont: Matching :family=sans-serif:style=normal:variant=normal:weight=normal:stretch=normal:size=10.0 to DejaVu Sans ('/opt/anaconda/envs/ai/lib/python3.6/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans.ttf') with score of 0.050000


iter 0	train disc cost	4.175182342529297	time	1.471437931060791
iter 1	train disc cost	3.9061384201049805	time	0.5063550472259521
iter 2	train disc cost	3.3714137077331543	time	0.39545321464538574
iter 3	train disc cost	3.286400079727173	time	0.39770030975341797
iter 4	train disc cost	2.917823553085327	time	0.3981661796569824


INFO:root:Training iteration 10 of 500
INFO:root:Training iteration 20 of 500
INFO:root:Training iteration 30 of 500
INFO:root:Training iteration 40 of 500
INFO:root:Training iteration 50 of 500
INFO:root:Training iteration 60 of 500
INFO:root:Training iteration 70 of 500
INFO:root:Training iteration 80 of 500
INFO:root:Training iteration 90 of 500
INFO:root:Saving fake samples to png: [[3.6590439e-01 2.0064412e-01 2.2247709e-01 ... 1.8403502e-05
  4.3979974e-04 1.0389431e-02]
 [3.6786816e-01 1.9464168e-01 2.2587135e-01 ... 1.8845632e-05
  6.1430095e-04 1.1284796e-02]
 [3.5417131e-01 1.8135026e-01 2.1309792e-01 ... 1.5082329e-05
  4.3484016e-04 1.1254670e-02]
 ...
 [3.7132490e-01 2.0138793e-01 2.2871152e-01 ... 1.9919158e-05
  4.0194459e-04 1.4225681e-02]
 [3.8211232e-01 2.3122169e-01 2.6105803e-01 ... 6.7956571e-05
  1.2835683e-03 1.7909879e-02]
 [3.8765338e-01 2.5124362e-01 2.6587242e-01 ... 7.2952018e-05
  1.2529304e-03 2.2525899e-02]]
INFO:root:Training iteration 100 of 500


iter 99	train disc cost	-0.10305924713611603	time	0.3980048505883468


INFO:root:Training iteration 110 of 500
INFO:root:Training iteration 120 of 500
INFO:root:Training iteration 130 of 500
INFO:root:Training iteration 140 of 500
INFO:root:Training iteration 150 of 500
INFO:root:Training iteration 160 of 500
INFO:root:Training iteration 170 of 500
INFO:root:Training iteration 180 of 500
INFO:root:Training iteration 190 of 500
INFO:root:Saving fake samples to png: [[1.82593502e-02 1.21762720e-03 1.02231523e-03 ... 1.64841950e-07
  1.44603746e-07 1.23878664e-04]
 [1.84627231e-02 1.25326158e-03 1.12420146e-03 ... 1.08136724e-07
  1.22157161e-07 9.43169871e-05]
 [1.38942627e-02 8.05405667e-04 7.11829402e-04 ... 9.05294684e-08
  8.63299618e-08 9.86486411e-05]
 ...
 [1.78632550e-02 1.23435154e-03 1.06488972e-03 ... 1.26109143e-07
  1.23452566e-07 1.34308037e-04]
 [2.38237530e-02 2.66049756e-03 2.41387403e-03 ... 9.62696163e-07
  1.21864230e-06 3.62134393e-04]
 [2.46977340e-02 2.72607757e-03 2.30587879e-03 ... 7.52619144e-07
  6.18390345e-07 3.40026512e-04]]
IN

iter 199	train disc cost	-0.09602370858192444	time	0.3898245167732239


INFO:root:Training iteration 210 of 500
INFO:root:Training iteration 220 of 500
INFO:root:Training iteration 230 of 500
INFO:root:Training iteration 240 of 500
INFO:root:Training iteration 250 of 500
INFO:root:Training iteration 260 of 500
INFO:root:Training iteration 270 of 500
INFO:root:Training iteration 280 of 500
INFO:root:Training iteration 290 of 500
INFO:root:Saving fake samples to png: [[1.2616682e-04 9.2172268e-06 4.2545576e-06 ... 1.2342832e-09
  7.1833389e-10 6.2240701e-07]
 [1.2507998e-04 9.7687989e-06 4.9492269e-06 ... 8.2685081e-10
  6.9700767e-10 4.6209942e-07]
 [6.7839632e-05 4.6997757e-06 2.2851850e-06 ... 5.0295207e-10
  4.2566042e-10 3.8583869e-07]
 ...
 [1.2484168e-04 1.0123809e-05 4.9501327e-06 ... 7.9530665e-10
  6.1686123e-10 6.0367023e-07]
 [2.3056871e-04 3.3919721e-05 1.8642464e-05 ... 1.5431477e-08
  1.6574484e-08 3.9200067e-06]
 [2.0853025e-04 2.7070591e-05 1.3555482e-05 ... 9.3362704e-09
  6.7054646e-09 2.9007347e-06]]
INFO:root:Training iteration 300 of 50

iter 299	train disc cost	-0.15806829929351807	time	0.39258101224899294


INFO:root:Training iteration 310 of 500
INFO:root:Training iteration 320 of 500
INFO:root:Training iteration 330 of 500
INFO:root:Training iteration 340 of 500
INFO:root:Training iteration 350 of 500
INFO:root:Training iteration 360 of 500
INFO:root:Training iteration 370 of 500
INFO:root:Training iteration 380 of 500
INFO:root:Training iteration 390 of 500
INFO:root:Saving fake samples to png: [[6.5214967e-06 1.1285581e-06 5.1809610e-07 ... 4.8095364e-11
  6.1417711e-11 4.8849937e-08]
 [5.5073747e-06 1.0383102e-06 5.2269468e-07 ... 3.1440441e-11
  6.2029902e-11 3.5039825e-08]
 [2.3246071e-06 4.0648894e-07 1.9286152e-07 ... 1.6636281e-11
  3.5477975e-11 2.6076915e-08]
 ...
 [1.0192445e-05 2.0462285e-06 1.0287082e-06 ... 3.2015082e-11
  5.2672519e-11 4.7861526e-08]
 [9.2403470e-06 3.1483296e-06 1.6982580e-06 ... 9.6510400e-10
  2.3360696e-09 4.5467416e-07]
 [1.2791144e-05 3.8813751e-06 1.9629636e-06 ... 4.7657250e-10
  7.4118001e-10 2.8243051e-07]]
INFO:root:Training iteration 400 of 50

iter 399	train disc cost	-0.1799074113368988	time	0.3946001148223877


INFO:root:Training iteration 410 of 500
INFO:root:Training iteration 420 of 500
INFO:root:Training iteration 430 of 500
INFO:root:Training iteration 440 of 500
INFO:root:Training iteration 450 of 500
INFO:root:Training iteration 460 of 500
INFO:root:Training iteration 470 of 500
INFO:root:Training iteration 480 of 500
INFO:root:Training iteration 490 of 500
INFO:root:Saving fake samples to png: [[1.5769763e-06 4.6158851e-07 2.4898793e-07 ... 1.1538572e-11
  2.6810233e-11 1.7124965e-08]
 [1.5969948e-06 5.1380403e-07 3.0642107e-07 ... 7.6636119e-12
  2.7309934e-11 1.2468958e-08]
 [7.3049659e-07 2.2393375e-07 1.2927948e-07 ... 3.7187601e-12
  1.5354823e-11 8.6514209e-09]
 ...
 [2.6833559e-06 9.2102783e-07 5.3548160e-07 ... 7.5889763e-12
  2.2995736e-11 1.6645195e-08]
 [2.4162291e-06 1.3812421e-06 8.7966345e-07 ... 2.8326178e-10
  1.1866772e-09 1.8616582e-07]
 [4.9948508e-06 2.5287932e-06 1.5136678e-06 ... 1.2514119e-10
  3.6341918e-10 1.0745638e-07]]
DEBUG:root:Training complete.


iter 499	train disc cost	-0.19283579289913177	time	0.38950712203979493


## Evaluate Results

### Load trained model

In [4]:
# detective = RickAndMortyDetective(model_configuration={
#     'batch_size': 64,
#     'output_dimensions': 784,
#     'train_iters': 300,
#     'load_path' : '{}/MNIST-abnormalclass-{}'.format(model_dir, abnormal_digit),
#     'plot_save_path' : model_dir
# })

### Load test data


In [5]:
# Get test data
test_data_normal = test_data_source.get_train_data('NORMAL')
test_data_abnormal = test_data_source.get_train_data('ABNORMAL')
test_data = test_data_source.get_train_data('ALL')

# Ground truth for ABNORMAL data is 1 , ground truth for NORMAL data is 0
n1 = np.ones(len(test_data_abnormal.data))
n2 = np.zeros(len(test_data_normal.data))
expected_truth = np.hstack((n1, n2))

### Calculate ROC Score

In [6]:
detection_result = detective.detect(test_data)

roc_score = GANPerformanceAnalysis({}).analyse(
  detection_result=detection_result.data,
  expected_truth=expected_truth
)

print(roc_score)

INFO:root:Running detector on <alphai_watson.datasource.Sample object at 0x7fa37868e390>
INFO:root:Detection completed in 1.0699957609176636


0.4709560176528132


### Generate classification report

In [7]:
# Save ; Compared ground truth to np.rint(detection_result.data), which rounds probability <0.5 to 0 and >0.5 to 1
clf_rep = precision_recall_fscore_support(expected_truth, np.rint(detection_result.data))
out_dict = {
             "precision" :clf_rep[0].round(2)
            ,"recall" : clf_rep[1].round(2)
            ,"f1-score" : clf_rep[2].round(2)
            ,"support" : clf_rep[3]
            }
df_out = pd.DataFrame(out_dict, index = ['NORMAL', 'ABNORMAL'])
avg_tot = (df_out.apply(lambda x: round(x.mean(), 2) if x.name != "support" else  round(x.sum(), 2)).to_frame().T)
avg_tot.index = ["avg/total"]
df_out = df_out.append(avg_tot)
print(df_out)

# Save Classification report to CSV (Optional)
# df_out.to_csv('classification_report_digit-{}.csv'.format(abnormal_digit), sep=';')

           f1-score  precision  recall  support
NORMAL         0.33       0.57    0.23  12620.0
ABNORMAL       0.44       0.33    0.69   6903.0
avg/total      0.38       0.45    0.46  19523.0
