# Electrocardiogram Analysis using ECG-FM

The electrocardiogram (ECG) is a low-cost, non-invasive diagnostic test that has been ubiquitous in the assessment and management of cardiovascular disease for decades. ECG-FM is a pretrained, open foundation model for ECG analysis.

In this tutorial, we will introduce how to perform inference for multi-label classification using a finetuned ECG-FM model. Specifically, we will take a model finetuned on the [PhysioNet 2021 v1.0.3 dataset](https://physionet.org/content/challenge-2021/1.0.3/) and perform inference on a sample of the [CODE-15% v1.0.0 dataset](https://zenodo.org/records/4916206/) to show how to adapt the predictions to a new set of labels.

## Overview
0. Installation
1. Prepare checkpoints
2. Prepare data
3. Run inference
4. Interpret results

## 0. Installation

ECG-FM was developed in collaboration with the [fairseq_signals](https://github.com/Jwoo5/fairseq-signals) framework, which implements a collection of deep learning methods for ECG analysis.

Clone [fairseq_signals](https://github.com/Jwoo5/fairseq-signals) and refer to the requirements and installation section in the top-level README. After following those steps, install `pandas` and make the environment accessible within this notebook by running:
```
python3 -m pip install --user pandas
python3 -m pip install --user --upgrade jupyterlab ipywidgets ipykernel
python3 -m ipykernel install --user --name ecg_fm
```

In [1]:
import os
import pandas as pd
import torch

from fairseq_signals.utils.store import MemmapReader



In [2]:
root = '/home/aa2650/datasets/code_15'
fairseq_signals_root = '/home/aa2650/playground/fairseq-signals'
fairseq_signals_root = fairseq_signals_root.rstrip('/')
fairseq_signals_root

'/home/aa2650/playground/fairseq-signals'

## 1. Prepare checkpoints

In [3]:
from huggingface_hub import hf_hub_download

_ = hf_hub_download(
    repo_id='wanglab/ecg-fm-preprint',
    filename='physionet_finetuned.pt',
    local_dir=os.path.join(root, 'notebooks/ckpts'),
)
_ = hf_hub_download(
    repo_id='wanglab/ecg-fm-preprint',
    filename='physionet_finetuned.yaml',
    local_dir=os.path.join(root, 'notebooks/ckpts'),
)

physionet_finetuned.pt:   0%|          | 0.00/1.08G [00:00<?, ?B/s]

physionet_finetuned.yaml:   0%|          | 0.00/3.56k [00:00<?, ?B/s]

In [4]:
assert os.path.isfile(os.path.join(root, 'notebooks/ckpts/physionet_finetuned.pt'))
assert os.path.isfile(os.path.join(root, 'notebooks/ckpts/physionet_finetuned.yaml'))

## 2. Prepare data

The model being used was finetuned on the [PhysioNet 2021 v1.0.3 dataset](https://physionet.org/content/challenge-2021/1.0.3/). To simplify this tutorial, we have processed a sample of 10 ECGs (14 5s segments) from the [CODE-15% v1.0.0 dataset](https://zenodo.org/records/4916206/) so that we may demonstrate how to adapt the predictions to a new set of labels.

If looking to perform inference on a full dataset (or using your own dataset), refer to the flexible, end-to-end, multi-source data preprocessing pipeline described [here](https://github.com/Jwoo5/fairseq-signals/tree/master/scripts/preprocess/ecg). Its README is useful for understanding how the data is organized. There are preprocessing scripts implemented for several datasets.

### Update manifest

The segmented split must be saved with absolute file paths, so we will update the current relative file paths accordingly.

In [None]:
import os

CODE_15_ROOT="/home/aa2650/playground/ECG-FM/experiment"

segmenting_cmd = f"""
cd /home/aa2650/playground/fairseq-signals/scripts/preprocess/ecg

python code_15_signals.py \
    --processed_root "/home/aa2650/datasets/code_15/" \
    --raw_root "/home/aa2650/datasets/code_15/" \
    --manifest_file "/home/aa2650/datasets/code_15/manifest.csv"
"""

# os.system(segmenting_cmd)

In [6]:
labels_cmd = f"""
cd /home/aa2650/playground/fairseq-signals/scripts/preprocess/ecg

python code_15_labels.py \
    --processed_root "/home/aa2650/datasets/code_15/" \
    --labels_path "/home/aa2650/datasets/code_15/labels.csv"
"""

os.system(labels_cmd)

0

In [7]:
split_cmd = f"""
cd /home/aa2650/playground/fairseq-signals/scripts/preprocess

python splits.py \
    --strategy "random" \
    --processed_root "/home/aa2650/datasets/code_15/" \
    --meta_file "/home/aa2650/datasets/code_15/meta.csv" \
    --segmented_file "/home/aa2650/datasets/code_15/segmented.csv" \
    --fractions "0.1,0.1,0.8" \
    --split_labels "train,valid,test" \
    
"""

os.system(split_cmd)



0

In [8]:
# The bottom part only needs to be run once
segmented_split = pd.read_csv(
    os.path.join('/home/aa2650/datasets/code_15/segmented_split.csv'),
    index_col='idx',
)
# segmented_split['path'] = ('/home/aa2650/datasets/code_15/segmented/') + segmented_split['path']
# segmented_split.to_csv(os.path.join('/home/aa2650/datasets/code_15/segmented_split.csv'))

In [9]:
os.path.join(root, '/segmented_split.csv')

'/segmented_split.csv'

Run the follow commands togenerate the `test.tsv` file used for inference.

In [10]:
generate_test_tsv = f"""
cd {fairseq_signals_root}/scripts/preprocess && \
python manifests.py \
    --split_file_paths "/home/aa2650/datasets/code_15/segmented_split.csv" \
    --save_dir "/home/aa2650/datasets/code_15/manifests/"
"""
os.system(generate_test_tsv)


0

In [11]:
assert os.path.isfile(os.path.join(root, '/home/aa2650/datasets/code_15/manifests/test.tsv'))

## 3. Run inference

Inside our environment, we can run the following command using hydra's command line interface to extract the logits for each segment. There must be an available GPU.

In [None]:
inference_cmd = f"""fairseq-hydra-inference \\
    task.data="/home/aa2650/datasets/code_15/manifests/" \\
    common_eval.path="/home/aa2650/playground/ECG-FM/ckpts/physionet_finetuned.pt" \\
    common_eval.results_path="/home/aa2650/datasets/code_15/manifests/outputs" \\
    model.num_labels=26 \\
    dataset.valid_subset="test" \\
    dataset.batch_size=10 \\
    dataset.num_workers=3 \\
    dataset.disable_validation=false \\
    distributed_training.distributed_world_size=1 \\
    distributed_training.find_unused_parameters=True \\
    --config-dir "/home/aa2650/playground/ECG-FM/ckpts/" \\
    --config-name physionet_finetuned
"""

os.system(inference_cmd)



[2025-03-25 18:48:02,153][fairseq_cli.inference][INFO] - loading model from /home/aa2650/playground/ECG-FM/notebooks/ckpts/physionet_finetuned.pt


Traceback (most recent call last):
  File "/home/aa2650/playground/fairseq-signals/fairseq_cli/hydra_inference.py", line 42, in hydra_main
    distributed_utils.call_main(cfg, pre_main, **kwargs)
  File "/home/aa2650/playground/fairseq-signals/fairseq_signals/distributed/utils.py", line 137, in call_main
    main(cfg, **kwargs)
  File "/home/aa2650/playground/fairseq-signals/fairseq_cli/inference.py", line 79, in main
    model, saved_cfg, task = checkpoint_utils.load_model_and_task(
  File "/home/aa2650/playground/fairseq-signals/fairseq_signals/utils/checkpoint_utils.py", line 335, in load_model_and_task
    raise IOError("Model file not found: {}".format(filename))
OSError: Model file not found: /home/aa2650/playground/ECG-FM/notebooks/ckpts/physionet_finetuned.pt

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.


256

In [13]:
assert os.path.isfile("/home/aa2650/datasets/code_15/manifests/outputs/outputs_test.npy")
assert os.path.isfile("/home/aa2650/datasets/code_15/manifests/outputs/outputs_test_header.pkl")

## 4. Interpret results

The logits are ordered same as the samples in the manifest and labels in the label definition.

### Get predictions on PhysioNet 2021 labels

In [14]:
physionet2021_label_def = pd.read_csv(
    os.path.join('/home/aa2650/playground/ECG-FM/data/physionet2021/labels/label_def.csv'),
     index_col='name',
)
physionet2021_label_names = physionet2021_label_def.index
physionet2021_label_def

Unnamed: 0_level_0,pos_count_all,pos_percent_all
name,Unnamed: 1_level_1,Unnamed: 2_level_1
AF,5230,0.060793
AFL,8271,0.096142
BBB,490,0.005696
Brady,283,0.00329
CLBBB|LBBB,1487,0.017285
CRBBB|RBBB,4794,0.055725
IAVB,3516,0.04087
IRBBB,1854,0.021551
LAD,7614,0.088505
LAnFB,2179,0.025329


In [15]:
# Load the array of computed logits
logits = MemmapReader.from_header("/home/aa2650/datasets/code_15/manifests/outputs/outputs_test.npy")[:]
logits.shape

(477310, 26)

In [16]:
# Construct predictions from logits
pred = pd.DataFrame(
    torch.sigmoid(torch.tensor(logits)).numpy(),
    columns=physionet2021_label_names,
)

# Join in sample information
pred = segmented_split.reset_index().join(pred, how='left').set_index('idx')
pred

Unnamed: 0_level_0,save_file,split,path,sample_size,AF,AFL,BBB,Brady,CLBBB|LBBB,CRBBB|RBBB,...,PR,PRWP,PVC|VPB,QAb,RAD,SA,SB,STach,TAb,TInv
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
125533.0,code_15_1000001.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.000042,0.000051,3.048507e-09,1.655384e-10,5.507722e-06,0.951498,...,6.518553e-07,1.655419e-13,3.255876e-06,3.009818e-06,1.778772e-04,1.686691e-02,0.000053,0.013136,0.065433,0.000691
125533.0,code_15_1000001.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.000002,0.000037,8.141519e-08,1.286161e-10,1.499412e-07,0.986049,...,1.223556e-08,2.447557e-11,1.311539e-05,5.539162e-07,3.539704e-09,7.462935e-05,0.000005,0.000038,0.002156,0.000054
220450.0,code_15_1000010.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.000130,0.000027,5.680386e-07,1.672731e-08,7.913556e-05,0.000073,...,5.719350e-07,2.439094e-14,3.706787e-04,5.181918e-06,1.001005e-12,2.111190e-05,0.000008,0.042590,0.000905,0.002513
220450.0,code_15_1000010.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.000438,0.000003,2.339043e-10,5.205239e-07,2.115561e-04,0.000369,...,1.039152e-07,5.417878e-16,4.737802e-06,2.988595e-06,1.414997e-12,8.293831e-06,0.000034,0.014083,0.000227,0.000249
276248.0,code_15_1000026.mat,test,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.000254,0.000003,3.895278e-13,1.217283e-04,1.570111e-05,0.003992,...,2.588438e-09,6.000090e-12,3.876199e-05,2.026912e-05,4.860023e-15,2.125675e-07,0.000001,0.000024,0.000987,0.000069
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
314494.0,code_15_999980.mat,test,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.002127,0.000166,5.427211e-11,3.871467e-08,9.678038e-05,0.009578,...,1.466605e-04,1.631258e-07,9.356161e-01,5.365901e-05,9.447211e-08,6.674079e-04,0.000004,0.001266,0.075940,0.027872
131802.0,code_15_999992.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.000415,0.000008,2.727484e-05,1.140014e-11,2.378474e-07,0.994584,...,1.947574e-06,1.293249e-12,6.228571e-07,1.966089e-03,3.317626e-06,8.009389e-05,0.000463,0.000050,0.045062,0.130763
131802.0,code_15_999992.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.000296,0.000043,2.341222e-06,2.705314e-12,6.083722e-06,0.999943,...,2.818748e-06,7.804447e-14,3.491405e-06,1.473442e-05,5.170162e-08,2.528839e-03,0.000017,0.000027,0.001761,0.077288
20493.0,code_15_999993.mat,train,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.000374,0.000003,5.783631e-15,2.409852e-04,1.131458e-04,0.038443,...,3.538421e-09,5.775553e-14,6.717935e-02,3.000936e-08,1.032446e-12,1.853731e-04,0.000559,0.000011,0.002317,0.000380


In [17]:
# Perform a (crude) thresholding of 0.5 for all labels
pred_thresh = pred.copy()
pred_thresh[physionet2021_label_names] = pred_thresh[physionet2021_label_names] > 0.5

# Construct a readable column of predicted labels for each sample
pred_thresh['labels'] = pred_thresh[physionet2021_label_names].apply(
    lambda row: ', '.join(row.index[row]),
    axis=1,
)
pred_thresh['labels']

idx
125533.0    CRBBB|RBBB, IRBBB, NSR
125533.0           CRBBB|RBBB, NSR
220450.0                          
220450.0                       NSR
276248.0                          
                     ...          
314494.0    LQT, PAC|SVPB, PVC|VPB
131802.0         CRBBB|RBBB, IRBBB
131802.0         CRBBB|RBBB, IRBBB
20493.0                        NSR
114944.0                 IAVB, NSR
Name: labels, Length: 477310, dtype: object

### Map predictions to CODE-15 labels

In [18]:
code_15_label_def = pd.read_csv("/home/aa2650/playground/ECG-FM/data/code_15/labels/label_def.csv",
     index_col='name',
)
code_15_label_names = code_15_label_def.index
code_15_label_def

Unnamed: 0_level_0,pos_count_all,pos_percent_all
name,Unnamed: 1_level_1,Unnamed: 2_level_1
is_male,138528,0.402691
1dAVb,5699,0.016567
RBBB,9652,0.028058
LBBB,6011,0.017474
SB,5588,0.016244
ST,7571,0.022008
AF,7008,0.020372
normal_ecg,134497,0.390973


In [19]:
label_mapping = {
    'CRBBB|RBBB': 'RBBB',
    'CLBBB|LBBB': 'LBBB',
    'SB': 'SB',
    'STach': 'ST',
    'AF': 'AF',
}

physionet2021_label_def['name_mapped'] = physionet2021_label_def.index.map(label_mapping)
physionet2021_label_def

Unnamed: 0_level_0,pos_count_all,pos_percent_all,name_mapped
name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
AF,5230,0.060793,AF
AFL,8271,0.096142,
BBB,490,0.005696,
Brady,283,0.00329,
CLBBB|LBBB,1487,0.017285,LBBB
CRBBB|RBBB,4794,0.055725,RBBB
IAVB,3516,0.04087,
IRBBB,1854,0.021551,
LAD,7614,0.088505,
LAnFB,2179,0.025329,


In [20]:
pred_mapped = pred.copy()
pred_mapped.drop(set(physionet2021_label_names) - set(label_mapping.keys()), axis=1, inplace=True)
pred_mapped.rename(label_mapping, axis=1, inplace=True)
pred_mapped

Unnamed: 0_level_0,save_file,split,path,sample_size,AF,LBBB,RBBB,SB,ST
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
125533.0,code_15_1000001.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.000042,5.507722e-06,0.951498,0.000053,0.013136
125533.0,code_15_1000001.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.000002,1.499412e-07,0.986049,0.000005,0.000038
220450.0,code_15_1000010.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.000130,7.913556e-05,0.000073,0.000008,0.042590
220450.0,code_15_1000010.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.000438,2.115561e-04,0.000369,0.000034,0.014083
276248.0,code_15_1000026.mat,test,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.000254,1.570111e-05,0.003992,0.000001,0.000024
...,...,...,...,...,...,...,...,...,...
314494.0,code_15_999980.mat,test,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.002127,9.678038e-05,0.009578,0.000004,0.001266
131802.0,code_15_999992.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.000415,2.378474e-07,0.994584,0.000463,0.000050
131802.0,code_15_999992.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.000296,6.083722e-06,0.999943,0.000017,0.000027
20493.0,code_15_999993.mat,train,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.000374,1.131458e-04,0.038443,0.000559,0.000011


In [21]:
pred_thresh_mapped = pred_thresh.copy()
pred_thresh_mapped.drop(set(physionet2021_label_names) - set(label_mapping.keys()), axis=1, inplace=True)
pred_thresh_mapped.rename(label_mapping, axis=1, inplace=True)
pred_thresh_mapped['predicted'] = pred_thresh_mapped[label_mapping.values()].apply(
    lambda row: ', '.join(row.index[row]),
    axis=1,
)
pred_thresh_mapped

KeyboardInterrupt: 

### Compare predicted CODE-15 to actual

In [None]:
code_15_labels = pd.read_csv(os.path.join('/home/aa2650/datasets/code_15/labels.csv'), index_col='idx')
code_15_labels['actual'] = code_15_labels[label_mapping.values()].apply(
    lambda row: ', '.join(row.index[row]),
    axis=1,
)
code_15_labels

Unnamed: 0_level_0,is_male,1dAVb,RBBB,LBBB,SB,ST,AF,normal_ecg,actual
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
0,True,False,False,False,False,False,False,True,
1,True,False,False,False,False,False,False,False,
2,True,False,False,False,False,False,True,False,AF
3,True,False,False,False,False,False,False,True,
4,True,False,False,False,False,False,False,False,
...,...,...,...,...,...,...,...,...,...
345774,True,False,False,False,False,False,False,True,
345775,False,False,False,False,False,False,False,False,
345776,False,False,False,False,False,False,False,False,
345777,False,False,False,False,False,False,False,False,


In [None]:
# Visualize predicted and actual labels side-by-side
pred_thresh_mapped[['predicted']].join(code_15_labels[['actual']], how='left')

# Calculate accuracy
# Compare predicted and actual labels
comparison = pred_thresh_mapped[['predicted']].join(code_15_labels[['actual']], how='left')

# Calculate overall accuracy (exact match)
accuracy = (comparison['predicted'] == comparison['actual']).mean()
print(f"Overall accuracy: {accuracy:.2%}")

Overall accuracy: 78.31%


In [None]:
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import MultiLabelBinarizer

y_true_str = comparison['actual']
y_pred_str = comparison['predicted']

y_true_list = [labels.split(", ") for labels in y_true_str]
y_pred_list = [labels.split(", ") for labels in y_pred_str]

mlb = MultiLabelBinarizer()
y_true_bin = mlb.fit_transform(y_true_list)
y_pred_bin = mlb.transform(y_pred_list)

class_names = mlb.classes_

results = []

for i, cls_name in enumerate(class_names):
    y_true_col = y_true_bin[:, i]
    y_pred_col = y_pred_bin[:, i]
    
    tn, fp, fn, tp = confusion_matrix(y_true_col, y_pred_col).ravel()
    total = tp + tn + fp + fn
    prevalence = (tp + fn) / total
    recall = tp / (tp + fn)
    precision = tp / (tp + fp)
    f1 = 2 * (precision * recall) / (precision + recall)
    specificity = tn / (tn + fp)
    npv = tn / (tn + fn)
    accuracy = (tp + tn) / total

    results.append({
        'class': cls_name,
        'prevalence': round(prevalence, 3),
        'recall': round(recall, 3),
        'precision': round(precision, 3),
        'f1': round(f1, 3),
        'specificity': round(specificity, 3),
        'accuracy': round(accuracy, 3),
        'npv': round(npv, 3)
    })

metrics_df = pd.DataFrame(results)
print(metrics_df)

  class  prevalence  recall  precision     f1  specificity  accuracy    npv
0             0.905   0.786      0.986  0.875        0.894     0.796  0.306
1    AF       0.020   0.868      0.613  0.719        0.989     0.986  0.997
2  LBBB       0.018   0.922      0.413  0.570        0.976     0.976  0.999
3  RBBB       0.027   0.965      0.222  0.361        0.907     0.909  0.999
4    SB       0.016   0.630      0.110  0.187        0.917     0.912  0.993
5    ST       0.021   0.763      0.548  0.638        0.986     0.981  0.995
