In [1]:
import pandas as pd
import numpy as np
import torch

from fairseq_signals.data.ecg.raw_ecg_dataset import FileECGDataset, NpECGDataset, DataframeECGDataset
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score
import torchmetrics.functional.classification as Fc


  warn(


In [2]:
def try_catch_fn(fn, default_val=np.nan):
    try:
        return fn()
    except Exception as e: 
        print('Error occured when evaluation metrics: ', e)
        return default_val
    

def metrics_headers(y_labels, categories, single=True, group=True):
    headers = []
    if single: 
        for label in y_labels:
            for m in ["auroc", "auprc", "f1score"]:
                headers += [f"{label}_{m}"]

    if group: 
        for category, elts in categories.items():
            idx = [y_labels.index(s) for s in elts if s in elts]
            num_labels = len(idx)
            if num_labels == 0: continue # No items in this category
            for m in ["auroc", "auprc", "f1score"]:
                for avg in ["micro", "macro"]:
                    headers += [f"{category}_{avg}_{m}"]
    return headers

def calculate_metrics(y_true, y_pred, device, y_labels, categories, single=True, group=True):
    y_pred = np.where(np.isnan(y_pred), 0, y_pred)
    y_pred_thresholded = (y_pred > 0.5).astype(int)
    t_y_true = torch.tensor(y_true, device=device, dtype=torch.long)
    t_y_pred = torch.tensor(y_pred, device=device)

    scores = {"tm": {}, "sk": {}}
    if single:
        for id, y_label in enumerate(y_labels):
            t_pred = t_y_pred[:, id]
            t_target = t_y_true[:, id]
            pred = y_pred[:, id]
            target = y_true[:, id]
            pred_threshold = y_pred_thresholded[:, id]
            scores["tm"][f"{y_label}_auroc"] = try_catch_fn(lambda: Fc.binary_auroc(t_pred, t_target))
            scores["sk"][f"{y_label}_auroc"] = try_catch_fn(lambda: roc_auc_score(target, pred))
            scores["tm"][f"{y_label}_auprc"] = try_catch_fn(lambda: Fc.binary_average_precision(t_pred, t_target))
            scores["sk"][f"{y_label}_auprc"] = try_catch_fn(lambda: average_precision_score(target, pred))
            scores["tm"][f"{y_label}_f1score"] = try_catch_fn(lambda: Fc.binary_f1_score(t_pred, t_target))
            scores["sk"][f"{y_label}_f1score"] = try_catch_fn(lambda: f1_score(target, pred_threshold,))
            
    if group:
        for category, elts in categories.items():
            idx = [y_labels.index(s) for s in elts if s in elts]
            num_labels = len(idx)
            if num_labels == 0: continue # No items in this category
            t_pred = t_y_pred[:, idx]
            t_target = t_y_true[:, idx]
            pred = y_pred[:, idx]
            target = y_true[:, idx]
            pred_threshold = y_pred_thresholded[:, idx]
            for avg in ["macro", "micro"]:
                scores["tm"][f"{category}_{avg}_auroc"] = \
                    try_catch_fn(lambda: Fc.multilabel_auroc(t_pred, t_target, num_labels, average=avg) if num_labels > 1 else Fc.binary_auroc(t_pred, t_target))
                scores["sk"][f"{category}_{avg}_auroc"] = \
                    try_catch_fn(lambda: roc_auc_score(target, pred, average=avg, ))
                scores["tm"][f"{category}_{avg}_auprc"] = \
                    try_catch_fn(lambda: Fc.multilabel_average_precision(t_pred, t_target, num_labels, average=avg) if num_labels > 1 else Fc.binary_average_precision(t_pred, t_target))
                scores["sk"][f"{category}_{avg}_auprc"] = \
                    try_catch_fn(lambda: average_precision_score(target, pred, average=avg, ))
                scores["tm"][f"{category}_{avg}_f1score"] = \
                    try_catch_fn(lambda: Fc.multilabel_f1_score(t_pred, t_target, num_labels, average=avg) if num_labels > 1 else Fc.binary_f1_score(t_pred, t_target))
                scores["sk"][f"{category}_{avg}_f1score"] = \
                    try_catch_fn(lambda: f1_score(target, pred_threshold, average=avg, ))
            
        
    return scores

def format_group_metrics(scores, categories, metrics=['auroc', 'f1score', 'auprc']):
    level1 = []
    level2 = []
    level3 = []

    for category in categories.keys():
        for m in metrics:
            for avg in ['micro', 'macro']:
                level1 = level1 + [category]        
                level2 = level2 + [f'{avg}_{m}']
                level3 = level3 + [scores[f'{category}_{avg}_{m}']]

    df = pd.DataFrame({'Category': level1, 'Metrics': level2, 'Scores': level3})
    return df
        


In [6]:
def get_pred_labels(
    header_pkl,
    header_npy,
    DatasetClass=None,
    manifest_path=None,
    parquet_path=None,
    csv_path=None,
    y_labels=None,
    sample_rate=250
):

    header = np.load(header_pkl, allow_pickle=True)
    y_pred = np.memmap(header_npy, 
        mode='r',
        shape=header['shape'],
        dtype=header['dtype']
    )

    if manifest_path is not None:
        dataset = DatasetClass(
            manifest_path=manifest_path,
            sample_rate=sample_rate,
            label=True
        )
        y_true = np.array([d['label'] for d in dataset])
    elif parquet_path is not None or csv_path is not None:
        df = pd.read_parquet(parquet_path) if csv_path is None else pd.read_csv(csv_path)
        y_true = df[y_labels].to_numpy()


    return y_pred, y_true


def get_scores(y_labels, categories, header_pkl, header_npy, csv_path=None, parquet_path=None):
    y_pred, y_true = get_pred_labels(
        header_pkl=header_pkl,
        header_npy=header_npy,
        csv_path=csv_path,
        parquet_path=parquet_path,
        y_labels=y_labels
    )
    scores = calculate_metrics(
        y_true=y_true, 
        y_pred=y_pred, 
        device="cpu", 
        y_labels=y_labels, 
        categories=categories, 
        single=True, 
        group=len(categories) > 0
    )

    return scores


def acs_scores(ft='finetune', folder=''):
    y_labels = ['Acute_Obstruction']
    categories={}
    header_pkl=f"/media/data1/achilsowa/datasets/fairseq/mhi/results/{ft}_acs/{folder}outputs_test_header.pkl"
    header_npy=f"/media/data1/achilsowa/datasets/fairseq/mhi/results/{ft}_acs/{folder}outputs_test.npy"
    csv_path="/media/data1/ravram/DeepECG_Datasets/filtered_ACS_test_df.csv"        
    return get_scores(y_labels, categories, header_pkl, header_npy, csv_path)
    
def afib_scores(ft='finetune', folder=''):
    y_labels =  ['label_2y', 'label_5y']
    categories={}
    header_pkl=f"/media/data1/achilsowa/datasets/fairseq/mhi/results/{ft}_afib/{folder}outputs_test_header.pkl"
    header_npy=f"/media/data1/achilsowa/datasets/fairseq/mhi/results/{ft}_afib/{folder}outputs_test.npy"
    parquet_path="/volume/deepecg/ecgs-data/parquet/mhi-afib-test.parquet"
    return get_scores(y_labels, categories, header_pkl, header_npy, parquet_path=parquet_path)

def fevg_scores(ft='finetune', folder=''):
    y_labels =  ['LVEF_UNDER_50', 'LVEF_EQUAL_OR_UNDER_40']
    categories={}
    header_pkl=f"/media/data1/achilsowa/datasets/fairseq/mhi/results/{ft}_fevg/{folder}outputs_test_header.pkl"
    header_npy=f"/media/data1/achilsowa/datasets/fairseq/mhi/results/{ft}_fevg/{folder}outputs_test.npy"
    csv_path="/media/data1/ravram/DeepECG_Datasets/test_filtered_echo_lite.csv"
    return get_scores(y_labels, categories, header_pkl, header_npy, csv_path=csv_path)


def lqts_scores(ft='finetune', folder=''):
    y_labels =  ['LQTS']
    categories={}
    header_pkl=f"/media/data1/achilsowa/datasets/fairseq/mhi/results/{ft}_lqts/{folder}outputs_test_header.pkl"
    header_npy=f"/media/data1/achilsowa/datasets/fairseq/mhi/results/{ft}_lqts/{folder}outputs_test.npy"
    csv_path="/media/data1/ravram/DeepECG_Datasets/mhi-lqts-test-lite.csv"
    return get_scores(y_labels, categories, header_pkl, header_npy, csv_path=csv_path)

def lqts_type_scores(ft='finetune', folder=''):
    y_labels =  ['LQTS_TYPE_1']
    categories={}
    header_pkl=f"/media/data1/achilsowa/datasets/fairseq/mhi/results/{ft}_lqts_type/{folder}outputs_test_header.pkl"
    header_npy=f"/media/data1/achilsowa/datasets/fairseq/mhi/results/{ft}_lqts_type/{folder}outputs_test.npy"
    csv_path="/media/data1/ravram/DeepECG_Datasets/mhi-lqts-type-test-lite.csv"
    return get_scores(y_labels, categories, header_pkl, header_npy, csv_path=csv_path)

scores = lqts_scores(ft='le')
scores

{'tm': {'LQTS_auroc': tensor(0.7567),
  'LQTS_auprc': tensor(0.3619),
  'LQTS_f1score': tensor(0.3644)},
 'sk': {'LQTS_auroc': 0.756658175081056,
  'LQTS_auprc': 0.36196432025058084,
  'LQTS_f1score': 0.3417085427135678}}

In [4]:
{'tm': {'LQTS_auroc': tensor(0.7483),
  'LQTS_auprc': tensor(0.2695),
  'LQTS_f1score': tensor(0.3267)},
 'sk': {'LQTS_auroc': 0.7482486104678092,
  'LQTS_auprc': 0.2696566287586327,
  'LQTS_f1score': 0.3054545454545454}}


import pandas as pd
df = pd.read_csv("/media/data1/ravram/DeepECG_Datasets/ACS_test_df.csv")
npy = df[['Acute_Obstruction']].to_numpy()
npy.shape

  df = pd.read_csv("/media/data1/ravram/DeepECG_Datasets/ACS_test_df.csv")


(8351, 1)

In [5]:
import numpy as np
header_pkl="/media/data1/achilsowa/datasets/fairseq/mhi/results/finetune_acs/outputs_test_header.pkl"
header_npy="/media/data1/achilsowa/datasets/fairseq/mhi/results/finetune_acs/outputs_test.npy"
header = np.load(header_pkl, allow_pickle=True)
y_pred = np.memmap(header_npy, 
        mode='r',
        shape=header['shape'],
        dtype=header['dtype']
    )
y_pred.shape
np.save("/media/data1/achilsowa/datasets/fairseq/mhi/results/finetune_acs/label_test.npy", npy)

In [15]:
y_labels =  ['Acute pericarditis', 'QS complex in V1-V2-V3', 'T wave inversion (anterior - V3-V4)', 'Right atrial enlargement','2nd degree AV block - mobitz 1','Left posterior fascicular block','Wolff-Parkinson-White (Pre-excitation syndrome)','Junctional rhythm','Premature ventricular complex',"rSR' in V1-V2",'Right superior axis','ST elevation (inferior - II, III, aVF)','Afib','ST elevation (anterior - V3-V4)','RV1 + SV6 > 11 mm','Sinusal','Monomorph','Delta wave','R/S ratio in V1-V2 >1','Third Degree AV Block','LV pacing','Nonspecific intraventricular conduction delay','ST depression (inferior - II, III, aVF)','Regular','Premature atrial complex','2nd degree AV block - mobitz 2','Left anterior fascicular block','Q wave (septal- V1-V2)','Prolonged QT','Left axis deviation','Left ventricular hypertrophy','ST depression (septal- V1-V2)','Supraventricular tachycardia','Atrial paced','Q wave (inferior - II, III, aVF)','no_qrs','T wave inversion (lateral -I, aVL, V5-V6)','Right bundle branch block','ST elevation (septal - V1-V2)','SV1 + RV5 or RV6 > 35 mm','Right axis deviation','RaVL > 11 mm','Polymorph','Ventricular tachycardia','QRS complex negative in III','ST depression (lateral - I, avL, V5-V6)','1st degree AV block','Lead misplacement','Q wave (posterior - V7-V9)','Atrial flutter','Ventricular paced','ST elevation (posterior - V7-V8-V9)','Ectopic atrial rhythm (< 100 BPM)','Early repolarization','Ventricular Rhythm','Irregularly irregular','Atrial tachycardia (>= 100 BPM)','R complex in V5-V6','ST elevation (lateral - I, aVL, V5-V6)','Brugada','Bi-atrial enlargement','Q wave (lateral- I, aVL, V5-V6)','ST upslopping','T wave inversion (inferior - II, III, aVF)','Regularly irregular','Bradycardia','qRS in V5-V6-I, aVL','Q wave (anterior - V3-V4)','Acute MI','ST depression (anterior - V3-V4)','Right ventricular hypertrophy','T wave inversion (septal- V1-V2)','ST downslopping','Left bundle branch block','Low voltage','U wave','Left atrial enlargement']
categories = {
    "RHYTHM": ['Ventricular tachycardia','Bradycardia','Brugada','Wolff-Parkinson-White (Pre-excitation syndrome)','Atrial flutter','Ectopic atrial rhythm (< 100 BPM)','Atrial tachycardia (>= 100 BPM)','Sinusal','Ventricular Rhythm','Supraventricular tachycardia','Junctional rhythm','Regular','Regularly irregular','Irregularly irregular','Afib','Premature ventricular complex','Premature atrial complex'],
    "CONDUCTION": ['Left anterior fascicular block','Delta wave','2nd degree AV block - mobitz 2','Left bundle branch block','Right bundle branch block','Left axis deviation','Atrial paced','Right axis deviation','Left posterior fascicular block','1st degree AV block','Right superior axis','Nonspecific intraventricular conduction delay','Third Degree AV Block','2nd degree AV block - mobitz 1','Prolonged QT','U wave','LV pacing','Ventricular paced'],
    "CHAMBER ENLARGEMENT": ['Bi-atrial enlargement','Left atrial enlargement','Right atrial enlargement','Left ventricular hypertrophy','Right ventricular hypertrophy'],
    "PERICARDITIS": ['Acute pericarditis'],
    'INFARCT, ISCHEMIA': ['Q wave (septal- V1-V2)','ST elevation (anterior - V3-V4)','Q wave (posterior - V7-V9)','Q wave (inferior - II, III, aVF)','Q wave (anterior - V3-V4)','ST elevation (lateral - I, aVL, V5-V6)','Q wave (lateral- I, aVL, V5-V6)','ST depression (lateral - I, avL, V5-V6)','Acute MI','ST elevation (septal - V1-V2)','ST elevation (inferior - II, III, aVF)','ST elevation (posterior - V7-V8-V9)','ST depression (inferior - II, III, aVF)','ST depression (anterior - V3-V4)'],
    "OTHER": ['ST downslopping','ST depression (septal- V1-V2)','R/S ratio in V1-V2 >1','RV1 + SV6 > 11 mm','Polymorph',"rSR' in V1-V2",'QRS complex negative in III','qRS in V5-V6-I, aVL','QS complex in V1-V2-V3','R complex in V5-V6','RaVL > 11 mm','T wave inversion (septal- V1-V2)','SV1 + RV5 or RV6 > 35 mm','T wave inversion (inferior - II, III, aVF)','Monomorph','T wave inversion (anterior - V3-V4)','T wave inversion (lateral -I, aVL, V5-V6)','Low voltage','Lead misplacement','ST depression (anterior - V3-V4)','Early repolarization','ST upslopping','no_qrs'],
}


y_labels = ['BP_class']
categories={}

scores = calculate_metrics(
    y_true=y_true, 
    y_pred=res, 
    device="cpu", 
    y_labels=y_labels, 
    categories=categories, 
    single=True, 
    group=True
)

In [17]:
scores['sk']

{'BP_class_auroc': 0.49949199220149587,
 'BP_class_auprc': 0.03945404688123759,
 'BP_class_f1score': 0.0}

In [7]:
df = format_group_metrics(scores['sk'], categories)
df

Unnamed: 0,Category,Metrics,Scores
0,RHYTHM,micro_auroc,0.995721
1,RHYTHM,macro_auroc,0.978057
2,RHYTHM,micro_f1score,0.933316
3,RHYTHM,macro_f1score,0.603175
4,RHYTHM,micro_auprc,0.981149
5,RHYTHM,macro_auprc,0.680214
6,CONDUCTION,micro_auroc,0.984699
7,CONDUCTION,macro_auroc,0.974687
8,CONDUCTION,micro_f1score,0.717048
9,CONDUCTION,macro_f1score,0.485248


In [74]:
df = format_group_metrics(scores['sk'], categories)
df


Unnamed: 0,Category,Metrics,Scores
0,RHYTHM,micro_auroc,0.992124
1,RHYTHM,macro_auroc,0.959762
2,RHYTHM,micro_f1score,0.917619
3,RHYTHM,macro_f1score,0.552463
4,RHYTHM,micro_auprc,0.970385
5,RHYTHM,macro_auprc,0.631654
6,CONDUCTION,micro_auroc,0.984017
7,CONDUCTION,macro_auroc,0.973824
8,CONDUCTION,micro_f1score,0.710589
9,CONDUCTION,macro_f1score,0.48865
