In [1]:
%load_ext autoreload
%autoreload 2

In [18]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from sklearn.metrics import roc_curve, accuracy_score, classification_report, auc, roc_auc_score
from LID_project.datasets.dataset import DataGenerator

In [None]:
def val2onehot(val_array, classes):
    labels = np.zeros((len(val_array), classes))
    for ind,lbl in enumerate(val_array):
        labels[ind,lbl] = 1
    return labels

def EER(true_targets_onehot, predictions):
    '''
    Imputs :
        true_targets_onehot: one hot encoding of true values, shape (n_samples x n_classes)
        predictions: one hot encoding of models prediction, shape (n_samples x n_classes)
    Output:
        per class EER score, np vector of length (n_clases,)

    !!!!Be carefull!!! check that the target_to_class dictionary is the same for the training and the test, 
    if not you are returning incorect classes
    Use predictions = model.predict(data_gen_val, verbose=1) to get the predictions,
    and true_targets_onehot = val2onehot(data_gen_val.getTargets(), 8) to get the true targets

    '''
    scores = []
    for i in range(true_targets_onehot.shape[1]):
        fpr, tpr, threshold = roc_curve(true_targets_onehot[:,i], predictions[:,i])
        fnr = 1 - tpr
        #print(1-roc_auc_score(true_targets_onehot[:,i], predictions[:,i]))
        ballanced_thres_pos = np.nanargmin(np.absolute((fnr - fpr)))
        scores.append(fpr[ballanced_thres_pos])
    return np.array(scores)

In [86]:
model = load_model('best_of_the_best_model_lstm.h5')

In [4]:
data_gen = DataGenerator('train', batch_size=1, shuffle=False, net='lstm', feat='plp') # has to be changed with 
data_gen_test = DataGenerator('test', batch_size=1, shuffle=False, net='lstm', feat='plp') # has to be changed with 

In [None]:
data_gen_test.target_to_class = data_gen.target_to_class # making the dictionary the same

In [88]:
predictions = model.predict(data_gen_test, verbose=1) # get predictions



In [None]:
np.save(open('pred.npy', 'wb'), predictions) # save predictions in case something goes wrong

In [None]:
true_targets_onehot = val2onehot(data_gen_test.getTargets(), 8) # finding true targets

In [60]:
print(classification_report(true_targets_onehot, (predictions.T == np.max(predictions, axis=1).T).T*1, digits=3)) # get classification report

In [91]:
l = EER(true_targets_onehot, predictions)
print(l)
print(np.mean(l))