In [None]:
import ast
from functools import partial

import pandas as pd
from sklearn.preprocessing import MultiLabelBinarizer

In [None]:
ptbxl = pd.read_csv(r"../data/raw/ptbxl_database.csv", index_col="ecg_id")

In [None]:
unique_diagnosis = ptbxl.scp_codes.unique()
unique_diagnosis, len(unique_diagnosis)

In [None]:
# Convert dict string to dict object
ptbxl.scp_codes = ptbxl.scp_codes.apply(ast.literal_eval)

In [None]:
ptbxl.scp_codes

In [None]:
def probs_to_tuple(probs: dict[str, int], threshold: int = 20) -> tuple[str]:
    """
    Convert dict of diagnoses and their probabilities to
    tuple of diagnoses with probabilities >= given threshold.
    Also if probabilities of diagnoses < probability of NORM, return ("NORM",)
    else tuple of diagnoses with probabilities >= probability of NORM.
    """

    norm_prob = probs.get("NORM", 0)

    result = [
        key for key, value in probs.items() if (
            value >= threshold and value >= norm_prob  and key != "NORM"
        )
    ]

    return res if (res := tuple(result)) else ("NORM",)

In [None]:
# Some tests for above function

assert probs_to_tuple({"NORM": 100, "1": 20, "2": 30}) == ("NORM",)
assert probs_to_tuple({"NORM": 40, "1": 50, "2": 100}) == ("1", "2")
assert probs_to_tuple({"NORM": 40, "1": 50, "2": 20}) == ("1",)
assert probs_to_tuple({"1": 50, "2": 20}) == ("1", "2")

In [None]:
probs_to_tuple_15 = partial(probs_to_tuple, threshold=15)

In [None]:
ptbxl["diagnoses"] = ptbxl.scp_codes.apply(probs_to_tuple_15)
ptbxl.diagnoses

In [None]:
ptbxl.diagnoses.unique(), len(ptbxl.diagnoses.unique())

In [None]:
len(ptbxl[ptbxl.diagnoses == ("NORM",)])

In [None]:
# Almost half of the dataset is NORM

In [None]:
scp_statements = pd.read_csv(r"../data/raw/scp_statements.csv", index_col=0)
scp_statements.head()

In [None]:
class_to_superclass_mapping = dict(zip(
    scp_statements.index, scp_statements.diagnostic_class
))

len(class_to_superclass_mapping)

In [None]:
def aggregate_diagnostic(diagnoses: tuple[str], mapping: dict[str, str]):
    """
    Return values of encountered keys from the given mapping
    """

    superclasses = {
        superclass for diagnose in diagnoses
        if isinstance(superclass := mapping.get(diagnose), str)
    }

    return res if (res := tuple(superclasses)) else ("NONE",)

In [None]:
aggregate_diagnostic(("DIG", "NDT", None), class_to_superclass_mapping)

In [None]:
aggregate_diagnostic_class_to_superclass = partial(
    aggregate_diagnostic,
    mapping=class_to_superclass_mapping,
)

ptbxl["superclass"] = ptbxl.diagnoses.apply(aggregate_diagnostic_class_to_superclass)
ptbxl.superclass.unique(), len(ptbxl.superclass.unique())

In [None]:
classes = tuple(scp_statements.index)
classes, len(classes) # should be 71

In [None]:
superclasses = list(scp_statements.diagnostic_class.unique()) + ["NONE"]
superclasses = tuple(filter(lambda diagnose: isinstance(diagnose, str), superclasses))
superclasses, len(superclasses)

In [None]:
classes_mlb = MultiLabelBinarizer()
superclasses_mlb = MultiLabelBinarizer()

classes_mlb.fit([classes])
superclasses_mlb.fit([superclasses])

In [None]:
print(classes_mlb.classes_)
print(superclasses_mlb.classes_)

In [None]:
ptbxl["mlb_diagnose"] = [tuple(diagnose) for diagnose in classes_mlb.transform(ptbxl.diagnoses.to_numpy())]
print(ptbxl.mlb_diagnose)

In [None]:
ptbxl["mlb_superclass"] = [tuple(superclass) for superclass in superclasses_mlb.transform(ptbxl.superclass.to_numpy())]
print(ptbxl.mlb_superclass)

In [None]:
ptbxl.tail()

In [None]:
train = ptbxl[ptbxl.strat_fold < 9]
validation = ptbxl[ptbxl.strat_fold == 9]
test = ptbxl[ptbxl.strat_fold == 10]

In [None]:
len(train), len(validation), len(test)