In [2]:
from typing import Tuple

import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.metrics import matthews_corrcoef, recall_score, precision_score
from tqdm import tqdm
import warnings

sns.set_theme()
from tensorflow import keras

from constants import TRAINING_PARTITIONS, ALL_PARTITIONS, annotation_mapping, amino_acid_mapping, reverse_annotation_mapping
from constants import TYPES, KINGDOMS, METRIC_KINGDOMS, METRIC_TYPES
from metrics.metrics import *
from utils.Dataset import Dataset
from utils.helpers import getDatasetPath
from utils.encoding import categoricalToSequence, oneHotToCategorical, sequenceToCategorical, categoricalToOneHot
from serialization import Serializer

# Load

In [4]:
# Load model
run_timestamp = "20211113-1438"
base_path = f"../results/{run_timestamp}/"
final_model = keras.models.load_model(base_path + f"models/final_model.h5")

# Load test set
test_data = Dataset(getDatasetPath()).getFolds([0])

# Make predictions

In [6]:
test_x = np.array([categoricalToOneHot(sequenceToCategorical(seq, amino_acid_mapping), amino_acid_mapping) for seq in test_data["sequence"]])
predictions = final_model.predict(test_x)
y_true = np.array([sequenceToCategorical(seq, annotation_mapping) for seq in test_data["annotation"]])
y_pred = np.array([oneHotToCategorical(pred) for pred in predictions])

In [8]:
test_data["prediction"] = np.array([categoricalToSequence(pred, reverse_annotation_mapping) for pred in y_pred])

In [10]:
num_bootstraps = 256
warnings.filterwarnings("ignore")
bootstraps = {i: test_data.sample(frac=1, replace=True, axis=0) for i in range(num_bootstraps)}

In [11]:
bootstraps[0].columns

Index(['sequence', 'annotation', 'prediction'], dtype='object')

In [13]:
def getRelevantData(df: pd.DataFrame, query: str) -> Tuple[np.ndarray, np.ndarray]:
    relevant_data = df.query(query)
    y_pred = np.array([sequenceToCategorical(seq, annotation_mapping) for seq in relevant_data["prediction"]])
    y_true = np.array([sequenceToCategorical(seq, annotation_mapping) for seq in relevant_data["annotation"]])

    return (y_pred, y_true)

In [14]:
metrics = [MCC(), Recall("macro"), Precision("macro")]
final_metrics_dict = {}
for i in tqdm(range(num_bootstraps), desc="Computing metrics"):
    final_metrics_dict[i] = {}
    for metric in metrics:
        # All sequences
        y_pred, y_true = getRelevantData(bootstraps[i], "index == index")
        final_metrics_dict[i][metric.name] = {
            "overall": {
                "overall": metric(y_true, y_pred)
            }
        }

        # By kingdom
        for kingdom in KINGDOMS:
            y_pred, y_true = getRelevantData(bootstraps[i], f"kingdom == '{kingdom}'")
            final_metrics_dict[i][metric.name]["overall"][kingdom] = metric(y_true, y_pred)

# Manually exclude non-existing labels (L & T) from eukarya
for i in tqdm(range(num_bootstraps), desc="Excluding non-existent pathways"):
    y_pred, y_true = getRelevantData(bootstraps[i], "kingdom == 'EUKARYA'")
    final_metrics_dict[i]["precision"]["overall"]["EUKARYA"] = precision_score(y_true.flatten(), y_pred.flatten(), average="macro", labels=[0,3,4,5])
    final_metrics_dict[i]["recall"]["overall"]["EUKARYA"] = recall_score(y_true.flatten(), y_pred.flatten(), average="macro", labels=[0,3,4,5])

Computing metrics: 100%|██████████| 256/256 [02:38<00:00,  1.61it/s]
Excluding non-existent pathways: 100%|██████████| 256/256 [00:29<00:00,  8.72it/s]


In [16]:
def binarize(array: np.ndarray, label: int) -> np.ndarray:
    binarized = array.copy()
    binarized[array == label] = 1
    binarized[array != label] = 0

    return binarized

