In [1]:
import ast
from functools import partial
from typing import Optional

import pandas as pd
from sklearn.preprocessing import MultiLabelBinarizer

In [2]:
ptbxl = pd.read_csv(r"../data/raw/ptbxl_database.csv", index_col="ecg_id")

In [3]:
unique_diagnosis = ptbxl.scp_codes.unique()
unique_diagnosis, len(unique_diagnosis)

(array(["{'NORM': 100.0, 'LVOLT': 0.0, 'SR': 0.0}",
        "{'NORM': 80.0, 'SBRAD': 0.0}", "{'NORM': 100.0, 'SR': 0.0}", ...,
        "{'IMI': 100.0, 'ISCLA': 50.0, 'ABQRS': 0.0, 'SVARR': 0.0}",
        "{'IMI': 80.0, 'ISCLA': 100.0, 'PVC': 100.0, 'ABQRS': 0.0, 'SR': 0.0}",
        "{'NDT': 100.0, 'PVC': 100.0, 'VCLVH': 0.0, 'STACH': 0.0}"],
       dtype=object),
 5466)

In [4]:
# Convert dict string to dict object
ptbxl.scp_codes = ptbxl.scp_codes.apply(ast.literal_eval)

In [5]:
ptbxl.scp_codes

ecg_id
1                 {'NORM': 100.0, 'LVOLT': 0.0, 'SR': 0.0}
2                             {'NORM': 80.0, 'SBRAD': 0.0}
3                               {'NORM': 100.0, 'SR': 0.0}
4                               {'NORM': 100.0, 'SR': 0.0}
5                               {'NORM': 100.0, 'SR': 0.0}
                               ...                        
