In [2]:
import os
import pandas as pd
import torch

from fairseq_signals.utils.store import MemmapReader



In [3]:
root = '/home/aa2650/datasets/code_15/subset'
experiment_root = '/home/aa2650/datasets/code_15/experiments/subset'
fairseq_signals_root = '/home/aa2650/playground/fairseq-signals'
fairseq_signals_root = fairseq_signals_root.rstrip('/')
fairseq_signals_root

data_split = "80-10-10"

The segmented split must be saved with absolute file paths, so we will update the current relative file paths accordingly.

In [4]:
# The bottom part only needs to be run once
segmented_path = f'/home/aa2650/datasets/code_15/subset/test_segmented_split.csv'
segmented_split = pd.read_csv(segmented_path,
    index_col='idx',
)
# segmented_split['path'] = ('/home/aa2650/datasets/code_15/segmented/') + segmented_split['path']
# segmented_split.to_csv(os.path.join('/home/aa2650/datasets/code_15/segmented_split.csv'))

Run the follow commands togenerate the `test.tsv` file used for inference.

In [5]:
each_experiment_path = os.path.join(experiment_root, data_split, "100")
os.makedirs(each_experiment_path, exist_ok=True)

model_path = f"/home/aa2650/playground/ECG-FM/experiments/subset/80-10-10/checkpoint100.pt"

print(each_experiment_path)


/home/aa2650/datasets/code_15/experiments/subset/80-10-10/100


In [6]:
# inference_cmd = f"""fairseq-hydra-inference \\
#     task.data="/home/aa2650/datasets/code_15/subset/manifests" \\
#     common_eval.path="{model_path}" \\
#     common_eval.results_path="{each_experiment_path}" \\
#     model.num_labels=6 \\
#     dataset.valid_subset="test" \\
#     dataset.batch_size=10 \\
#     dataset.num_workers=3 \\
#     dataset.disable_validation=false \\
#     distributed_training.distributed_world_size=1 \\
#     distributed_training.find_unused_parameters=True \\
#     --config-dir "/home/aa2650/playground/ECG-FM/ckpts/" \\
#     --config-name physionet_finetuned
# """

# os.system(inference_cmd)

In [7]:
assert os.path.isfile(f"{each_experiment_path}/outputs_test.npy")
assert os.path.isfile(f"{each_experiment_path}/outputs_test_header.pkl")

## 4. Interpret results

The logits are ordered same as the samples in the manifest and labels in the label definition.

### Get predictions on PhysioNet 2021 labels

In [8]:
# os.path.join('/home/aa2650/playground/ECG-FM/data/code_15/labels/label_def.csv'),

code15_label_def = pd.read_csv(
    os.path.join('/home/aa2650/datasets/code_15/subset/label_def.csv'),
     index_col='name',
)
code15_label_names = code15_label_def.index
code15_label_names

Index(['RBBB', 'LBBB', 'SB', 'ST', 'AF', 'normal_ecg'], dtype='object', name='name')

In [9]:
# Load the array of computed logits
logits = MemmapReader.from_header(f"{each_experiment_path}/outputs_test.npy")[:]
logits.shape

(5873, 6)

In [10]:
# Construct predictions from logits
pred = pd.DataFrame(
    torch.sigmoid(torch.tensor(logits)).numpy(),
    columns=code15_label_names,
)

pred = segmented_split.reset_index().join(pred, how='left').set_index('idx')
pred

Unnamed: 0_level_0,save_file,split,path,sample_size,RBBB,LBBB,SB,ST,AF,normal_ecg
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
4329,code_15_1000730.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,1.616569e-12,1.000000e+00,1.097811e-09,5.402769e-13,5.644047e-13,1.867926e-22
4329,code_15_1000730.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,3.787874e-10,1.000000e+00,3.522709e-10,7.411588e-10,7.766477e-14,2.296589e-22
29077,code_15_100123.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,4.869405e-11,1.000000e+00,2.810555e-08,1.657921e-05,1.031574e-15,1.557310e-20
12710,code_15_1001938.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,8.658945e-10,2.831584e-11,1.160468e-10,1.000000e+00,7.337080e-06,1.771482e-13
35541,code_15_1002557.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,1.309301e-14,1.229286e-09,2.842760e-18,8.499558e-16,5.832023e-13,1.000000e+00
...,...,...,...,...,...,...,...,...,...,...
15687,code_15_997787.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,1.000000e+00,6.145794e-10,8.690409e-15,7.385872e-11,4.300155e-14,5.777580e-18
15687,code_15_997787.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,1.000000e+00,1.612668e-11,8.021295e-15,4.514144e-11,3.816720e-12,6.175830e-17
37132,code_15_998911.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,1.650549e-11,4.849524e-11,8.455527e-20,1.554712e-18,7.928680e-14,1.000000e+00
418,code_15_998961.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,1.000000e+00,3.975210e-11,4.861097e-12,4.259096e-18,2.385377e-17,4.884292e-13


