In [3]:
from typing import Tuple
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import pandas as pd
import pickle
from tensorflow import keras
import seaborn as sns
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix, multilabel_confusion_matrix

from evaluation.serialization import Serializer
from constants import TYPES, annotation_mapping, reverse_annotation_mapping, amino_acid_mapping
from utils.encoding import sequenceToCategorical, categoricalToOneHot, categoricalToSequence, oneHotToCategorical
from utils.Dataset import Dataset
from utils.helpers import getDatasetPath

In [5]:
run_timestamp = "20211113-1438"
base_path = f"../results/{run_timestamp}/"
final_model = keras.models.load_model(base_path + f"models/final_model.h5")


test_data = Dataset(getDatasetPath()).getFolds([0])


test_x = np.array([categoricalToOneHot(sequenceToCategorical(seq, amino_acid_mapping), amino_acid_mapping) for seq in test_data["sequence"]])
predictions = final_model.predict(test_x)

test_data["prediction"] = np.array([categoricalToSequence(oneHotToCategorical(pred), reverse_annotation_mapping) for pred in predictions])

In [6]:
def getRelevantData(query: str) -> Tuple[np.ndarray, np.ndarray]:
    relevant_data = test_data.query(query)
    y_pred = np.array([sequenceToCategorical(seq, annotation_mapping) for seq in relevant_data["prediction"]])
    y_true = np.array([sequenceToCategorical(seq, annotation_mapping) for seq in relevant_data["annotation"]])

    return (y_pred, y_true)

In [7]:
cv_training_metrics = Serializer.load("cv_training_metrics")
cv_final_metrics = Serializer.load("cv_metrics")
final_training_metrics = Serializer.load("final_training_metrics")
final_metrics = Serializer.load("final_metrics")

In [111]:
(y_pred, y_true) = getRelevantData("index == index")

annotation_labels = [c for c in "STLIMO"]
cm = pd.DataFrame(
    confusion_matrix(y_true.flatten(), y_pred.flatten()),
    index = annotation_labels,
    columns = annotation_labels
)

In [125]:
# Compute metrics per class
n = cm.values.sum()
tp = np.diag(cm)
fp = np.array([np.sum(cm[label]) - cm[label][label] for label in annotation_labels])
fn = np.array([np.sum(cm[:][label]) - cm[label][label] for label in annotation_labels])
tn = np.array([cm.drop(label, axis=0).drop(label, axis=1).values.sum() for label in annotation_labels])

In [132]:
cm["S"]

S    13110
T       52
L      337
I      448
M      236
O      558
Name: S, dtype: int64

In [129]:
for i, label in enumerate(annotation_labels):
    print(label)
    prec = (fp[i])
    print("Precision: " + prec)
    print("Recall: " + rec)

True

In [113]:
prec = (tp.sum() + fp.sum()) / n
print(tp.sum())
print(fp.sum())
print(fn.sum())
print(tn.sum())
print(n)

275735
11545
11545
1424855
287280


In [None]:
# Add sum columns
cm.loc["Total", :] = cm.sum(axis=0)
cm.loc[:, "Total"] = cm.sum(axis=1)
cm = cm.astype(int)

In [45]:
cm

Unnamed: 0,S,T,L,I,M,O,Total
S,13110,210,412,471,360,784,15347
T,52,3014,33,91,0,101,3291
L,337,57,5442,31,0,96,5963
I,448,145,42,199473,619,2022,202749
M,236,0,64,648,5634,680,7262
O,558,106,89,2282,571,49062,52668
Total,14741,3532,6082,202996,7184,52745,287280