In [17]:
# By label
for i in tqdm(range(num_bootstraps), desc="Computing metrics by label"):
    for label in range(6):
        y_true, y_pred = getRelevantData(bootstraps[i], f"index == index")
        y_true = binarize(y_true, label)
        y_pred = binarize(y_pred, label)

        final_metrics_dict[i]["mcc"][label] = {"overall": matthews_corrcoef(y_true.flatten(), y_pred.flatten())}
        final_metrics_dict[i]["precision"][label] = {"overall": precision_score(y_true.flatten(), y_pred.flatten())}
        final_metrics_dict[i]["recall"][label] = {"overall": recall_score(y_true.flatten(), y_pred.flatten())}

# By label x kingdom
for i in tqdm(range(num_bootstraps), desc="Computing metrics by label & kingdom"):
    for label in range(6):
        for kingdom in KINGDOMS:
            y_true, y_pred = getRelevantData(bootstraps[i], f"kingdom == '{kingdom}'")
            y_true = binarize(y_true, label)
            y_pred = binarize(y_pred, label)

            final_metrics_dict[i]["mcc"][label][kingdom] = matthews_corrcoef(y_true.flatten(), y_pred.flatten())
            final_metrics_dict[i]["precision"][label][kingdom] = precision_score(y_true.flatten(), y_pred.flatten())
            final_metrics_dict[i]["recall"][label][kingdom] = recall_score(y_true.flatten(), y_pred.flatten())

Computing metrics by label: 100%|██████████| 256/256 [05:45<00:00,  1.35s/it]
Computing metrics by label & kingdom: 100%|██████████| 256/256 [05:21<00:00,  1.26s/it]


In [18]:
final_metrics_dict

{0: {'mcc': {'overall': {'overall': 0.9128594117885255,
    'EUKARYA': 0.8963643782523398,
    'ARCHAEA': 0.857196637384657,
    'POSITIVE': 0.8613630054558251,
    'NEGATIVE': 0.8923283217076973},
   0: {'overall': 0.8641911169963459,
    'EUKARYA': 0.8869541491652372,
    'ARCHAEA': 0.7565692383251578,
    'POSITIVE': 0.7626645188806288,
    'NEGATIVE': 0.8306917088972373},
   1: {'overall': 0.876441064874769,
    'EUKARYA': 0.0,
    'ARCHAEA': 0.9580420363628382,
    'POSITIVE': 0.7710020672126835,
    'NEGATIVE': 0.9445438456951883},
   2: {'overall': 0.9011479401229923,
    'EUKARYA': 0.0,
    'ARCHAEA': 0.753149139925573,
    'POSITIVE': 0.9287306268336135,
    'NEGATIVE': 0.9235852876511856},
   3: {'overall': 0.9410732598231368,
    'EUKARYA': 0.9206164721881849,
    'ARCHAEA': 0.9195004875772199,
    'POSITIVE': 0.8666501979942246,
    'NEGATIVE': 0.8840913406209492},
   4: {'overall': 0.7563111122782894,
    'EUKARYA': 0.7522313928793566,
    'ARCHAEA': 0.7664753202348559,
  

In [20]:
labels = {0: 'S', 1: 'T', 2: 'L', 3: 'I', 4: 'M', 5: 'O', "overall": "overall"}
final_metrics = pd.DataFrame([
    (metric.name, bootstrap, labels[label], kingdom, final_metrics_dict[bootstrap][metric.name][label][kingdom])
    for metric in metrics
    for bootstrap in range(num_bootstraps)
    for label in ["overall", *range(6)]
    for kingdom in METRIC_KINGDOMS
])

final_metrics.columns = ["metric", "bootstrap", "label", "kingdom", "value"]
final_metrics.set_index(["metric", "bootstrap", "label", "kingdom"])

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,value
metric,bootstrap,label,kingdom,Unnamed: 4_level_1
mcc,0,overall,EUKARYA,0.896364
mcc,0,overall,ARCHAEA,0.857197
mcc,0,overall,POSITIVE,0.861363
mcc,0,overall,NEGATIVE,0.892328
mcc,0,overall,overall,0.912859
...,...,...,...,...
precision,255,O,EUKARYA,0.899769
precision,255,O,ARCHAEA,0.925065
precision,255,O,POSITIVE,0.952448
precision,255,O,NEGATIVE,0.958812


# Store results

In [21]:
Serializer.save(final_metrics, "final_metrics")