# Instructions

In this tutorial, we will perform multi-label classification using an ECG-FM model finetuned on the [MIMIC-IV-ECG v1.0 dataset](https://physionet.org/content/mimic-iv-ecg/1.0/). It outlines the data and model loading, as well as inference, same-sample prediction aggregation, and visualizations for embeddings and saliency maps.

ECG-FM was developed in collaboration with the [fairseq_signals](https://github.com/Jwoo5/fairseq-signals) framework, which implements a collection of deep learning methods for ECG analysis.

This is segment the ECG into inputs of 5 s and perform a label-specific aggregation of the predictions from each sample

This document serves largely as a quickstart introduction. Much of this functionality is also available via the [fairseq-signals scripts](https://github.com/bowang-lab/ECG-FM/blob/main/notebooks/infer_cli.ipynb), as well the [ECG-FM scripts](https://github.com/bowang-lab/ECG-FM/tree/main/scripts).

## Installation

Begin by cloning [fairseq_signals](https://github.com/Jwoo5/fairseq-signals) and refer to the installation section in the top-level README. For example, the following commands are sufficient at the present moment:
```
# Creating `fairseq` environment:
conda create --name fairseq python=3.10.6
source activate fairseq
git clone https://github.com/Jwoo5/fairseq-signals
cd fairseq-signals
python3 -m pip install pip==24.0
python3 -m pip install -e .
```

In [None]:
# You may require the following imports depending on what functionality you run
!pip install huggingface-hub
!pip install pandas
!pip install ecg-transform==0.1.3
!pip install umap-learn
!pip install plotly

In [1]:
import os

root = os.path.dirname(os.getcwd())

## Download checkpoints

Checkpoints are available on [HuggingFace](https://huggingface.co/wanglab/ecg-fm). The finetuned model be downloaded using the following command:

In [2]:
import os
from huggingface_hub import hf_hub_download

_ = hf_hub_download(
    repo_id='wanglab/ecg-fm',
    filename='mimic_iv_ecg_finetuned.yaml',
    local_dir=os.path.join(root, 'ckpts'),
)

## Infer

In [None]:
import os
import torch
import numpy as np
import random
from itertools import chain
from typing import List, Dict
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, matthews_corrcoef
import torch.nn as nn
import torch.optim as optim

# ECG-FM imports
from ecg_transform.inp import ECGInput, ECGInputSchema
from ecg_transform.sample import ECGSample, ECGMetadata
from ecg_transform.t.base import ECGTransform
from ecg_transform.t.common import (
    HandleConstantLeads,
    LinearResample,
    ReorderLeads,
)
from ecg_transform.t.scale import Standardize
from ecg_transform.t.cut import SegmentNonoverlapping

from fairseq_signals.models import build_model_from_checkpoint
from fairseq_signals.models.classification.ecg_transformer_classifier import ECGTransformerClassificationModel

# -----------------------------
# Constants
# -----------------------------
ECG_FM_LEAD_ORDER = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
SAMPLE_RATE = 500
N_SAMPLES = SAMPLE_RATE * 5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# -----------------------------
# Load Data (Change dataset path for other downstream tasks)
# -----------------------------
lead_file_paths = {
    f"LEAD_{lead}": f"data_aspire_PAP/LEAD_{lead}.pt" for lead in ECG_FM_LEAD_ORDER
}
labels_file_path = "data_aspire_PAP/labels.pt"

ecg_lead_tensors = {lead: torch.load(path) for lead, path in lead_file_paths.items()}
labels = torch.load(labels_file_path)

# -----------------------------
# Schema and Transforms
# -----------------------------
ECG_FM_SCHEMA = ECGInputSchema(
    sample_rate=SAMPLE_RATE,
    expected_lead_order=ECG_FM_LEAD_ORDER,
    required_num_samples=N_SAMPLES,
)

ECG_FM_TRANSFORMS = [
    ReorderLeads(expected_order=ECG_FM_LEAD_ORDER, missing_lead_strategy='raise'),
    LinearResample(desired_sample_rate=SAMPLE_RATE),
    HandleConstantLeads(strategy='zero'),
    Standardize(),
    SegmentNonoverlapping(segment_length=N_SAMPLES),
]

class ECGFromPTDataset(Dataset):
    def __init__(self, ecg_leads_dict, labels, schema, transforms):
        self.ecg_leads = ecg_leads_dict
        self.labels = labels
        self.schema = schema
        self.transforms = transforms
        self.lead_names = list(ecg_leads_dict.keys())
        self.num_samples = len(labels)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        data = np.stack([self.ecg_leads[lead][idx].numpy() for lead in self.lead_names], axis=0)
        metadata = ECGMetadata(
            sample_rate=SAMPLE_RATE,
            num_samples=data.shape[1],
            lead_names=[lead.replace("LEAD_", "") for lead in self.lead_names],
            input_start=0,
            input_end=data.shape[1],
            unit=None
        )
        metadata.file = f"sample_{idx}"

        inp = ECGInput(data, metadata)
        sample = ECGSample(inp, self.schema, self.transforms)
        source = torch.from_numpy(sample.out).float()
        return {"source": source, "label": self.labels[idx]}

def collate_fn(samples: List[Dict]):
    x_segments = [s["source"] for s in samples]  # list of [segments, C]
    x_lens = [seg.shape[0] for seg in x_segments]
    x = torch.cat(x_segments, dim=0)  # [total_segments, C]
    y = torch.tensor([s["label"] for s in samples])  # [B]
    y = torch.repeat_interleave(y, torch.tensor(x_lens))  # repeat each label per segment
    return {"net_input": {"source": x}, "label": y}

def pt_data_loader(dataset, batch_size=64, num_workers=0):
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn)

# -----------------------------
# Load Pretrained ECG-FM Model
# -----------------------------
ckpt_path = "mimic_iv_ecg_physionet_pretrained.pt"
model: ECGTransformerClassificationModel = build_model_from_checkpoint(checkpoint_path=ckpt_path)
model.eval()
model.to(device)

# -----------------------------
# Training and Evaluation
# -----------------------------
def set_seed(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(42)

def train_one_epoch(model, dataloader, criterion, optimizer):
    model.train()
    total_loss = 0
    for batch in dataloader:
        x = batch["net_input"]["source"].to(device)
        y = batch["label"].to(device)
        optimizer.zero_grad()
        output_dict = model(source=x)
        logits = model.get_logits(output_dict)
        if logits.shape[0] != y.shape[0]:
            min_len = min(logits.shape[0], y.shape[0])
            logits = logits[:min_len]
            y = y[:min_len]
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * y.size(0)
    return total_loss / len(dataloader.dataset)

def evaluate(model, dataloader):
    model.eval()
    all_labels, all_preds, all_probs = [], [], []
    with torch.no_grad():
        for batch in dataloader:
            x = batch["net_input"]["source"].to(device)
            y = batch["label"].to(device)
            output_dict = model(source=x)
            logits = model.get_logits(output_dict)
            prob = torch.softmax(logits, dim=1)
            pred = torch.argmax(prob, dim=1)
            min_len = min(len(y), len(pred))
            y, pred, prob = y[:min_len], pred[:min_len], prob[:min_len]
            all_labels.extend(y.cpu().numpy())
            all_preds.extend(pred.cpu().numpy())
            all_probs.extend(prob[:, 1].cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    auc = roc_auc_score(all_labels, all_probs)
    f1 = f1_score(all_labels, all_preds)
    mcc = matthews_corrcoef(all_labels, all_preds)
    return acc, auc, f1, mcc

# -----------------------------
# K-Fold Cross Validation
# -----------------------------
dataset = ECGFromPTDataset(ecg_lead_tensors, labels, ECG_FM_SCHEMA, ECG_FM_TRANSFORMS)
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
fold_results = []

for fold, (train_idx, test_idx) in enumerate(skf.split(np.zeros(len(labels)), labels)):
    print(f"FOLD {fold}")
    train_loader = pt_data_loader(Subset(dataset, train_idx), batch_size=64)
    test_loader = pt_data_loader(Subset(dataset, test_idx), batch_size=64)

    clf = build_model_from_checkpoint(checkpoint_path=ckpt_path)
    clf.eval()
    clf.to(device)

    optimizer = optim.Adam(clf.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(10):
        loss = train_one_epoch(clf, train_loader, criterion, optimizer)
        print(f"Epoch {epoch+1} Loss: {loss:.4f}")

    acc, auc, f1, mcc = evaluate(clf, test_loader)
    fold_results.append((acc, auc, f1, mcc))
    print(f"Fold {fold} Results: Acc={acc:.4f}, AUC={auc:.4f}, F1={f1:.4f}, MCC={mcc:.4f}\n")

# -----------------------------
# Results Summary
# -----------------------------
accuracies, aucs, f1s, mccs = zip(*fold_results)
print(f'Mean Accuracy: {np.mean(accuracies):.4f} ± {np.std(accuracies):.4f}')
print(f'Mean AUC: {np.mean(aucs):.4f} ± {np.std(aucs):.4f}')
print(f'Mean F1: {np.mean(f1s):.4f} ± {np.std(f1s):.4f}')
print(f'Mean MCC: {np.mean(mccs):.4f} ± {np.std(mccs):.4f}')
