In [None]:
import numpy as np
import scipy as sp
import scipy.io as sio
import scipy.signal as sig
import pywt
import os
import glob
import itertools
import matplotlib
import pandas as pd
import re
import tensorflow as tf
from tensorflow.contrib.layers import fully_connected
import tensorflow.contrib.rnn as recurrent
import sklearn.preprocessing
import matplotlib.pyplot as plt
#
%matplotlib inline

In [None]:
from codes.pre_processing import *
from codes.segmentation import *
from codes.utils import *
from codes.training import *
from codes.model import *

A utility fuction to plot confusion matrix: 
http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html#sphx-glr-auto-examples-model-selection-plot-confusion-matrix-py

In [None]:
def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True class')
    plt.xlabel('Predicted class')

In [None]:
data_root_dir = '../training_set'
train_file_list, val_file_list = test_val_split_v2(data_root_dir, train_percentage = 90)
ref_file = os.path.join(data_root_dir, 'REFERENCE.csv')

In [None]:
df = pd.read_csv(ref_file, delimiter = ',')
#
RECORDS = pd.Series.as_matrix(df.Recording)
LABEL_1 = pd.Series.as_matrix(df.First_label)
LABEL_2 = pd.Series.as_matrix(df.Second_label)
LABEL_3 = pd.Series.as_matrix(df.Third_label)
#
N = len(RECORDS)

In [None]:
# build the graph
inputs, labels, seq_length, logits, accuracy = build_model_graph()

In [None]:
model_dir = './model'

In [None]:
# prediction
# first some placeholders to keep the results
sub_id = list()
sub_actual = list()
sub_predict = list()
#
with tf.Session() as sess:
    # load the model
    load_model(model_dir, sess)
    #
    for num in np.arange(len(val_file_list)):
        record = re.search('A[0-9]+', val_file_list[num]).group(0)
        sub_id.append(record)
        parent_label = LABEL_1[np.squeeze(np.where(RECORDS == record))]
        sub_actual.append(parent_label - 1)
        peaks, features = peak_detector_with_refinement(val_file_list[num], 'sym8', max_level, window_size, window_size_for_threshold, search_radius)
        segs, labs, lens = extract_ecg_segments_v2(peaks, val_file_list[num], parent_label, 1000)
        #segs = np.transpose(segs, axes = (0, 2, 1))
        logits_val = sess.run(logits, feed_dict = {inputs: segs, labels: labs, seq_length: lens})
        acc = sess.run(accuracy, feed_dict = {inputs: segs, labels: labs, seq_length: lens})
        sub_predict.append(prediction_v2(np.argmax(logits_val, axis = 1)))
        print('processed: ' + RECORDS[np.squeeze(np.where(RECORDS == record))])
#
data_dict = {'id': sub_id, 'actual_class': sub_actual, 'predicted_class': sub_predict}
df = pd.DataFrame(data = data_dict)   

In [None]:
from sklearn.metrics import f1_score, confusion_matrix
#
GT_labels = df['actual_class'].tolist()
predict_labels = df['predicted_class'].tolist()
#
con_mtx = confusion_matrix(GT_labels, predict_labels, )

In [None]:
print(f1_score(GT_labels, predict_labels, average = 'micro'))
print(f1_score(GT_labels, predict_labels, average = 'weighted'))
print(f1_score(GT_labels, predict_labels, average = None))

In [None]:
plt.figure()
class_names = ['Normal', 'AF', 'I-AVB', 'LBBB', 'RBBB', 'PAC', 'PVC', 'STD', 'STE']
plot_confusion_matrix(con_mtx, classes=class_names, normalize=True,
                      title='Confusion Matrix')