In [3]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn import metrics
import joblib
import datetime
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
from model.NN import NN, DatasetUtil
import time

n_days_lookahead = int(input('Please input the length of days lookahead in {5, 7, 15, 30, 45, 60, 90, 120}: '))

if(n_days_lookahead not in [5, 7, 15, 30, 45, 60, 90, 120]):
    print('Input does not meet requirements.')
    exit()

data_type = str(input('Please specify the coverage of the data {A - Manufacturer 1, B - Manufacturer 2, C - Manufacturer 1 & 2, D - Unbalanced}: '))

if(data_type not in ['A', 'B', 'C', 'D']):
    print('Input does not meet requirements.')
    exit()

dit_str = {'A': 'mc1', 'B': 'mc2', 'C': 'mc1_mc2', 'D': 'unbalanced'}

model_type = str(input('Please input the type of trained model to use {A - Manufacturer 1, B - Manufacturer 2, C - Manufacturer 1 & 2, D - Unbalanced}: '))

if(model_type not in ['A', 'B', 'C', 'D']):
    print('Input does not meet requirements.')
    exit()

def loadData():

    X = np.load('../data/' + dit_str[data_type] + '/' + str(n_days_lookahead) + '_days_lookahead/smart_test.npy',allow_pickle=True)
    y = np.load('../data/' + dit_str[data_type] + '/' + str(n_days_lookahead) + '_days_lookahead/test_labels.npy',allow_pickle=True)

    X = X.astype('float32')
    y = y.astype('float32')
    return X, y


def get_all_metrics(true, predicted, score):
    confusion_matrix = metrics.confusion_matrix(true, predicted)
    fpr_list, tpr_list, thersholds = roc_curve(true, score)
    roc_auc = auc(fpr_list, tpr_list)
    TP = confusion_matrix[0][0]
    FN = confusion_matrix[0][1]
    FP = confusion_matrix[1][0]
    TN =  confusion_matrix[1][1]
    precision_of_failed = TP / (TP + FP)
    precision_of_healthy = TN / (TN + FN)
    tpr = TP / (TP + FN)
    fpr = FP / (TN + FP)
    auc_score = roc_auc
    f1_score = 2*precision_of_failed*tpr / (precision_of_failed+tpr)
    print('precision of failed: ', precision_of_failed)
    print('precision of healthy: ', precision_of_healthy)
    print('tpr: ', tpr)
    print('fpr: ', fpr)
    print('auc: ', auc_score)
    print('f1-score: ', f1_score)
    print('roc curve: ')
    plt.plot(fpr_list, tpr_list)
    plt.xlim([-0.05, 1.05])  
    plt.ylim([-0.05, 1.05])
    plt.xlabel('FPR')
    plt.ylabel('TPR')
    plt.title('ROC Curve')
    plt.legend(loc="lower right")
    plt.show()
    
# print('-------------------- NN --------------------')
# print('------------------ Loading Data ------------------')
X, y = loadData()

input_size = 330
hidden_size = 512

output_size = 1
num_layers = 2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#     print(device)

model = NN(input_size=input_size, hidden_size=hidden_size, output_size=output_size).to(device)
model.load_state_dict(torch.load('../trained_model/'+ dit_str[model_type] + '/' + str(n_days_lookahead) + '_days_lookahead/nn.pth'))

X = X.reshape((len(X),-1))

test_dataset = DatasetUtil(X, y)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=True)

y_pred_list = []
y_true_list = []
y_score_list = []
with torch.no_grad():
    total=0
    for i, (X, y) in enumerate(test_loader):
        X = X.to(device)
        y = y.to(device)
        score = model(X)
        predicted = torch.tensor([1 if x[0] > 0.5 else 0 for x in score]).to(device)
        for j in range(0,len(y)):
            y_pred_list.append(int(predicted[j].cpu()))
            y_true_list.append(y[j].cpu())
            y_score_list.append(float(score[j][0]))

