In [1]:
import os
from glob import glob

import numpy as np
from scipy.signal import butter, filtfilt, find_peaks


# ==========================
# 1. PATHS – EDIT THIS PART
# ==========================

KOCH_ROOT = r"C:\Users\RAZER\Documents\GitHub\Game_of_SQUIDE\SQUIDE\dATASET\dataset_2\koch"

MOESM6_DIR = os.path.join(KOCH_ROOT, "MOESM6")  # BSPM (ECG-like)
MOESM7_DIR = os.path.join(KOCH_ROOT, "MOESM7")  # part 1 of MCG (A1, V1, X1)
MOESM8_DIR = os.path.join(KOCH_ROOT, "MOESM8")  # part 2 of MCG (Z1, Z2, Z3)

OUT_FILE = "koch_pairs.npz"
FS = 1000  # sampling frequency: 1 ms interval


# ==========================
# 2. LOADING FUNCTIONS
# ==========================

def load_bspm_ecg(moesm6_dir):
    """
    Load BSPM (ECG-like) signals from MOESM6/BSPM_data/BSPM_CH_E*.txt

    Returns:
        ecg_bspm: (C_ecg, T)
        file_list: list of txt file paths
    """
    bspm_folder = os.path.join(moesm6_dir, "BSPM_data")
    txt_files = sorted(glob(os.path.join(bspm_folder, "BSPM_CH_E*.txt")))
    if not txt_files:
        raise RuntimeError(f"No BSPM_CH_E*.txt files found in {bspm_folder}")

    channels = []
    for f in txt_files:
        data = np.loadtxt(f, dtype=np.float32)
        channels.append(data)

    # Align lengths
    min_len = min(len(ch) for ch in channels)
    channels = [ch[:min_len] for ch in channels]
    ecg_bspm = np.stack(channels, axis=0)  # (C, T)

    print(f"[BSPM] Loaded {ecg_bspm.shape[0]} channels, length {ecg_bspm.shape[1]} samples")

    # Scale as per paper: multiply by -1.0 * 10^(-6) to get mV
    ecg_bspm = -1.0e-6 * ecg_bspm

    # Channel 33 is respiration → E33 → index 32 (0-based)
    if ecg_bspm.shape[0] >= 33:
        idx = np.arange(ecg_bspm.shape[0])
        idx = idx[idx != 32]
        ecg_bspm = ecg_bspm[idx, :]
        print(f"[BSPM] Dropped channel 33 (respiration). New shape: {ecg_bspm.shape}")

    return ecg_bspm, txt_files


def load_mcg_group(folder):
    """
    Load all *.txt files from a given MCG_CH_* folder, e.g. MCG_CH_A1, MCG_CH_Z1.

    Returns:
        arr: (C_group, T)
        files: list of txt paths
    """
    txt_files = sorted(glob(os.path.join(folder, "MCG_CH_*.txt")))
    if not txt_files:
        raise RuntimeError(f"No MCG_CH_*.txt in {folder}")

    channels = []
    for f in txt_files:
        data = np.loadtxt(f, dtype=np.float32)
        channels.append(data)

    min_len = min(len(ch) for ch in channels)
    channels = [ch[:min_len] for ch in channels]
    arr = np.stack(channels, axis=0)
    print(f"[MCG group] {folder} -> {arr.shape[0]} channels, {arr.shape[1]} samples")
    return arr, txt_files


def load_full_mcg(moesm7_dir, moesm8_dir):
    """
    Load all MCG channels from MOESM7 (MCG_CH_A1, V1, X1) and
    MOESM8 (MCG_CH_Z1, Z2, Z3) and concatenate along channel axis.

    Returns:
        mcg_full: (C_mcg_total, T)
        file_list: list of paths
    """
    # Known subfolders from your dir listing:
    mcg_groups_7 = ["MCG_CH_A1", "MCG_CH_V1", "MCG_CH_X1"]
    mcg_groups_8 = ["MCG_CH_Z1", "MCG_CH_Z2", "MCG_CH_Z3"]

    all_arrays = []
    all_files = []

    # Load A1, V1, X1
    for grp in mcg_groups_7:
        folder = os.path.join(moesm7_dir, grp)
        if os.path.isdir(folder):
            arr, files = load_mcg_group(folder)
            all_arrays.append(arr)
            all_files.extend(files)
        else:
            print(f"[WARN] Missing folder: {folder}")

    # Load Z1, Z2, Z3
    for grp in mcg_groups_8:
        folder = os.path.join(moesm8_dir, grp)
        if os.path.isdir(folder):
            arr, files = load_mcg_group(folder)
            all_arrays.append(arr)
            all_files.extend(files)
        else:
            print(f"[WARN] Missing folder: {folder}")

    if not all_arrays:
        raise RuntimeError("No MCG groups found in MOESM7 or MOESM8")

    # Align lengths across all groups
    min_len = min(a.shape[1] for a in all_arrays)
    all_arrays = [a[:, :min_len] for a in all_arrays]

    mcg_full = np.concatenate(all_arrays, axis=0)  # (C_total, T)
    print(f"[MCG] Total concatenated shape: {mcg_full.shape}")
    return mcg_full, all_files


