# Electrocardiogram Analysis using ECG-FM

The electrocardiogram (ECG) is a low-cost, non-invasive diagnostic test that has been ubiquitous in the assessment and management of cardiovascular disease for decades. ECG-FM is a pretrained, open foundation model for ECG analysis.

In this tutorial, we will introduce how to perform inference for multi-label classification using a finetuned ECG-FM model. Specifically, we will take a model finetuned on the [PhysioNet 2021 v1.0.3 dataset](https://physionet.org/content/challenge-2021/1.0.3/) and perform inference on a sample of the [CODE-15% v1.0.0 dataset](https://zenodo.org/records/4916206/) to show how to adapt the predictions to a new set of labels.

## Overview
0. Installation
1. Prepare checkpoints
2. Prepare data
3. Run inference
4. Interpret results

## 0. Installation

ECG-FM was developed in collaboration with the [fairseq_signals](https://github.com/Jwoo5/fairseq-signals) framework, which implements a collection of deep learning methods for ECG analysis.

Clone [fairseq_signals](https://github.com/Jwoo5/fairseq-signals) and refer to the requirements and installation section in the top-level README. After following those steps, install `pandas` and make the environment accessible within this notebook by running:
```
python3 -m pip install --user pandas
python3 -m pip install --user --upgrade jupyterlab ipywidgets ipykernel
python3 -m ipykernel install --user --name ecg_fm
```

In [1]:
import os
import pandas as pd
import torch

from fairseq_signals.utils.store import MemmapReader



In [2]:
root = '/home/aa2650/datasets/code_15'
fairseq_signals_root = '/home/aa2650/playground/fairseq-signals'
fairseq_signals_root = fairseq_signals_root.rstrip('/')
fairseq_signals_root

'/home/aa2650/playground/fairseq-signals'

## 1. Prepare checkpoints

In [3]:
from huggingface_hub import hf_hub_download

_ = hf_hub_download(
    repo_id='wanglab/ecg-fm-preprint',
    filename='physionet_finetuned.pt',
    local_dir=os.path.join(root, 'notebooks/ckpts'),
)
_ = hf_hub_download(
    repo_id='wanglab/ecg-fm-preprint',
    filename='physionet_finetuned.yaml',
    local_dir=os.path.join(root, 'notebooks/ckpts'),
)

physionet_finetuned.pt:   0%|          | 0.00/1.08G [00:00<?, ?B/s]

physionet_finetuned.yaml:   0%|          | 0.00/3.56k [00:00<?, ?B/s]

In [4]:
assert os.path.isfile(os.path.join(root, 'notebooks/ckpts/physionet_finetuned.pt'))
assert os.path.isfile(os.path.join(root, 'notebooks/ckpts/physionet_finetuned.yaml'))

## 2. Prepare data

The model being used was finetuned on the [PhysioNet 2021 v1.0.3 dataset](https://physionet.org/content/challenge-2021/1.0.3/). To simplify this tutorial, we have processed a sample of 10 ECGs (14 5s segments) from the [CODE-15% v1.0.0 dataset](https://zenodo.org/records/4916206/) so that we may demonstrate how to adapt the predictions to a new set of labels.

If looking to perform inference on a full dataset (or using your own dataset), refer to the flexible, end-to-end, multi-source data preprocessing pipeline described [here](https://github.com/Jwoo5/fairseq-signals/tree/master/scripts/preprocess/ecg). Its README is useful for understanding how the data is organized. There are preprocessing scripts implemented for several datasets.

### Update manifest

The segmented split must be saved with absolute file paths, so we will update the current relative file paths accordingly.

In [None]:
import os

In [None]:

# Generates segmented folder with all samples to train and infer

# CODE_15_ROOT="/home/aa2650/playground/ECG-FM/experiment"

# segmenting_cmd = f"""
# cd /home/aa2650/playground/fairseq-signals/scripts/preprocess/ecg

# python code_15_signals.py \
#     --processed_root "/home/aa2650/datasets/code_15/" \
#     --raw_root "/home/aa2650/datasets/code_15/" \
#     --manifest_file "/home/aa2650/datasets/code_15/manifest.csv"
# """

# os.system(segmenting_cmd)

In [6]:
labels_cmd = f"""
cd /home/aa2650/playground/fairseq-signals/scripts/preprocess/ecg

python code_15_labels.py \
    --processed_root "/home/aa2650/datasets/code_15/" \
    --labels_path "/home/aa2650/datasets/code_15/labels.csv"
"""

os.system(labels_cmd)

0

In [3]:
split_cmd = f"""
cd /home/aa2650/playground/fairseq-signals/scripts/preprocess

python splits.py \
    --strategy "random" \
    --processed_root "/home/aa2650/datasets/code_15/" \
    --meta_file "/home/aa2650/datasets/code_15/meta.csv" \
    --segmented_file "/home/aa2650/datasets/code_15/segmented.csv" \
    --fractions "0.01,0.01,0.98" \
    --split_labels "train,valid,test" \
    
"""

os.system(split_cmd)



0

In [None]:
# The bottom part only needs to be run once
# segmented_split = pd.read_csv(
#     os.path.join('/home/aa2650/datasets/code_15/segmented_split.csv'),
#     index_col='idx',
# )
# segmented_split['path'] = ('/home/aa2650/datasets/code_15/segmented/') + segmented_split['path']
# segmented_split.to_csv(os.path.join('/home/aa2650/datasets/code_15/segmented_split.csv'))

In [9]:
os.path.join(root, '/segmented_split.csv')

'/segmented_split.csv'

Run the follow commands togenerate the `test.tsv` file used for inference.

In [4]:
generate_test_tsv = f"""
cd {fairseq_signals_root}/scripts/preprocess && \
python manifests.py \
    --split_file_paths "/home/aa2650/datasets/code_15/segmented_split.csv" \
    --save_dir "/home/aa2650/datasets/code_15/manifests/"
"""
os.system(generate_test_tsv)


0

In [11]:
assert os.path.isfile(os.path.join(root, '/home/aa2650/datasets/code_15/manifests/test.tsv'))

## 3. Run inference

Inside our environment, we can run the following command using hydra's command line interface to extract the logits for each segment. There must be an available GPU.

In [24]:
inference_cmd = f"""fairseq-hydra-inference \\
    task.data="/home/aa2650/datasets/code_15/manifests/" \\
    common_eval.path="/home/aa2650/playground/ECG-FM/ckpts/physionet_finetuned.pt" \\
    common_eval.results_path="/home/aa2650/datasets/code_15/manifests/outputs" \\
    model.num_labels=26 \\
    dataset.valid_subset="test" \\
    dataset.batch_size=10 \\
    dataset.num_workers=3 \\
    dataset.disable_validation=false \\
    distributed_training.distributed_world_size=1 \\
    distributed_training.find_unused_parameters=True \\
    --config-dir "/home/aa2650/playground/ECG-FM/ckpts/" \\
    --config-name physionet_finetuned
"""

os.system(inference_cmd)



[2025-03-25 18:48:26,929][fairseq_cli.inference][INFO] - loading model from /home/aa2650/playground/ECG-FM/ckpts/physionet_finetuned.pt
[2025-03-25 18:48:29,311][fairseq_signals.utils.checkpoint_utils][INFO] - Loaded a checkpoint in 2.38s
[2025-03-25 18:48:29,313][fairseq_cli.inference][INFO] - num. shared model params: 90,393,242 (num. trained: 90,393,242)
[2025-03-25 18:48:29,314][fairseq_cli.inference][INFO] - num. expert model params: 0 (num. trained: 0)
[2025-03-25 18:48:29,538][fairseq_cli.inference][INFO] - {'_name': None,
 'checkpoint': {'_name': None, 'save_dir': '<REDACTED>', 'restore_file': 'checkpoint_last.pt', 'finetune_from_model': None, 'reset_dataloader': False, 'reset_lr_scheduler': False, 'reset_meters': False, 'reset_optimizer': False, 'optimizer_overrides': '{}', 'save_interval': 1, 'save_interval_updates': 0, 'keep_interval_updates': -1, 'keep_interval_updates_pattern': -1, 'keep_last_epochs': 0, 'keep_best_checkpoints': -1, 'no_save': False, 'no_epoch_checkpoints'

0

In [25]:
assert os.path.isfile("/home/aa2650/datasets/code_15/manifests/outputs/outputs_test.npy")
assert os.path.isfile("/home/aa2650/datasets/code_15/manifests/outputs/outputs_test_header.pkl")

## 4. Interpret results

The logits are ordered same as the samples in the manifest and labels in the label definition.

### Get predictions on PhysioNet 2021 labels

In [26]:
physionet2021_label_def = pd.read_csv(
    os.path.join('/home/aa2650/playground/ECG-FM/data/physionet2021/labels/label_def.csv'),
     index_col='name',
)
physionet2021_label_names = physionet2021_label_def.index
physionet2021_label_def

Unnamed: 0_level_0,pos_count_all,pos_percent_all
name,Unnamed: 1_level_1,Unnamed: 2_level_1
AF,5230,0.060793
AFL,8271,0.096142
BBB,490,0.005696
Brady,283,0.00329
CLBBB|LBBB,1487,0.017285
CRBBB|RBBB,4794,0.055725
IAVB,3516,0.04087
IRBBB,1854,0.021551
LAD,7614,0.088505
LAnFB,2179,0.025329


In [27]:
# Load the array of computed logits
logits = MemmapReader.from_header("/home/aa2650/datasets/code_15/manifests/outputs/outputs_test.npy")[:]
logits.shape

(381858, 26)

In [28]:
# Construct predictions from logits
pred = pd.DataFrame(
    torch.sigmoid(torch.tensor(logits)).numpy(),
    columns=physionet2021_label_names,
)

# Join in sample information
pred = segmented_split.reset_index().join(pred, how='left').set_index('idx')
pred

Unnamed: 0_level_0,save_file,split,path,sample_size,AF,AFL,BBB,Brady,CLBBB|LBBB,CRBBB|RBBB,...,PR,PRWP,PVC|VPB,QAb,RAD,SA,SB,STach,TAb,TInv
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
125533.0,code_15_1000001.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.000254,3.347727e-06,3.895278e-13,1.217283e-04,0.000016,0.003992,...,2.588438e-09,6.000090e-12,0.000039,0.000020,4.860023e-15,2.125675e-07,1.218711e-06,0.000024,0.000987,0.000069
125533.0,code_15_1000001.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.000254,3.122761e-07,2.224263e-14,1.082924e-07,0.000041,0.087624,...,4.981121e-09,4.079006e-11,0.014940,0.000014,8.163946e-15,2.183712e-05,5.298577e-07,0.000034,0.008454,0.000098
220450.0,code_15_1000010.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.000262,1.424171e-07,6.284038e-15,9.419722e-02,0.000286,0.019146,...,2.140992e-10,8.157639e-12,0.033128,0.000002,2.333001e-10,3.773365e-04,3.230675e-02,0.001986,0.011495,0.009156
220450.0,code_15_1000010.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.000604,3.346062e-06,1.750463e-13,3.096962e-04,0.000004,0.000065,...,2.234003e-07,1.520999e-14,0.008913,0.000015,4.718596e-12,1.037054e-05,2.061068e-04,0.000818,0.049041,0.001916
276248.0,code_15_1000026.mat,test,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.000146,2.015475e-08,3.704078e-12,9.619567e-08,0.000002,0.012680,...,2.287794e-10,9.846712e-12,0.000011,0.000072,2.169309e-11,1.241030e-05,7.119846e-06,0.000191,0.000057,0.000003
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
314494.0,code_15_999980.mat,test,/home/aa2650/datasets/code_15/segmented/code_1...,2500,,,,,,,...,,,,,,,,,,
131802.0,code_15_999992.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,,,,,,,...,,,,,,,,,,
131802.0,code_15_999992.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,,,,,,,...,,,,,,,,,,
20493.0,code_15_999993.mat,train,/home/aa2650/datasets/code_15/segmented/code_1...,2500,,,,,,,...,,,,,,,,,,


In [29]:
# Perform a (crude) thresholding of 0.5 for all labels
pred_thresh = pred.copy()
pred_thresh[physionet2021_label_names] = pred_thresh[physionet2021_label_names] > 0.5

# Construct a readable column of predicted labels for each sample
pred_thresh['labels'] = pred_thresh[physionet2021_label_names].apply(
    lambda row: ', '.join(row.index[row]),
    axis=1,
)
pred_thresh['labels']

idx
125533.0       
125533.0       
220450.0       
220450.0       
276248.0    NSR
           ... 
314494.0       
131802.0       
131802.0       
20493.0        
114944.0       
Name: labels, Length: 477310, dtype: object

### Map predictions to CODE-15 labels

In [30]:
code_15_label_def = pd.read_csv("/home/aa2650/playground/ECG-FM/data/code_15/labels/label_def.csv",
     index_col='name',
)
code_15_label_names = code_15_label_def.index
code_15_label_def

Unnamed: 0_level_0,pos_count_all,pos_percent_all
name,Unnamed: 1_level_1,Unnamed: 2_level_1
is_male,138528,0.402691
1dAVb,5699,0.016567
RBBB,9652,0.028058
LBBB,6011,0.017474
SB,5588,0.016244
ST,7571,0.022008
AF,7008,0.020372
normal_ecg,134497,0.390973


In [31]:
label_mapping = {
    'CRBBB|RBBB': 'RBBB',
    'CLBBB|LBBB': 'LBBB',
    'SB': 'SB',
    'STach': 'ST',
    'AF': 'AF',
}

physionet2021_label_def['name_mapped'] = physionet2021_label_def.index.map(label_mapping)
physionet2021_label_def

Unnamed: 0_level_0,pos_count_all,pos_percent_all,name_mapped
name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
AF,5230,0.060793,AF
AFL,8271,0.096142,
BBB,490,0.005696,
Brady,283,0.00329,
CLBBB|LBBB,1487,0.017285,LBBB
CRBBB|RBBB,4794,0.055725,RBBB
IAVB,3516,0.04087,
IRBBB,1854,0.021551,
LAD,7614,0.088505,
LAnFB,2179,0.025329,


In [32]:
pred_mapped = pred.copy()
pred_mapped.drop(set(physionet2021_label_names) - set(label_mapping.keys()), axis=1, inplace=True)
pred_mapped.rename(label_mapping, axis=1, inplace=True)
pred_mapped

Unnamed: 0_level_0,save_file,split,path,sample_size,AF,LBBB,RBBB,SB,ST
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
125533.0,code_15_1000001.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.000254,0.000016,0.003992,1.218711e-06,0.000024
125533.0,code_15_1000001.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.000254,0.000041,0.087624,5.298577e-07,0.000034
220450.0,code_15_1000010.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.000262,0.000286,0.019146,3.230675e-02,0.001986
220450.0,code_15_1000010.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.000604,0.000004,0.000065,2.061068e-04,0.000818
276248.0,code_15_1000026.mat,test,/home/aa2650/datasets/code_15/segmented/code_1...,2500,0.000146,0.000002,0.012680,7.119846e-06,0.000191
...,...,...,...,...,...,...,...,...,...
314494.0,code_15_999980.mat,test,/home/aa2650/datasets/code_15/segmented/code_1...,2500,,,,,
131802.0,code_15_999992.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,,,,,
131802.0,code_15_999992.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,,,,,
20493.0,code_15_999993.mat,train,/home/aa2650/datasets/code_15/segmented/code_1...,2500,,,,,


In [33]:
pred_thresh_mapped = pred_thresh.copy()
pred_thresh_mapped.drop(set(physionet2021_label_names) - set(label_mapping.keys()), axis=1, inplace=True)
pred_thresh_mapped.rename(label_mapping, axis=1, inplace=True)
pred_thresh_mapped['predicted'] = pred_thresh_mapped[label_mapping.values()].apply(
    lambda row: ', '.join(row.index[row]),
    axis=1,
)
pred_thresh_mapped

Unnamed: 0_level_0,save_file,split,path,sample_size,AF,LBBB,RBBB,SB,ST,labels,predicted
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
125533.0,code_15_1000001.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,False,False,False,False,False,,
125533.0,code_15_1000001.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,False,False,False,False,False,,
220450.0,code_15_1000010.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,False,False,False,False,False,,
220450.0,code_15_1000010.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,False,False,False,False,False,,
276248.0,code_15_1000026.mat,test,/home/aa2650/datasets/code_15/segmented/code_1...,2500,False,False,False,False,False,NSR,
...,...,...,...,...,...,...,...,...,...,...,...
314494.0,code_15_999980.mat,test,/home/aa2650/datasets/code_15/segmented/code_1...,2500,False,False,False,False,False,,
131802.0,code_15_999992.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,False,False,False,False,False,,
131802.0,code_15_999992.mat,valid,/home/aa2650/datasets/code_15/segmented/code_1...,2500,False,False,False,False,False,,
20493.0,code_15_999993.mat,train,/home/aa2650/datasets/code_15/segmented/code_1...,2500,False,False,False,False,False,,


### Compare predicted CODE-15 to actual

In [34]:
code_15_labels = pd.read_csv(os.path.join('/home/aa2650/datasets/code_15/labels.csv'), index_col='idx')
code_15_labels['actual'] = code_15_labels[label_mapping.values()].apply(
    lambda row: ', '.join(row.index[row]),
    axis=1,
)
code_15_labels

Unnamed: 0_level_0,is_male,1dAVb,RBBB,LBBB,SB,ST,AF,normal_ecg,actual
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
0,True,False,False,False,False,False,False,True,
1,True,False,False,False,False,False,False,False,
2,True,False,False,False,False,False,True,False,AF
3,True,False,False,False,False,False,False,True,
4,True,False,False,False,False,False,False,False,
...,...,...,...,...,...,...,...,...,...
345774,True,False,False,False,False,False,False,True,
345775,False,False,False,False,False,False,False,False,
345776,False,False,False,False,False,False,False,False,
345777,False,False,False,False,False,False,False,False,


In [35]:
# Visualize predicted and actual labels side-by-side
pred_thresh_mapped[['predicted']].join(code_15_labels[['actual']], how='left')

# Calculate accuracy
# Compare predicted and actual labels
comparison = pred_thresh_mapped[['predicted']].join(code_15_labels[['actual']], how='left')

# Calculate overall accuracy (exact match)
accuracy = (comparison['predicted'] == comparison['actual']).mean()
print(f"Overall accuracy: {accuracy:.2%}")

Overall accuracy: 70.65%


In [36]:
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import MultiLabelBinarizer

y_true_str = comparison['actual']
y_pred_str = comparison['predicted']

y_true_list = [labels.split(", ") for labels in y_true_str]
y_pred_list = [labels.split(", ") for labels in y_pred_str]

mlb = MultiLabelBinarizer()
y_true_bin = mlb.fit_transform(y_true_list)
y_pred_bin = mlb.transform(y_pred_list)

class_names = mlb.classes_

results = []

for i, cls_name in enumerate(class_names):
    y_true_col = y_true_bin[:, i]
    y_pred_col = y_pred_bin[:, i]
    
    tn, fp, fn, tp = confusion_matrix(y_true_col, y_pred_col).ravel()
    total = tp + tn + fp + fn
    prevalence = (tp + fn) / total
    recall = tp / (tp + fn)
    precision = tp / (tp + fp)
    f1 = 2 * (precision * recall) / (precision + recall)
    specificity = tn / (tn + fp)
    npv = tn / (tn + fn)
    accuracy = (tp + tn) / total

    results.append({
        'class': cls_name,
        'prevalence': round(prevalence, 3),
        'recall': round(recall, 3),
        'precision': round(precision, 3),
        'f1': round(f1, 3),
        'specificity': round(specificity, 3),
        'accuracy': round(accuracy, 3),
        'npv': round(npv, 3)
    })

metrics_df = pd.DataFrame(results)
print(metrics_df)

  class  prevalence  recall  precision     f1  specificity  accuracy    npv
0             0.905   0.777      0.905  0.836        0.226     0.724  0.097
1    AF       0.020   0.025      0.022  0.023        0.977     0.958  0.980
2  LBBB       0.018   0.029      0.016  0.021        0.968     0.952  0.982
3  RBBB       0.027   0.096      0.028  0.043        0.907     0.885  0.973
4    SB       0.016   0.070      0.015  0.025        0.926     0.912  0.984
5    ST       0.021   0.024      0.022  0.023        0.976     0.956  0.979