In [11]:
# Perform a (crude) thresholding of 0.5 for all labels
pred_thresh = pred.copy()
pred_thresh[code15_label_names] = pred_thresh[code15_label_names] > 0.5

# Construct a readable column of predicted labels for each sample
pred_thresh['labels'] = pred_thresh[code15_label_names].apply(
    lambda row: ', '.join(row.index[row]),
    axis=1,
)
pred_thresh['labels']

idx
4329           LBBB
4329           LBBB
29077          LBBB
12710            ST
35541    normal_ecg
            ...    
15687          RBBB
15687          RBBB
37132    normal_ecg
418            RBBB
37915    normal_ecg
Name: labels, Length: 5873, dtype: object

In [12]:
code_15_label_def = pd.read_csv("/home/aa2650/playground/ECG-FM/data/code_15/labels/label_def.csv",
     index_col='name',
)
code_15_label_names = code_15_label_def.index
code_15_label_def

Unnamed: 0_level_0,pos_count_all,pos_percent_all
name,Unnamed: 1_level_1,Unnamed: 2_level_1
is_male,138528,0.402691
1dAVb,5699,0.016567
RBBB,9652,0.028058
LBBB,6011,0.017474
SB,5588,0.016244
ST,7571,0.022008
AF,7008,0.020372
normal_ecg,134497,0.390973


In [13]:
label_mapping = {
    'RBBB': 'RBBB',
    'LBBB': 'LBBB',
    'SB': 'SB',
    'ST': 'ST',
    'AF': 'AF',
    'normal_ecg': 'normal_ecg'
}

code15_label_def['name_mapped'] = code15_label_def.index.map(label_mapping)
code15_label_def

Unnamed: 0_level_0,pos_count_all,pos_percent_all,name_mapped
name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
RBBB,9672,0.224518,RBBB
LBBB,6026,0.139883,LBBB
SB,5605,0.13011,SB
ST,7584,0.176049,ST
AF,7033,0.163258,AF
normal_ecg,9500,0.220525,normal_ecg


In [14]:
pred_mapped = pred.copy()
pred_mapped.drop(set(code15_label_names) - set(label_mapping.keys()), axis=1, inplace=True)
pred_mapped.rename(label_mapping, axis=1, inplace=True)
pred_mapped

Unnamed: 0_level_0,save_file,split,path,sample_size,RBBB,LBBB,SB,ST,AF,normal_ecg
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
4329,code_15_1000730.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,1.616569e-12,1.000000e+00,1.097811e-09,5.402769e-13,5.644047e-13,1.867926e-22
4329,code_15_1000730.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,3.787874e-10,1.000000e+00,3.522709e-10,7.411588e-10,7.766477e-14,2.296589e-22
29077,code_15_100123.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,4.869405e-11,1.000000e+00,2.810555e-08,1.657921e-05,1.031574e-15,1.557310e-20
12710,code_15_1001938.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,8.658945e-10,2.831584e-11,1.160468e-10,1.000000e+00,7.337080e-06,1.771482e-13
35541,code_15_1002557.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,1.309301e-14,1.229286e-09,2.842760e-18,8.499558e-16,5.832023e-13,1.000000e+00
...,...,...,...,...,...,...,...,...,...,...
15687,code_15_997787.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,1.000000e+00,6.145794e-10,8.690409e-15,7.385872e-11,4.300155e-14,5.777580e-18
15687,code_15_997787.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,1.000000e+00,1.612668e-11,8.021295e-15,4.514144e-11,3.816720e-12,6.175830e-17
37132,code_15_998911.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,1.650549e-11,4.849524e-11,8.455527e-20,1.554712e-18,7.928680e-14,1.000000e+00
418,code_15_998961.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,1.000000e+00,3.975210e-11,4.861097e-12,4.259096e-18,2.385377e-17,4.884292e-13