# ==========================
# 3. R-PEAK DETECTION
# ==========================

def bandpass_filter(signal, fs=1000, low=5.0, high=15.0, order=2):
    nyq = 0.5 * fs
    b, a = butter(order, [low / nyq, high / nyq], btype="band")
    return filtfilt(b, a, signal)


def detect_r_peaks(signal, fs=1000, distance_ms=300, height_factor=0.4):
    """
    Very basic R-peak detector for clean signals.
    """
    filtered = bandpass_filter(signal, fs=fs, low=5.0, high=15.0, order=2)
    distance = int(distance_ms * fs / 1000.0)
    height = height_factor * np.max(filtered)
    peaks, _ = find_peaks(filtered, distance=distance, height=height)
    return peaks


def normalize_beat(x):
    """
    Channel-wise median/MAD normalization for a single beat window.
    x: (C, T)
    """
    m = np.median(x, axis=1, keepdims=True)
    mad = np.median(np.abs(x - m), axis=1, keepdims=True) + 1e-6
    return (x - m) / mad


# ==========================
# 4. MAIN
# ==========================

def main():
    # Load ECG (BSPM)
    ecg_bspm, ecg_files = load_bspm_ecg(MOESM6_DIR)

    # Load MCG (all groups from MOESM7 + MOESM8)
    mcg, mcg_files = load_full_mcg(MOESM7_DIR, MOESM8_DIR)

    # Align time length between ECG and MCG
    T = min(ecg_bspm.shape[1], mcg.shape[1])
    ecg_bspm = ecg_bspm[:, :T]
    mcg = mcg[:, :T]
    print(f"[ALIGN] ECG shape: {ecg_bspm.shape}, MCG shape: {mcg.shape}")

    # Choose one ECG channel for R-peaks (channel 0 for now)
    r_lead = ecg_bspm[0, :]
    r_peaks = detect_r_peaks(r_lead, fs=FS, distance_ms=300, height_factor=0.4)
    print(f"[R-PEAKS] Detected {len(r_peaks)} peaks")

    # Extract 2-second windows around each R-peak
    pre = int(0.5 * FS)   # 500 ms before R
    post = int(1.5 * FS)  # 1500 ms after R
    win_len = pre + post  # 2000 samples

    ecg_beats = []
    mcg_beats = []

    for r in r_peaks:
        start = r - pre
        end = r + post
        if start < 0 or end > T:
            continue

        ecg_win = ecg_bspm[:, start:end]  # (C_ecg, win_len)
        mcg_win = mcg[:, start:end]       # (C_mcg, win_len)

        if ecg_win.shape[1] != win_len or mcg_win.shape[1] != win_len:
            continue

        ecg_norm = normalize_beat(ecg_win).astype(np.float32)
        mcg_norm = normalize_beat(mcg_win).astype(np.float32)

        ecg_beats.append(ecg_norm)
        mcg_beats.append(mcg_norm)

    ecg_beats = np.stack(ecg_beats, axis=0)  # (N, C_ecg, 2000)
    mcg_beats = np.stack(mcg_beats, axis=0)  # (N, C_mcg, 2000)

    print(f"[RESULT] ECG beats: {ecg_beats.shape}, MCG beats: {mcg_beats.shape}")

    # Save to NPZ
    np.savez_compressed(
        OUT_FILE,
        ecg_beats=ecg_beats,
        mcg_beats=mcg_beats,
        fs=np.array([FS], dtype=np.int32),
        ecg_channel_files=np.array(ecg_files, dtype=object),
        mcg_channel_files=np.array(mcg_files, dtype=object),
        description=np.array(["Koch ECG (BSPM) + SQUID MCG, paired 2s beats"], dtype=object),
    )
    print(f"[SAVE] Saved Koch ECG–MCG pairs to {OUT_FILE}")


if __name__ == "__main__":
    main()


