In [1]:
!mkdir -p ./data/
!aws s3 sync --no-sign-request s3://physionet-open/ltafdb/1.0.0/ ./data/ltafdb/

In [2]:
import os
import gc

import wfdb
import numpy as np
import polars as pl
from tqdm.notebook import tqdm
# from scipy.signal import butter, filtfilt

import plotly.express as px
import plotly.graph_objects as go

In [3]:
# Download the dataset
os.makedirs('./data/ltafdb', exist_ok=True)
if not os.path.exists('./data/ltafdb/75.dat'):
    wfdb.dl_database('ltafdb', './data/ltafdb')
else:
    print('Dataset already downloaded')

Dataset already downloaded


In [4]:
def list_records(path: str) -> list[str]:
    records = []

    for root, _, files in os.walk(path):
        for file in files:
            if file.endswith(".dat"):
                records.append(os.path.join(root, file.replace(".dat", "")))
        
    return records

def load_record(record: str) -> tuple[pl.DataFrame, np.ndarray, dict]:
    qrs_annotations = wfdb.rdann(record, "qrs")
    annotations = wfdb.rdann(record, "atr")
    signals, fields = wfdb.rdsamp(record)

    annotation_df = pl.from_dict({
        'symbol': annotations.symbol,
        'aux': annotations.aux_note,
        'position': annotations.sample,
    })
    
    return qrs_annotations, annotation_df, signals, fields

## Annotations

As per the PhysioNet website (<https://physionet.org/static/lightwave/doc/annotations.html>),
the following annotations are available (selected only the ones present in this
dataset):

### Beat Annotations

| Annotation | Description                       |
| ---------- | --------------------------------- |
| N          | Normal Beat                       |
| A          | Atrial Premature Beat             |
| V          | Premature Ventricular Contraction |
| Q          | Unclassified Beat                 |

### Non-Beat Annotations

| Annotation | Description        |
| ---------- | ------------------ |
| +          | Rhythm Change      |
| "          | Comment Annotation |


In [5]:
qrs_ann, annotations, signal, fields = load_record("./data/ltafdb/100")

start_sample = 106000
end_sample = 107000
channel_sample = 0

sample_sig = signal[start_sample:end_sample, channel_sample]
sample_ann = annotations.filter(pl.col('position') >= start_sample, pl.col('position') <= end_sample)
sample_ann = sample_ann.with_columns((pl.col('position') - start_sample) / fields['fs'])

fig = px.line(x=np.arange(len(sample_sig), dtype=np.float64) / fields['fs'], y=sample_sig, title="Sample signal")
fig.update_layout(xaxis_title="Time (s)", yaxis_title="ECG [mV]")

for symbol, position in sample_ann.group_by('symbol').agg(pl.col('position')).iter_rows():
    fig.add_trace(go.Scatter(x=position, y=np.repeat(2.0, len(position)), mode='markers+text', textposition='top center', text=np.repeat(symbol, len(position)), name=symbol))

for pos, sym in zip(qrs_ann.sample, qrs_ann.symbol):
    if pos < start_sample or pos > end_sample:
        continue

    fig.add_annotation(x=(pos - start_sample) / fields['fs'], y=3.0, text=f"QRS - {sym}", showarrow=True)

fig.show()

In [6]:
def extract_windows(signal: np.ndarray, annotations: pl.DataFrame, fs: int, section_size: int, section_stride: int) -> tuple[list[np.ndarray], list[list[tuple[str, int]]]]:
    ann_filtered = annotations.filter(pl.col('symbol').is_in(('N', 'A', 'V')))
    ann_pos = ann_filtered['position'].to_numpy()
    ann_labels = ann_filtered['symbol'].to_numpy()

    sorted_pos = np.argsort(ann_pos)
    ann_pos = ann_pos[sorted_pos]
    ann_labels = ann_labels[sorted_pos]

    signal_samples = []
    signal_labels = []


    for i in range(0, len(signal) - section_size, section_stride):
        start_idx = i
        end_idx = i + section_size

        start_idx_pos = np.searchsorted(ann_pos, start_idx, side='left')
        end_idx_pos = np.searchsorted(ann_pos, end_idx, side='left')
        sample_labels = list(zip(ann_labels[start_idx_pos:end_idx_pos], ann_pos[start_idx_pos:end_idx_pos]))

        # Filter by reasonable BPM in section
        bpm = len(sample_labels) / (section_size / fs) * 60
        if bpm < 30 or bpm > 240:
            continue

        # Add sample and labels
        signal_samples.append(signal[start_idx:end_idx, 0].copy())
        signal_labels.append(sample_labels)

    # Discard first and last section
    signal_samples = signal_samples[1:-1]
    signal_labels = signal_labels[1:-1]

    return signal_samples, signal_labels

def binarize_af_labels(sample_labels: list[list[tuple[str, int]]], p: float = 0.5) -> list[bool]:
    return [sum(1 for l in labels if l[0] in ('A', 'V')) / len(labels) >= p for labels in sample_labels]


# 5s windows with a 1.25s stride
signal_samples, signal_labels = extract_windows(signal, annotations, fields['fs'], 5 * fields['fs'], 5 * fields['fs'] // 4)
print(f'Number of sections: {len(signal_samples)}')

Number of sections: 63071


In [7]:
SEED = 42
all_samples = []
balanced_samples = []
for record in tqdm(list_records("./data/ltafdb"), desc="Loading Records", unit="Record"):
    _, ann, sig, fields = load_record(record)
    win_signal, win_labels = extract_windows(sig, ann, fields['fs'], 
                                             section_size=5 * fields['fs'], section_stride=5 * fields['fs'] // 4)
    win_labels_bin = binarize_af_labels(win_labels, p=0.5)

    sample_df = pl.DataFrame([
        np.repeat(int(os.path.basename(record)), len(win_labels)),
        np.repeat(int(fields['fs']), len(win_labels)),
        win_signal,
        win_labels_bin,
    ], schema=['record', 'fs', 'signal', 'label'])
    all_samples.append(sample_df)

    af_df = sample_df.filter(pl.col('label') == True)
    non_af_df = sample_df.filter(pl.col('label') == False)
    sample_df = pl.concat([
        af_df, 
        non_af_df.sample(af_df.shape[0], shuffle=True, seed=SEED),
    ]).sample(fraction=1, shuffle=True, seed=SEED)

    balanced_samples.append(sample_df)


balanced_df = pl.concat(balanced_samples, how='vertical').sample(fraction=1, shuffle=True, seed=SEED)
balanced_df.write_parquet("./data/balanced.pqt.zst", compression='zstd', compression_level=5)

# This DF, uncompressed, consumes around 25GB of RAM, so be careful :)
all_samples_df = pl.concat(all_samples, how='vertical')
all_samples_df.write_parquet("./data/samples.pqt.zst", compression='zstd', compression_level=5)

del balanced_samples
del balanced_df
del all_samples_df
del all_samples
gc.collect()

Loading Records:   0%|          | 0/84 [00:00<?, ?Record/s]

18