In [2]:
from typing import Tuple

import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm
import warnings

sns.set_theme()
from tensorflow import keras

from constants import TRAINING_PARTITIONS, ALL_PARTITIONS, annotation_mapping, amino_acid_mapping, reverse_annotation_mapping
from constants import TYPES, KINGDOMS, METRIC_KINGDOMS, METRIC_TYPES
from metrics.metrics import *
from utils.Dataset import Dataset
from utils.helpers import getDatasetPath
from utils.encoding import categoricalToSequence, oneHotToCategorical, sequenceToCategorical, categoricalToOneHot
from serialization import Serializer

# Load models

In [4]:
run_timestamp = "20211113-0233"
base_path = f"../results/{run_timestamp}/"
cv_models = {}
for k in TRAINING_PARTITIONS:
    cv_models[k] = keras.models.load_model(base_path + f"models/holdout-fold_{k}_final.h5")

# Bootstrap

## Functions

In [6]:
def getRelevantData(df: pd.DataFrame, query: str) -> Tuple[np.ndarray, np.ndarray]:
    relevant_data = df.query(query)
    y_pred = np.array([sequenceToCategorical(seq, annotation_mapping) for seq in relevant_data["prediction"]])
    y_true = np.array([sequenceToCategorical(seq, annotation_mapping) for seq in relevant_data["annotation"]])

    return (y_pred, y_true)

In [8]:
def binarize(array: np.ndarray, label: int) -> np.ndarray:
    binarized = array.copy()
    binarized[array == label] = 1
    binarized[array != label] = 0

    return binarized

In [10]:
def bootstrap(test_data: pd.DataFrame, num_bootstraps: int):
    bootstraps = {i: test_data.sample(frac=1, replace=True, axis=0) for i in range(num_bootstraps)}

    metrics = [MCC(), Recall("macro"), Precision("macro")]
    metrics_dict = {}
    for i in tqdm(range(num_bootstraps), leave=False, desc="Excluding overall metrics"):
        metrics_dict[i] = {}
        for metric in metrics:
            # All sequences
            y_pred, y_true = getRelevantData(bootstraps[i], "index == index")
            metrics_dict[i][metric.name] = {
                "overall": {
                    "overall": metric(y_true, y_pred)
                }
            }

            # By kingdom
            for kingdom in KINGDOMS:
                y_pred, y_true = getRelevantData(bootstraps[i], f"kingdom == '{kingdom}'")
                metrics_dict[i][metric.name]["overall"][kingdom] = metric(y_true, y_pred)

    # Manually exclude non-existing labels (L & T) from eukarya
    for i in tqdm(range(num_bootstraps), leave=False, desc="Excluding non-existent pathways"):
        y_pred, y_true = getRelevantData(bootstraps[i], "kingdom == 'EUKARYA'")
        metrics_dict[i]["precision"]["overall"]["EUKARYA"] = precision_score(y_true.flatten(), y_pred.flatten(), average="macro", labels=[0,3,4,5])
        metrics_dict[i]["recall"]["overall"]["EUKARYA"] = recall_score(y_true.flatten(), y_pred.flatten(), average="macro", labels=[0,3,4,5])

    # By label
    for i in tqdm(range(num_bootstraps), leave=False, desc="Computing per-label metrics"):
        for label in range(6):
            y_true, y_pred = getRelevantData(bootstraps[i], f"index == index")
            y_true = binarize(y_true, label)
            y_pred = binarize(y_pred, label)

            metrics_dict[i]["mcc"][label] = {"overall": matthews_corrcoef(y_true.flatten(), y_pred.flatten())}
            metrics_dict[i]["precision"][label] = {"overall": precision_score(y_true.flatten(), y_pred.flatten())}
            metrics_dict[i]["recall"][label] = {"overall": recall_score(y_true.flatten(), y_pred.flatten())}

    # By label x kingdom
    for i in tqdm(range(num_bootstraps), leave=False, desc="Computing per-label & kingdom metrics"):
        for label in range(6):
            for kingdom in KINGDOMS:
                y_true, y_pred = getRelevantData(bootstraps[i], f"kingdom == '{kingdom}'")
                y_true = binarize(y_true, label)
                y_pred = binarize(y_pred, label)

                metrics_dict[i]["mcc"][label][kingdom] = matthews_corrcoef(y_true.flatten(), y_pred.flatten())
                metrics_dict[i]["precision"][label][kingdom] = precision_score(y_true.flatten(), y_pred.flatten())
                metrics_dict[i]["recall"][label][kingdom] = recall_score(y_true.flatten(), y_pred.flatten())

    return metrics_dict