In [15]:
pred_thresh_mapped = pred_thresh.copy()
pred_thresh_mapped.drop(set(code15_label_names) - set(label_mapping.keys()), axis=1, inplace=True)
pred_thresh_mapped.rename(label_mapping, axis=1, inplace=True)
pred_thresh_mapped['predicted'] = pred_thresh_mapped[label_mapping.values()].apply(
    lambda row: ', '.join(row.index[row]),
    axis=1,
)
pred_thresh_mapped

Unnamed: 0_level_0,save_file,split,path,sample_size,RBBB,LBBB,SB,ST,AF,normal_ecg,labels,predicted
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
4329,code_15_1000730.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,False,True,False,False,False,False,LBBB,LBBB
4329,code_15_1000730.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,False,True,False,False,False,False,LBBB,LBBB
29077,code_15_100123.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,False,True,False,False,False,False,LBBB,LBBB
12710,code_15_1001938.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,False,False,False,True,False,False,ST,ST
35541,code_15_1002557.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,False,False,False,False,False,True,normal_ecg,normal_ecg
...,...,...,...,...,...,...,...,...,...,...,...,...
15687,code_15_997787.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,True,False,False,False,False,False,RBBB,RBBB
15687,code_15_997787.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,True,False,False,False,False,False,RBBB,RBBB
37132,code_15_998911.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,False,False,False,False,False,True,normal_ecg,normal_ecg
418,code_15_998961.mat,test,/home/aa2650/datasets/code_15/subset/segmented...,2500,True,False,False,False,False,False,RBBB,RBBB


In [16]:
true_labels = pd.read_csv(os.path.join('/home/aa2650/datasets/code_15/subset/labels.csv'), index_col='idx')
true_labels['actual'] = true_labels[label_mapping.values()].apply(
    lambda row: ', '.join(row.index[row]),
    axis=1,
)
true_labels

Unnamed: 0_level_0,RBBB,LBBB,SB,ST,AF,normal_ecg,actual
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
0,False,False,False,False,True,False,AF
1,False,False,False,False,True,False,AF
2,False,False,True,False,False,False,SB
3,True,False,False,False,False,False,RBBB
4,False,False,False,False,True,False,AF
...,...,...,...,...,...,...,...
43074,False,False,False,False,False,True,normal_ecg
43075,False,False,False,False,False,True,normal_ecg
43076,False,False,False,False,False,True,normal_ecg
43077,False,False,False,False,False,True,normal_ecg


In [17]:
# Visualize predicted and actual labels side-by-side
pred_thresh_mapped[['predicted']].join(true_labels[['actual']], how='left')

# Calculate accuracy
# Compare predicted and actual labels
comparison = pred_thresh_mapped[['predicted']].join(true_labels[['actual']], how='left')

# Calculate overall accuracy (exact match)
accuracy = (comparison['predicted'] == comparison['actual']).mean()
print(f"Overall accuracy: {accuracy:.2%}")

Overall accuracy: 91.08%


In [18]:
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import MultiLabelBinarizer

y_true_str = comparison['actual']
y_pred_str = comparison['predicted']

y_true_list = [labels.split(", ") for labels in y_true_str]
y_pred_list = [labels.split(", ") for labels in y_pred_str]

mlb = MultiLabelBinarizer()
y_true_bin = mlb.fit_transform(y_true_list)
y_pred_bin = mlb.transform(y_pred_list)

class_names = mlb.classes_

results = []

for i, cls_name in enumerate(class_names):
    y_true_col = y_true_bin[:, i]
    y_pred_col = y_pred_bin[:, i]
    
    tn, fp, fn, tp = confusion_matrix(y_true_col, y_pred_col).ravel()
    total = tp + tn + fp + fn
    prevalence = (tp + fn) / total
    sensitivity = tp / (tp + fn)
    precision = tp / (tp + fp)
    f1 = 2 * (precision * sensitivity) / (precision + sensitivity)
    specificity = tn / (tn + fp)
    npv = tn / (tn + fn)
    accuracy = (tp + tn) / total

    results.append({
        'class': cls_name,
        'prevalence': round(prevalence, 3),
        'f1': round(f1, 3),
        'accuracy': round(accuracy, 3),
    })

metrics_df = pd.DataFrame(results)
print(metrics_df)

        class  prevalence     f1  accuracy
0          AF       0.158  0.924     0.977
1        LBBB       0.147  0.948     0.985
2        RBBB       0.211  0.965     0.985
3          SB       0.134  0.908     0.975
4          ST       0.178  0.931     0.975
5  normal_ecg       0.224  0.947     0.976