get_all_metrics(np.asarray(y_true_list).astype('int'), np.asarray(y_pred_list), np.asarray(y_score_list))


  get_all_metrics(np.asarray(y_true_list).astype('int'), np.asarray(y_pred_list), np.asarray(y_score_list))
  get_all_metrics(np.asarray(y_true_list).astype('int'), np.asarray(y_pred_list), np.asarray(y_score_list))


precision of failed:  0.8153983353151011
precision of healthy:  0.8148367952522255
tpr:  0.8146718146718147
fpr:  0.18443718443718443
auc:  0.8859693293165227
f1-score:  0.8150349130886941
1011.3034248352051


  get_all_metrics(np.asarray(y_true_list).astype('int'), np.asarray(y_pred_list), np.asarray(y_score_list))
  get_all_metrics(np.asarray(y_true_list).astype('int'), np.asarray(y_pred_list), np.asarray(y_score_list))


precision of failed:  0.7872679814385151
precision of healthy:  0.8185911868768093
tpr:  0.8279701082812262
fpr:  0.2237303644959585
auc:  0.8730761409508514
f1-score:  0.8071062216605962
912.6970767974854


  get_all_metrics(np.asarray(y_true_list).astype('int'), np.asarray(y_pred_list), np.asarray(y_score_list))
  get_all_metrics(np.asarray(y_true_list).astype('int'), np.asarray(y_pred_list), np.asarray(y_score_list))


precision of failed:  0.8054344667247894
precision of healthy:  0.8176186158960411
tpr:  0.8211851851851852
fpr:  0.19837037037037036
auc:  0.8801857229080932
f1-score:  0.8132335680751174
1101.9506454467773


  get_all_metrics(np.asarray(y_true_list).astype('int'), np.asarray(y_pred_list), np.asarray(y_score_list))
  get_all_metrics(np.asarray(y_true_list).astype('int'), np.asarray(y_pred_list), np.asarray(y_score_list))


precision of failed:  0.801417347095979
precision of healthy:  0.7810659388018962
tpr:  0.7734165923282783
fpr:  0.19164436515016353
auc:  0.860694358567359
f1-score:  0.7871680411591131
932.0840835571289


  get_all_metrics(np.asarray(y_true_list).astype('int'), np.asarray(y_pred_list), np.asarray(y_score_list))
  get_all_metrics(np.asarray(y_true_list).astype('int'), np.asarray(y_pred_list), np.asarray(y_score_list))


precision of failed:  0.7595712098009189
precision of healthy:  0.8111129651259803
tpr:  0.8281724347298117
fpr:  0.2621432908318154
auc:  0.8556033255253522
f1-score:  0.7923898046619707
918.3132648468018


  get_all_metrics(np.asarray(y_true_list).astype('int'), np.asarray(y_pred_list), np.asarray(y_score_list))
  get_all_metrics(np.asarray(y_true_list).astype('int'), np.asarray(y_pred_list), np.asarray(y_score_list))


precision of failed:  0.7876733921815889
precision of healthy:  0.7618364418938307
tpr:  0.7506384257172901
fpr:  0.20234339792699413
auc:  0.847220910608801
f1-score:  0.7687100992231367
921.1153984069824


  get_all_metrics(np.asarray(y_true_list).astype('int'), np.asarray(y_pred_list), np.asarray(y_score_list))
  get_all_metrics(np.asarray(y_true_list).astype('int'), np.asarray(y_pred_list), np.asarray(y_score_list))


precision of failed:  0.7862993298585257
precision of healthy:  0.7659427306681422
tpr:  0.7573149741824441
fpr:  0.20582329317269077
auc:  0.8462946223875915
f1-score:  0.7715350332432236
969.6543216705322
precision of failed:  0.7844763738705377
precision of healthy:  0.7480945614261724
tpr:  0.7308860060723158
fpr:  0.2008004416229644
auc:  0.8366846101390591
f1-score:  0.7567335857683789
1007.1835517883301


  get_all_metrics(np.asarray(y_true_list).astype('int'), np.asarray(y_pred_list), np.asarray(y_score_list))
  get_all_metrics(np.asarray(y_true_list).astype('int'), np.asarray(y_pred_list), np.asarray(y_score_list))