[BSPM] Loaded 33 channels, length 100000 samples
[BSPM] Dropped channel 33 (respiration). New shape: (32, 100000)
[MCG group] C:\Users\RAZER\Documents\GitHub\Game_of_SQUIDE\SQUIDE\dATASET\dataset_2\koch\MOESM7\MCG_CH_A1 -> 17 channels, 100000 samples
[MCG group] C:\Users\RAZER\Documents\GitHub\Game_of_SQUIDE\SQUIDE\dATASET\dataset_2\koch\MOESM7\MCG_CH_V1 -> 19 channels, 100000 samples
[MCG group] C:\Users\RAZER\Documents\GitHub\Game_of_SQUIDE\SQUIDE\dATASET\dataset_2\koch\MOESM7\MCG_CH_X1 -> 15 channels, 100000 samples
[MCG group] C:\Users\RAZER\Documents\GitHub\Game_of_SQUIDE\SQUIDE\dATASET\dataset_2\koch\MOESM8\MCG_CH_Z1 -> 17 channels, 100000 samples
[MCG group] C:\Users\RAZER\Documents\GitHub\Game_of_SQUIDE\SQUIDE\dATASET\dataset_2\koch\MOESM8\MCG_CH_Z2 -> 16 channels, 100000 samples
[MCG group] C:\Users\RAZER\Documents\GitHub\Game_of_SQUIDE\SQUIDE\dATASET\dataset_2\koch\MOESM8\MCG_CH_Z3 -> 16 channels, 100000 samples
[MCG] Total concatenated shape: (100, 100000)
[ALIGN] ECG shape:

In [2]:
import numpy as np

data = np.load("koch_pairs.npz", allow_pickle=True)
ecg_beats = data["ecg_beats"]
mcg_beats = data["mcg_beats"]

print("ECG beats:", ecg_beats.shape)  # (127, 32, 2000)
print("MCG beats:", mcg_beats.shape)  # (127, 100, 2000)
print("fs:", data["fs"])
print("desc:", data["description"])


ECG beats: (127, 32, 2000)
MCG beats: (127, 100, 2000)
fs: [1000]
desc: ['Koch ECG (BSPM) + SQUID MCG, paired 2s beats']


In [6]:
import numpy as np
import torch
from torch.utils.data import Dataset


class KochPairedBeatsDataset(Dataset):
    def __init__(self, npz_path="koch_pairs.npz", augment=False):
        """
        npz_path: path to the koch_pairs.npz file
        augment: if True, apply very light noise / jitter augmentations
        """
        data = np.load(npz_path, allow_pickle=True)
        self.ecg_beats = data["ecg_beats"]  # (N, C_ecg, T)
        self.mcg_beats = data["mcg_beats"]  # (N, C_mcg, T)
        self.fs = int(data["fs"][0])
        self.augment = augment

    def __len__(self):
        return self.ecg_beats.shape[0]

    def _augment(self, x):
        """
        x: numpy array (C, T)
        very light augmentations: small Gaussian noise and tiny time shift
        """
        # small noise
        x = x + 0.01 * np.random.randn(*x.shape).astype(np.float32)

        # tiny circular time shift up to ±20 samples
        max_shift = 20
        shift = np.random.randint(-max_shift, max_shift + 1)
        if shift != 0:
            x = np.roll(x, shift, axis=1)

        return x

    def __getitem__(self, idx):
        ecg = self.ecg_beats[idx]  # (C_ecg, T)
        mcg = self.mcg_beats[idx]  # (C_mcg, T)

        if self.augment:
            ecg = self._augment(ecg)
            mcg = self._augment(mcg)

        # convert to torch tensors
        ecg = torch.from_numpy(ecg)  # float32, shape (C_ecg, T)
        mcg = torch.from_numpy(mcg)  # float32, shape (C_mcg, T)

        return ecg, mcg


In [7]:
#from koch_dataset import KochPairedBeatsDataset
from torch.utils.data import DataLoader

ds = KochPairedBeatsDataset("koch_pairs.npz", augment=False)
print("Num beats:", len(ds))

loader = DataLoader(ds, batch_size=8, shuffle=True)
ecg, mcg = next(iter(loader))
print("ECG batch shape:", ecg.shape)  # expect: torch.Size([8, 32, 2000])
print("MCG batch shape:", mcg.shape)  # expect: torch.Size([8, 100, 2000])


Num beats: 127
ECG batch shape: torch.Size([8, 32, 2000])
MCG batch shape: torch.Size([8, 100, 2000])


# 1. PTB ECG – Prepare Beat-Level Dataset

In [None]:
import os
from glob import glob

import numpy as np
import wfdb
from scipy.signal import find_peaks, butter, filtfilt


def list_ptb_records(ptb_root):
    """
    ptb_root: path to PTB root folder, e.g. 'data/ptbdb'
    Returns list of full record paths without extension, e.g. '.../patient001/s0010_re'
    """
    rec_paths = []
    for patient_dir in sorted(glob(os.path.join(ptb_root, 'patient*'))):
        for rec in sorted(glob(os.path.join(patient_dir, '*.hea'))):
            rec_paths.append(rec[:-4])  # strip '.hea'
    return rec_paths