21833    {'NDT': 100.0, 'PVC': 100.0, 'VCLVH': 0.0, 'ST...
21834             {'NORM': 100.0, 'ABQRS': 0.0, 'SR': 0.0}
21835                           {'ISCAS': 50.0, 'SR': 0.0}
21836                           {'NORM': 100.0, 'SR': 0.0}
21837                           {'NORM': 100.0, 'SR': 0.0}
Name: scp_codes, Length: 21837, dtype: object

In [6]:
def probs_to_tuple(probs: dict[str, int], threshold: int = 20) -> Optional[tuple[str]]:
    """
    Convert dict of diagnoses and their probabilities to
    tuple of diagnoses with probabilities >= given threshold.
    If result include diagnose with "NORM" or empty, return NA for later drop
    """

    result = tuple([key for key, value in probs.items() if value >= threshold])

    is_diagnose_with_norm = ("NORM" in result) and (len(result) > 1)

    if not result or is_diagnose_with_norm:
        return None

    return result

In [7]:
# Some tests for above function

assert probs_to_tuple({"NORM": 100, "1": 19, "2": 10}) == ("NORM",)
assert probs_to_tuple({"NORM": 40, "1": 50, "2": 100}) == None
assert probs_to_tuple({"NORM": 40, "1": 50, "2": 20}) == None
assert probs_to_tuple({"1": 50, "2": 20}) == ("1", "2")

In [8]:
probs_to_tuple_15 = partial(probs_to_tuple, threshold=100)

In [9]:
ptbxl["diagnoses"] = ptbxl.scp_codes.apply(probs_to_tuple_15)
ptbxl.diagnoses

ecg_id
1           (NORM,)
2              None
3           (NORM,)
4           (NORM,)
5           (NORM,)
            ...    
21833    (NDT, PVC)
21834       (NORM,)
21835          None
21836       (NORM,)
21837       (NORM,)
Name: diagnoses, Length: 21837, dtype: object

In [10]:
ptbxl.diagnoses.unique(), len(ptbxl.diagnoses.unique())

(array([('NORM',), None, ('AFLT',), ..., ('LVH', 'ISC_', '2AVB'),
        ('IMI', 'NDT', '1AVB'), ('ISCIL', 'RAO/RAE')], dtype=object),
 1229)

In [11]:
len(ptbxl[ptbxl.diagnoses.isna()])

4083

In [12]:
# Almost half of the dataset is NORM

In [13]:
scp_statements = pd.read_csv(r"../data/raw/scp_statements.csv", index_col=0)
scp_statements.head()

Unnamed: 0,description,diagnostic,form,rhythm,diagnostic_class,diagnostic_subclass,Statement Category,SCP-ECG Statement Description,AHA code,aECG REFID,CDISC Code,DICOM Code
NDT,non-diagnostic T abnormalities,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,non-diagnostic T abnormalities,,,,
NST_,non-specific ST changes,1.0,1.0,,STTC,NST_,Basic roots for coding ST-T changes and abnorm...,non-specific ST changes,145.0,MDC_ECG_RHY_STHILOST,,
DIG,digitalis-effect,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,suggests digitalis-effect,205.0,,,
LNGQT,long QT-interval,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,long QT-interval,148.0,,,
NORM,normal ECG,1.0,,,NORM,NORM,Normal/abnormal,normal ECG,1.0,,,F-000B7


In [14]:
class_to_superclass_mapping = dict(zip(
    scp_statements.index, scp_statements.diagnostic_class
))

len(class_to_superclass_mapping)

71

In [15]:
def aggregate_diagnostic(
        diagnoses: Optional[tuple[str]], mapping: dict[str, str]
) -> Optional[tuple[Optional[str]]]:
    """
    Return values of encountered keys from the given mapping.
    """

    if not diagnoses:
        return None

    superclasses = tuple({
        superclass
        if pd.notna(superclass := mapping.get(diagnose))
        else None
        for diagnose in diagnoses
    })

    if None in superclasses:
        return None

    return superclasses

In [16]:
print(aggregate_diagnostic(("DIG", "NDT", None), class_to_superclass_mapping))

None


In [17]:
aggregate_diagnostic_class_to_superclass = partial(
    aggregate_diagnostic,
    mapping=class_to_superclass_mapping,
)

ptbxl["superclass"] = ptbxl.diagnoses.apply(aggregate_diagnostic_class_to_superclass)
ptbxl.superclass.unique(), len(ptbxl.superclass.unique())

(array([('NORM',), None, ('STTC',), ('HYP',), ('CD',), ('CD', 'STTC'),
        ('MI',), ('STTC', 'HYP'), ('CD', 'HYP'), ('CD', 'MI', 'STTC'),
        ('CD', 'MI'), ('STTC', 'MI', 'HYP'), ('HYP', 'STTC'),
        ('CD', 'STTC', 'HYP'), ('MI', 'STTC'), ('MI', 'HYP'),
        ('CD', 'HYP', 'STTC'), ('MI', 'HYP', 'STTC'),
        ('STTC', 'CD', 'MI', 'HYP'), ('CD', 'MI', 'HYP'),
        ('CD', 'STTC', 'MI', 'HYP'), ('CD', 'MI', 'HYP', 'STTC')],
       dtype=object),
 22)

In [18]:
classes = tuple(scp_statements[scp_statements.diagnostic_class.notna()].index)
classes, len(classes)

(('NDT',
  'NST_',
  'DIG',
  'LNGQT',
  'NORM',
  'IMI',
  'ASMI',
  'LVH',
  'LAFB',
  'ISC_',
  'IRBBB',
  '1AVB',
  'IVCD',
  'ISCAL',
  'CRBBB',
  'CLBBB',
  'ILMI',
  'LAO/LAE',
  'AMI',
  'ALMI',
  'ISCIN',
  'INJAS',
  'LMI',
  'ISCIL',
  'LPFB',
  'ISCAS',
  'INJAL',
  'ISCLA',
  'RVH',
  'ANEUR',
  'RAO/RAE',
  'EL',
  'WPW',
  'ILBBB',
  'IPLMI',
  'ISCAN',
  'IPMI',
  'SEHYP',
  'INJIN',
  'INJLA',
  'PMI',
  '3AVB',
  'INJIL',
  '2AVB'),
 44)

In [19]:
# classes = tuple(scp_statements.index)
# classes, len(classes) # should be 71

In [20]:
superclasses = list(scp_statements.diagnostic_class.unique())
superclasses = tuple(filter(lambda diagnose: isinstance(diagnose, str), superclasses))
superclasses, len(superclasses)

(('STTC', 'NORM', 'MI', 'HYP', 'CD'), 5)

In [21]:
ptbxl.head()

Unnamed: 0_level_0,patient_id,age,sex,height,weight,nurse,site,device,recording_date,report,...,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr,diagnoses,superclass
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,15709.0,56.0,1,,63.0,2.0,0.0,CS-12 E,1984-11-09 09:17:34,sinusrhythmus periphere niederspannung,...,", I-V1,",,,,,3,records100/00000/00001_lr,records500/00000/00001_hr,"(NORM,)","(NORM,)"
2,13243.0,19.0,0,,70.0,2.0,0.0,CS-12 E,1984-11-14 12:55:37,sinusbradykardie sonst normales ekg,...,,,,,,2,records100/00000/00002_lr,records500/00000/00002_hr,,
3,20372.0,37.0,1,,69.0,2.0,0.0,CS-12 E,1984-11-15 12:49:10,sinusrhythmus normales ekg,...,,,,,,5,records100/00000/00003_lr,records500/00000/00003_hr,"(NORM,)","(NORM,)"
4,17014.0,24.0,0,,82.0,2.0,0.0,CS-12 E,1984-11-15 13:44:57,sinusrhythmus normales ekg,...,,,,,,3,records100/00000/00004_lr,records500/00000/00004_hr,"(NORM,)","(NORM,)"
5,17448.0,19.0,1,,70.0,2.0,0.0,CS-12 E,1984-11-17 10:43:15,sinusrhythmus normales ekg,...,,,,,,4,records100/00000/00005_lr,records500/00000/00005_hr,"(NORM,)","(NORM,)"


In [22]:
ptbxl.dropna(inplace=True, subset=["diagnoses", "superclass"])
ptbxl.head()

Unnamed: 0_level_0,patient_id,age,sex,height,weight,nurse,site,device,recording_date,report,...,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr,diagnoses,superclass
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1


In [23]:
classes_mlb = MultiLabelBinarizer()
superclasses_mlb = MultiLabelBinarizer()

classes_mlb.fit([classes])
superclasses_mlb.fit([superclasses])

MultiLabelBinarizer()

In [24]:
print(classes_mlb.classes_)
print(superclasses_mlb.classes_)

['1AVB' '2AVB' '3AVB' 'ALMI' 'AMI' 'ANEUR' 'ASMI' 'CLBBB' 'CRBBB' 'DIG'
 'EL' 'ILBBB' 'ILMI' 'IMI' 'INJAL' 'INJAS' 'INJIL' 'INJIN' 'INJLA' 'IPLMI'
 'IPMI' 'IRBBB' 'ISCAL' 'ISCAN' 'ISCAS' 'ISCIL' 'ISCIN' 'ISCLA' 'ISC_'
 'IVCD' 'LAFB' 'LAO/LAE' 'LMI' 'LNGQT' 'LPFB' 'LVH' 'NDT' 'NORM' 'NST_'
 'PMI' 'RAO/RAE' 'RVH' 'SEHYP' 'WPW']
['CD' 'HYP' 'MI' 'NORM' 'STTC']


In [25]:
ptbxl["mlb_diagnose"] = [tuple(diagnose) for diagnose in classes_mlb.transform(ptbxl.diagnoses.to_numpy())]
print(ptbxl.mlb_diagnose)

Series([], Name: mlb_diagnose, dtype: float64)


In [26]:
ptbxl["mlb_superclass"] = [tuple(superclass) for superclass in superclasses_mlb.transform(ptbxl.superclass.to_numpy())]
print(ptbxl.mlb_superclass)

Series([], Name: mlb_superclass, dtype: float64)


In [27]:
ptbxl.tail()

Unnamed: 0_level_0,patient_id,age,sex,height,weight,nurse,site,device,recording_date,report,...,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr,diagnoses,superclass,mlb_diagnose,mlb_superclass
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1


In [28]:
train = ptbxl[ptbxl.strat_fold < 9]
validation = ptbxl[ptbxl.strat_fold == 9]
test = ptbxl[ptbxl.strat_fold == 10]

In [29]:
len(train), len(validation), len(test)

(0, 0, 0)