## Execution

In [11]:
NUM_BOOTSTRAPS_PER_FOLD = 256

cv_metrics_dict = {}

warnings.filterwarnings("ignore")
dataset = Dataset(getDatasetPath())
for k, model in tqdm(cv_models.items(), desc="Fold progress"):
    test_data = dataset.getFolds([k])

    test_x = np.array([categoricalToOneHot(sequenceToCategorical(seq, amino_acid_mapping), amino_acid_mapping) for seq in test_data["sequence"]])
    predictions = model.predict(test_x)
    y_true = np.array([sequenceToCategorical(seq, annotation_mapping) for seq in test_data["annotation"]])
    y_pred = np.array([oneHotToCategorical(pred) for pred in predictions])

    test_data["prediction"] = np.array([categoricalToSequence(pred, reverse_annotation_mapping) for pred in y_pred])

    cv_metrics_dict[k] = bootstrap(test_data, NUM_BOOTSTRAPS_PER_FOLD)


Computing overall metrics...


100%|██████████| 256/256 [02:38<00:00,  1.61it/s]


Exlcluding non-existent pathways from eukarya...


100%|██████████| 256/256 [00:28<00:00,  8.89it/s]


Computing per-label metrics...


100%|██████████| 256/256 [05:41<00:00,  1.33s/it]


Computing per-label & kingdom metrics...


100%|██████████| 256/256 [05:12<00:00,  1.22s/it]


Computing overall metrics...


100%|██████████| 256/256 [02:35<00:00,  1.65it/s]


Exlcluding non-existent pathways from eukarya...


100%|██████████| 256/256 [00:29<00:00,  8.79it/s]


Computing per-label metrics...


100%|██████████| 256/256 [05:34<00:00,  1.31s/it]


Computing per-label & kingdom metrics...


100%|██████████| 256/256 [05:13<00:00,  1.23s/it]


Computing overall metrics...


100%|██████████| 256/256 [02:34<00:00,  1.66it/s]


Exlcluding non-existent pathways from eukarya...


100%|██████████| 256/256 [00:28<00:00,  8.89it/s]


Computing per-label metrics...


100%|██████████| 256/256 [05:31<00:00,  1.30s/it]


Computing per-label & kingdom metrics...


100%|██████████| 256/256 [05:07<00:00,  1.20s/it]


Computing overall metrics...


100%|██████████| 256/256 [02:35<00:00,  1.64it/s]


Exlcluding non-existent pathways from eukarya...


100%|██████████| 256/256 [00:29<00:00,  8.70it/s]


Computing per-label metrics...


100%|██████████| 256/256 [05:42<00:00,  1.34s/it]


Computing per-label & kingdom metrics...


100%|██████████| 256/256 [05:17<00:00,  1.24s/it]


## Convert to DataFrame

In [None]:
labels = {0: 'S', 1: 'T', 2: 'L', 3: 'I', 4: 'M', 5: 'O', "overall": "overall"}

cv_final_metrics = pd.DataFrame([
    (partition, bootstrap, metric, label_text, kingdom, cv_metrics_dict[partition][bootstrap][metric][label][kingdom])
    for bootstrap in range(NUM_BOOTSTRAPS_PER_FOLD)
    for partition in TRAINING_PARTITIONS
    for metric in ["mcc", "precision", "recall"]
    for label, label_text in labels.items()
    for kingdom in METRIC_KINGDOMS
])

cv_final_metrics.columns = ["holdout_fold", "bootstrap", "metric", "label", "kingdom", "value"]
cv_final_metrics.set_index(["holdout_fold", "bootstrap", "metric", "label", "kingdom"])

In [14]:
cv_final_metrics

Unnamed: 0,holdout_fold,bootstrap,metric,label,kingdom,value
0,1,0,mcc,S,EUKARYA,0.908898
1,1,0,mcc,S,ARCHAEA,0.696838
2,1,0,mcc,S,POSITIVE,0.769908
3,1,0,mcc,S,NEGATIVE,0.621653
4,1,0,mcc,S,overall,0.840507
...,...,...,...,...,...,...
107515,4,255,recall,overall,EUKARYA,0.869244
107516,4,255,recall,overall,ARCHAEA,0.846012
107517,4,255,recall,overall,POSITIVE,0.804781
107518,4,255,recall,overall,NEGATIVE,0.869732


# Store results

In [15]:
Serializer.save(cv_final_metrics, "cv_metrics")