def bandpass_filter(signal, fs=1000, low=5.0, high=15.0, order=2):
    """
    Simple bandpass filter for QRS detection.
    signal: 1D numpy array
    fs: sampling frequency
    """
    nyq = 0.5 * fs
    b, a = butter(order, [low / nyq, high / nyq], btype='band')
    return filtfilt(b, a, signal)
    

def detect_r_peaks(ecg_lead, fs=1000, distance_ms=300, height_factor=0.5):
    """
    Very basic R-peak detector on a single lead.
    ecg_lead: 1D numpy array
    distance_ms: minimum distance between peaks
    height_factor: relative threshold on peak height
    Returns indices of R-peaks.
    """
    filtered = bandpass_filter(ecg_lead, fs=fs, low=5.0, high=15.0)
    distance = int(distance_ms * fs / 1000.0)
    # rough threshold: height_factor * max
    height = height_factor * np.max(filtered)
    peaks, _ = find_peaks(filtered, distance=distance, height=height)
    return peaks


In [None]:
import os
import numpy as np
import wfdb
from collections import defaultdict

PTB_ROOT = "data/ptbdb"    # TODO: change to your path
OUT_FILE = "ptb_beats.npz" # output

# standard 12 leads in PTB header
STANDARD_LEADS = ['i', 'ii', 'iii', 'avr', 'avl', 'avf',
                  'v1', 'v2', 'v3', 'v4', 'v5', 'v6']

all_beats = []
all_record_names = []
all_subject_ids = []
all_labels = []  # you can later fill with MI/non-MI, etc.

rec_paths = list_ptb_records(PTB_ROOT)
print(f"Found {len(rec_paths)} records")

for rec_path in rec_paths:
    try:
        record = wfdb.rdrecord(rec_path)
    except Exception as e:
        print(f"Failed to read {rec_path}: {e}")
        continue

    sig = record.p_signal  # shape (T, n_channels)
    fs = int(record.fs)    # should be 1000
    ch_names = [ch.lower() for ch in record.sig_name]

    # map lead names to indices
    lead_indices = []
    for ln in STANDARD_LEADS:
        if ln in ch_names:
            lead_indices.append(ch_names.index(ln))
        else:
            print(f"Lead {ln} not found in {rec_path}, skipping record.")
            lead_indices = []
            break

    if not lead_indices:
        continue

    ecg_12 = sig[:, lead_indices].T  # shape (12, T)
    T = ecg_12.shape[1]

    # use lead II for R-peak detection
    lead2_index = ch_names.index('ii')
    lead2 = sig[:, lead2_index]
    r_peaks = detect_r_peaks(lead2, fs=fs, distance_ms=300, height_factor=0.4)

    if len(r_peaks) < 5:
        print(f"Few R-peaks in {rec_path}, skipping.")
        continue

    # 2-second window: [-500ms, +1500ms] → 2000 samples
    pre = int(0.5 * fs)
    post = int(1.5 * fs)
    win_len = pre + post

    for r in r_peaks:
        start = r - pre
        end = r + post
        if start < 0 or end > T:
            continue
        beat = ecg_12[:, start:end]  # (12, 2000)
        if beat.shape[1] != win_len:
            continue

        # simple per-beat normalization (median + MAD)
        m = np.median(beat, axis=1, keepdims=True)
        mad = np.median(np.abs(beat - m), axis=1, keepdims=True) + 1e-6
        beat_norm = (beat - m) / mad

        all_beats.append(beat_norm.astype(np.float32))
        all_record_names.append(os.path.basename(rec_path))
        # subject id is patient folder name
        all_subject_ids.append(os.path.basename(os.path.dirname(rec_path)))
        all_labels.append(0)  # placeholder; you can parse .hea for diagnosis later

all_beats = np.stack(all_beats, axis=0)  # (N, 12, 2000)
all_record_names = np.array(all_record_names)
all_subject_ids = np.array(all_subject_ids)
all_labels = np.array(all_labels)

print("PTB beats shape:", all_beats.shape)

np.savez_compressed(
    OUT_FILE,
    beats=all_beats,
    record_names=all_record_names,
    subject_ids=all_subject_ids,
    labels=all_labels,
    fs=np.array([1000], dtype=np.int32),
    description=np.array(["PTB 12-lead beats, 2s windows around R-peaks"], dtype=object)
)
print(f"Saved PTB beats to {OUT_FILE}")


# 2. Koch ECG–MCG – Prepare Paired Beat Dataset