In [None]:
!pip install -qU wfdb streamlit


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.2/91.2 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m163.8/163.8 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.0/10.0 MB[0m [31m45.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.0/12.0 MB[0m [31m61.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m55.1 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-colab 1.0.0 requires pandas==2.2.2, but you have pandas 2.3.2 which is incompatible.
cudf-cu12 25.6.0 requires pandas<2.2.4dev0,>=2.0, but you have pandas 2.3.2 which is incompatible.
dask-cudf-cu12 25.6.0 requir

## Download Database

In [None]:
# wfdb.dl_database("mitdb", dl_dir="mitdb_local")
# wfdb.dl_database("afdb",  dl_dir="afdb_local")

## Test with small dataset

In [None]:
# patients' ID

# MITDB ID
mit_ids = ["106", "201", "203"]
# AFDB ID
af_ids  = []


USE_LOCAL = False
MIT_LOCAL = "./mitdb_local"
AF_LOCAL  = "./afdb_local"

## Create functions: Filtering, R-peak detection, feature extraction, waveform annotation

In [None]:
import numpy as np
import pandas as pd
import wfdb.processing as wproc
from scipy.signal import butter, filtfilt
from scipy.stats import variation

def bandpass(x, fs, lo=0.5, hi=40.0, order=3):  # Filter out noise
    b, a = butter(order, [lo/(fs/2), hi/(fs/2)], btype='band')
    return filtfilt(b, a, x)

def extract_features(seg, fs):
    seg_f = bandpass(seg, fs)

    # Detect R peak
    try:
        r_idx = wproc.xqrs_detect(sig=seg_f, fs=fs)
    except Exception:
        r_idx = np.array([], dtype=int)

    if len(r_idx) > 1:
        rr = np.diff(r_idx) / fs
        mean_hr = float(60.0 / rr.mean())
        cvrr    = float(variation(rr)) if rr.std() > 0 else 0.0
        rmssd   = float(np.sqrt(np.mean(np.square(np.diff(rr))))) if len(rr) > 2 else 0.0
    else:
        mean_hr = cvrr = rmssd = 0.0

    qrs_w = float((np.percentile(np.diff(r_idx),95)/fs)) if len(r_idx)>2 else 0.0
    r_amp = float(seg_f[r_idx].mean()) if len(r_idx)>0 else 0.0
    energy= float(np.sum(seg_f**2)/len(seg_f))
    return [mean_hr, cvrr, rmssd, qrs_w, r_amp, energy, int(len(r_idx))]

In [None]:
def window_label(t0, t1, ann, fs, pvc_ratio_threshold=0.2):
    # Abnormal beat-level annotations
    abnormal_beats = {'A', 'L', 'R', 'F', '/', 'f', 'Q'}
    # Critical rhythm-level annotations (always abnormal if present)
    critical_rhythms = {"AFIB", "AFL", "VT", "VFL"}

    start, end = int(t0 * fs), int(t1 * fs)
    m = (ann.sample >= start) & (ann.sample < end)
    beats = np.array(ann.symbol)[m] if np.any(m) else np.array([])

    # Rhythm annotations (aux_note)
    aux = [a for i, a in enumerate(ann.aux_note) if a and (start <= ann.sample[i] < end)]
    aux_norm = [a.upper() for a in aux]  # normalize to uppercase strings

    # 1) Check for critical rhythm-level events
    if any(any(r in a for r in critical_rhythms) for a in aux_norm):
        # If AFIB is present, return AF (keep AF as a separate label)
        if any("AFIB" in a for a in aux_norm):
            return "AF"
        return "Abnormal"

    # 2) PVC detection (based on ratio and count)
    if len(beats) > 0:
        pvc_count = np.sum(beats == 'V')
        pvc_ratio = pvc_count / len(beats)
        if pvc_ratio >= pvc_ratio_threshold and pvc_count >= 2:
            return "PVC"

    # 3) Other abnormal beats
    if any(b in abnormal_beats for b in beats):
        return "Abnormal"

    # 4) Otherwise, label as Normal
    return "Normal"


## Windowing

In [None]:
import os
import wfdb


WIN, STRIDE = 10, 5     # 10-second window, 5-second stride
LEAD        = 0         # Use lead 0
rows = []

# === Add for CNN ===
X_raw = []
y_bin = []


datasets = [
    ("mitdb", mit_ids, MIT_LOCAL),
    ("afdb",  af_ids,  AF_LOCAL),
]

for db_name, ids, local_dir in datasets:
    # Read record and annotation
    for rid in ids:
        if USE_LOCAL:
            # Read local database
            rec_path = os.path.join(local_dir, str(rid))
            rec = wfdb.rdrecord(rec_path)
            ann = wfdb.rdann(rec_path, "atr")
        else:
            # Read online database
            rec = wfdb.rdrecord(str(rid), pn_dir=db_name)
            ann = wfdb.rdann(str(rid), "atr", pn_dir=db_name)

        fs  = rec.fs
        sig = rec.p_signal[:, LEAD]

        win_samp, stride_samp = int(WIN*fs), int(STRIDE*fs)
        for start in range(0, len(sig)-win_samp+1, stride_samp):
            end = start + win_samp
            seg = sig[start:end]
            t0, t1 = start/fs, end/fs

            feats = extract_features(seg, fs)
            label = window_label(t0, t1, ann, fs)

            rows.append({
                "db": db_name, "record": str(rid),
                "t0": t0, "t1": t1,
                "mean_HR": feats[0], "CVRR": feats[1], "RMSSD": feats[2],
                "QRS_width": feats[3], "R_amp": feats[4], "Energy": feats[5],
                "R_count": feats[6], "label": label
            })
            # === Collect raw signal window for CNN ===
            if feats[6] >= 2:  # same filter as df["R_count"] >= 2
                seg_f = bandpass(seg, fs)
                m, s = np.mean(seg_f), np.std(seg_f)
                if s > 0:
                    seg_z = (seg_f - m) / s
                    X_raw.append(seg_z.astype(np.float32))
                    y_bin.append(0 if label == "Normal" else 1)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Learning initial signal parameters...
Found 8 beats during learning. Initializing using learned parameters
Running QRS detection...
QRS detection complete.
Learning initial signal parameters...
Found 8 beats during learning. Initializing using learned parameters
Running QRS detection...
QRS detection complete.
Learning initial signal parameters...
Found 8 beats during learning. Initializing using learned parameters
Running QRS detection...
QRS detection complete.
Learning initial signal parameters...
Failed to find 8 beats during learning.
Initializing using default parameters
Running QRS detection...
QRS detection complete.
Learning initial signal parameters...
Failed to find 8 beats during learning.
Initializing using default parameters
Running QRS detection...
QRS detection complete.
Learning initial signal parameters...
Failed to find 8 beats during learning.
Initializing using default parameters
Running QRS detection

In [None]:
df = pd.DataFrame(rows)
# Filter out windows with too few beats
df = df[df["R_count"]>=2].reset_index(drop=True)

print("Rows:", len(df))
print(df["label"].value_counts())
df.head(20)

Rows: 1080
label
Normal      616
PVC         376
Abnormal     44
AF           44
Name: count, dtype: int64


Unnamed: 0,db,record,t0,t1,mean_HR,CVRR,RMSSD,QRS_width,R_amp,Energy,R_count,label
0,mitdb,106,0.0,10.0,60.372671,0.054863,0.052134,1.050556,2.237665,0.140133,10,Normal
1,mitdb,106,5.0,15.0,63.570961,0.057368,0.061064,1.023333,2.26988,0.155462,10,Normal
2,mitdb,106,10.0,20.0,69.364162,0.064052,0.036075,0.953472,2.287595,0.143953,11,Normal
3,mitdb,106,15.0,25.0,74.25,0.032856,0.022872,0.855556,2.275705,0.133523,12,Normal
4,mitdb,106,20.0,30.0,79.827533,0.054971,0.030186,0.801944,2.299192,0.141329,13,Normal
5,mitdb,106,25.0,35.0,71.618037,0.140782,0.06628,0.988472,2.116861,0.12923,11,Normal
6,mitdb,106,30.0,40.0,64.631957,0.070756,0.04231,1.001806,1.595677,0.072019,11,Normal
7,mitdb,106,35.0,45.0,67.057606,0.047477,0.037616,0.941667,1.209281,0.045495,10,Normal
8,mitdb,106,40.0,50.0,61.970035,0.053471,0.053845,1.055,1.106784,0.035106,10,Normal
9,mitdb,106,45.0,55.0,65.36651,0.084302,0.048957,1.050556,1.112534,0.036165,10,Normal


## Save into CSV

In [None]:
from google.colab import files

OUT_CSV = "ecg_features.csv"
df.to_csv(OUT_CSV, index=False)

files.download(OUT_CSV)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

## Binary Classification

In [None]:
df["label"] = np.where(df["label"] == "Normal", "Normal", "Abnormal")

print(df["label"].value_counts())
df.head(20)

label
Normal      616
Abnormal    464
Name: count, dtype: int64


Unnamed: 0,db,record,t0,t1,mean_HR,CVRR,RMSSD,QRS_width,R_amp,Energy,R_count,label
0,mitdb,106,0.0,10.0,60.372671,0.054863,0.052134,1.050556,2.237665,0.140133,10,Normal
1,mitdb,106,5.0,15.0,63.570961,0.057368,0.061064,1.023333,2.26988,0.155462,10,Normal
2,mitdb,106,10.0,20.0,69.364162,0.064052,0.036075,0.953472,2.287595,0.143953,11,Normal
3,mitdb,106,15.0,25.0,74.25,0.032856,0.022872,0.855556,2.275705,0.133523,12,Normal
4,mitdb,106,20.0,30.0,79.827533,0.054971,0.030186,0.801944,2.299192,0.141329,13,Normal
5,mitdb,106,25.0,35.0,71.618037,0.140782,0.06628,0.988472,2.116861,0.12923,11,Normal
6,mitdb,106,30.0,40.0,64.631957,0.070756,0.04231,1.001806,1.595677,0.072019,11,Normal
7,mitdb,106,35.0,45.0,67.057606,0.047477,0.037616,0.941667,1.209281,0.045495,10,Normal
8,mitdb,106,40.0,50.0,61.970035,0.053471,0.053845,1.055,1.106784,0.035106,10,Normal
9,mitdb,106,45.0,55.0,65.36651,0.084302,0.048957,1.050556,1.112534,0.036165,10,Normal


In [None]:
X = df[["mean_HR","CVRR","RMSSD","QRS_width","R_amp","Energy","R_count"]]
y = df["label"]

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=42)

## Logistic Regression

In [None]:
from sklearn.metrics import classification_report
from sklearn.linear_model import LogisticRegression

log_clf = LogisticRegression(
    max_iter=1000, class_weight="balanced",
    solver="lbfgs", multi_class="auto"
)
log_clf.fit(X_train, y_train)
y_pred_log = log_clf.predict(X_test)

print("=== Logistic Regression ===")
print(classification_report(y_test, y_pred_log))


=== Logistic Regression ===
              precision    recall  f1-score   support

    Abnormal       0.84      0.88      0.86       139
      Normal       0.91      0.88      0.89       185

    accuracy                           0.88       324
   macro avg       0.87      0.88      0.87       324
weighted avg       0.88      0.88      0.88       324





## Random Forest

In [None]:
from sklearn.ensemble import RandomForestClassifier


clf = RandomForestClassifier(n_estimators=300, class_weight="balanced", random_state=42)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)

print("=== Random Forest ===")
print(classification_report(y_test, y_pred))

=== Random Forest ===
              precision    recall  f1-score   support

    Abnormal       0.88      0.91      0.89       139
      Normal       0.93      0.90      0.92       185

    accuracy                           0.90       324
   macro avg       0.90      0.90      0.90       324
weighted avg       0.91      0.90      0.90       324



## CNN

In [None]:
# ===================== Binary PyTorch 1D CNN  =====================
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, precision_recall_curve

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

# ---- Assumes X_raw (shape [N,T] or [N,T,1]) and y_bin (0/1) already exist ----
X_raw = np.asarray(X_raw, dtype=np.float32)
y_bin = np.asarray(y_bin, dtype=np.int64)

# ---- Split with safe stratification ----
vals, cnts = np.unique(y_bin, return_counts=True)
strat = y_bin if (len(vals) == 2 and np.min(cnts) >= 2) else None
if strat is None:
    print("[Warn] cannot stratify train/test (one class missing or too few).")

Xtr, Xte, ytr, yte = train_test_split(
    X_raw, y_bin, test_size=0.30, stratify=strat, random_state=42
)

vals_tr, cnts_tr = np.unique(ytr, return_counts=True)
strat_val = ytr if (len(vals_tr) == 2 and np.min(cnts_tr) >= 2) else None
if strat_val is None:
    print("[Warn] cannot stratify train/val (one class missing or too few).")

Xtr, Xva, ytr, yva = train_test_split(
    Xtr, ytr, test_size=0.20, stratify=strat_val, random_state=42
)

print("Train counts:", dict(zip(*np.unique(ytr, return_counts=True))))
print("Valid counts:", dict(zip(*np.unique(yva, return_counts=True))))
print("Test  counts:", dict(zip(*np.unique(yte, return_counts=True))))

# ---- Torch helpers ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

def to_torch_chw(x_np: np.ndarray) -> torch.Tensor:
    """[N,T] or [N,T,1] -> [N,1,T] float32 tensor"""
    if x_np.ndim == 2:
        x_np = x_np[:, None, :]        # [N,1,T]
    elif x_np.ndim == 3:                # [N,T,1] -> [N,1,T]
        x_np = np.transpose(x_np, (0, 2, 1))
    return torch.from_numpy(x_np.astype(np.float32))

Xtr_t = to_torch_chw(Xtr).to(device)
Xva_t = to_torch_chw(Xva).to(device)
Xte_t = to_torch_chw(Xte).to(device)

# Targets as float for BCEWithLogitsLoss
ytr_t = torch.from_numpy(ytr.astype(np.float32)).to(device)
yva_t = torch.from_numpy(yva.astype(np.float32)).to(device)
yte_t = torch.from_numpy(yte.astype(np.float32)).to(device)

train_dl = DataLoader(TensorDataset(Xtr_t, ytr_t), batch_size=64, shuffle=True,  drop_last=False)
val_dl   = DataLoader(TensorDataset(Xva_t, yva_t), batch_size=256, shuffle=False, drop_last=False)
test_dl  = DataLoader(TensorDataset(Xte_t, yte_t), batch_size=256, shuffle=False, drop_last=False)

# ---- 1D CNN model (compact & Keras-like) ----
class Torch1DCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv1d(1, 32, 7, padding=3),  nn.BatchNorm1d(32),  nn.ReLU(), nn.MaxPool1d(2),
            nn.Conv1d(32, 64, 5, padding=2), nn.BatchNorm1d(64),  nn.ReLU(), nn.MaxPool1d(2),
            nn.Conv1d(64, 128, 3, padding=1),nn.BatchNorm1d(128), nn.ReLU(), nn.MaxPool1d(2),
            nn.Dropout(0.3),
            nn.AdaptiveAvgPool1d(1)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128, 64), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(64, 1)   # 1 logit for BCEWithLogitsLoss
        )
    def forward(self, x):
        x = self.features(x)
        return self.classifier(x).squeeze(1)  # [N]

model = Torch1DCNN().to(device)

# ---- Imbalance handling: bias init + pos_weight ----
pos = int(np.sum(ytr == 1)); neg = int(np.sum(ytr == 0))
prior = (pos + 1) / (pos + neg + 2)                     # Laplace-smoothed prior in (0,1)
logit_bias = float(np.log(prior / (1.0 - prior)))
with torch.no_grad():
    model.classifier[-1].bias.fill_(logit_bias)

pos_w = torch.tensor([(neg + 1) / (pos + 1)], dtype=torch.float32, device=device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_w)      # balance positives
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

# ---- Train w/ early stopping on val loss ----
has_val = len(val_dl.dataset) > 0
best_loss, best_state, patience, bad = float("inf"), None, 5, 0
epochs = 30

for epoch in range(1, epochs + 1):
    model.train()
    run = 0.0
    for xb, yb in train_dl:
        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        run += loss.item() * xb.size(0)
    train_loss = run / len(train_dl.dataset)

    model.eval()
    if has_val:
        vrun = 0.0
        with torch.no_grad():
            for xb, yb in val_dl:
                vrun += criterion(model(xb), yb).item() * xb.size(0)
        val_loss = vrun / len(val_dl.dataset)
        print(f"Epoch {epoch:02d}/{epochs} - train {train_loss:.4f} - val {val_loss:.4f}")
        metric = val_loss
    else:
        print(f"Epoch {epoch:02d}/{epochs} - train {train_loss:.4f} - val N/A")
        metric = train_loss  # fallback

    if metric < best_loss - 1e-4:
        best_loss = metric
        best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        bad = 0
    else:
        bad += 1
        if bad >= patience:
            break

if best_state is not None:
    model.load_state_dict(best_state)

# ---- Pick threshold on validation (max F1 on PR); fallback to 0.5 if no val ----
if has_val:
    model.eval()
    with torch.no_grad():
        va_logits = torch.cat([model(xb) for xb, _ in val_dl]).cpu().numpy()
    va_prob = 1.0 / (1.0 + np.exp(-va_logits))
    prec, rec, thr = precision_recall_curve(yva, va_prob)
    f1 = 2 * prec * rec / (prec + rec + 1e-9)
    best_thr = float(thr[np.argmax(f1)]) if len(thr) > 0 else 0.5
else:
    best_thr = 0.5
print(f"[Info] Decision threshold: {best_thr:.3f}")

# ---- Final test ----
model.eval()
with torch.no_grad():
    te_logits = torch.cat([model(xb) for xb, _ in test_dl]).cpu().numpy()
te_prob = 1.0 / (1.0 + np.exp(-te_logits))
y_pred = (te_prob >= best_thr).astype(int)

print("=== 1D CNN (raw waveform, binary) — PyTorch ===")
print(classification_report(yte, y_pred, digits=4))


Train counts: {np.int64(0): np.int64(344), np.int64(1): np.int64(260)}
Valid counts: {np.int64(0): np.int64(87), np.int64(1): np.int64(65)}
Test  counts: {np.int64(0): np.int64(185), np.int64(1): np.int64(139)}
Using device: cpu
Epoch 01/30 - train 0.7757 - val 0.7830
Epoch 02/30 - train 0.7252 - val 0.7534
Epoch 03/30 - train 0.6672 - val 0.7263
Epoch 04/30 - train 0.5987 - val 0.7001
Epoch 05/30 - train 0.5483 - val 0.8544
Epoch 06/30 - train 0.4903 - val 0.6715
Epoch 07/30 - train 0.4624 - val 0.5397
Epoch 08/30 - train 0.4406 - val 0.4983
Epoch 09/30 - train 0.4079 - val 0.4770
Epoch 10/30 - train 0.4078 - val 0.7399
Epoch 11/30 - train 0.4058 - val 0.4925
Epoch 12/30 - train 0.3954 - val 0.4690
Epoch 13/30 - train 0.3986 - val 0.5778
Epoch 14/30 - train 0.3774 - val 0.4822
Epoch 15/30 - train 0.3694 - val 0.4420
Epoch 16/30 - train 0.3618 - val 0.5375
Epoch 17/30 - train 0.3698 - val 0.7824
Epoch 18/30 - train 0.3687 - val 0.5786
Epoch 19/30 - train 0.3655 - val 0.4733
Epoch 20/30

## Test: load an ECG file from MIT-BIH, take the first two minutes of data, and run it through our trained models.

In [None]:
import numpy as np
import pandas as pd
import torch

def predict_ecg_segment(
    sig, fs,
    rf=None,                # sklearn RandomForestClassifier (optional)
    cnn_model=None,         # PyTorch 1D CNN (optional)
    device=None,
    win_sec=10, stride_sec=5,
    rcount_min=2,           # minimum R peaks for CNN window (match training)
    thr_bin=0.5             # decision threshold for binary CNN
):
    """
    Slice a raw ECG (1D array) into sliding windows and return per-window predictions
    plus a segment-level summary.

    - RF: uses hand-crafted features (works for binary Normal/Abnormal or tri-class
      Normal/PVC/AFib depending on the model you trained).
    - CNN: uses raw waveform windows (supports binary or tri-class depending on the model).

    Returns:
      out_df  : DataFrame with per-window t0, t1, R_count, rf_pred, cnn_pred,
                and probability columns.
      summary : dict with majority vote, e.g. {'rf_pred_majority': ..., 'cnn_pred_majority': ...}
    """
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    sig = np.asarray(sig, dtype=np.float32).ravel()
    N = len(sig)
    win_samp    = int(win_sec * fs)
    stride_samp = int(stride_sec * fs)

    # --- helper: same preprocessing as training (bandpass + z-score) ---
    def _prep_cnn_seg(seg):
        seg_f = bandpass(seg, fs)
        m, s = float(seg_f.mean()), float(seg_f.std())
        if s <= 0:
            return None
        return ((seg_f - m) / s).astype(np.float32)

    # --- ensure we can produce at least one window ---
    if N <= 0:
        empty = pd.DataFrame(columns=["t0","t1","R_count","rf_pred","cnn_pred"])
        return empty, {"rf_pred_majority": None, "cnn_pred_majority": None}

    if N < win_samp:
        start_indices = [0]
        win_samp = N
        stride_samp = N
    else:
        start_indices = range(0, N - win_samp + 1, stride_samp)

    rows = []
    rf_X,  rf_ix  = [], []
    cnn_X, cnn_ix = [], []

    # --------- windowing, feature extraction, CNN inputs ----------
    for start in start_indices:
        end = start + win_samp
        seg = sig[start:end]
        t0, t1 = start / fs, end / fs

        feats = extract_features(seg, fs)   # [mean_hr, cvrr, rmssd, qrs_w, r_amp, energy, r_count]
        rcount = int(feats[6])

        rows.append({"t0": t0, "t1": t1, "R_count": rcount, "rf_pred": None, "cnn_pred": None})

        if rf is not None:
            rf_X.append([feats[0], feats[1], feats[2], feats[3], feats[4], feats[5], feats[6]])
            rf_ix.append(len(rows) - 1)

        if cnn_model is not None and rcount >= rcount_min:
            seg_z = _prep_cnn_seg(seg)
            if seg_z is not None:
                cnn_X.append(seg_z)
                cnn_ix.append(len(rows) - 1)

    out_df = pd.DataFrame(rows)

    # -------------------- RF inference --------------------
    if rf is not None and len(rf_X) > 0:
        rf_X = np.asarray(rf_X, dtype=np.float32)
        proba = rf.predict_proba(rf_X)                  # [M, K]
        seen  = np.array(rf.classes_, dtype=object)     # e.g., ['Normal','Abnormal'] or ['Normal','PVC','AFib']

        # infer if RF is binary or tri-class
        if set(seen.tolist()) <= {"Normal", "Abnormal"} or len(seen) == 2:
            label_order = np.array(["Normal", "Abnormal"], dtype=object)
        else:
            label_order = np.array(["Normal", "PVC", "AFib"], dtype=object)

        proba_full = np.zeros((proba.shape[0], len(label_order)), dtype=np.float32)
        for j, cls in enumerate(seen):
            k = np.where(label_order == cls)[0]
            if len(k) == 1:
                proba_full[:, k[0]] = proba[:, j]

        rf_idx = proba_full.argmax(1)
        out_df.loc[rf_ix, "rf_pred"] = label_order[rf_idx]
        # probability columns
        for k, lbl in enumerate(label_order):
            out_df.loc[rf_ix, f"rf_p_{lbl}"] = proba_full[:, k]

    # -------------------- CNN inference --------------------
    if cnn_model is not None and len(cnn_X) > 0:
        cnn_model.eval()
        X = np.stack(cnn_X, axis=0)[:, None, :]         # [M, 1, T]
        X_t = torch.tensor(X, dtype=torch.float32, device=device)
        with torch.no_grad():
            logits = cnn_model(X_t)

        # multi-class or binary
        if logits.ndim == 2 and logits.shape[1] > 1:
            # multi-class (e.g., 3 classes: Normal / PVC / AFib)
            prob = torch.softmax(logits, dim=1).cpu().numpy()
            label_order = np.array(["Normal", "PVC", "AFib"], dtype=object)
            pred_idx = prob.argmax(1)
            out_df.loc[cnn_ix, "cnn_pred"] = label_order[pred_idx]
            for k, lbl in enumerate(label_order):
                out_df.loc[cnn_ix, f"cnn_p_{lbl}"] = prob[:, k]
        else:
            # binary (Normal / Abnormal)
            logit = logits.squeeze(1) if logits.ndim == 2 else logits
            p_pos = torch.sigmoid(logit).cpu().numpy().ravel()     # 1 = Abnormal
            pred  = (p_pos >= thr_bin).astype(int)
            out_df.loc[cnn_ix, "cnn_pred"] = np.where(pred == 1, "Abnormal", "Normal")
            out_df.loc[cnn_ix, "cnn_p_Normal"]   = 1.0 - p_pos
            out_df.loc[cnn_ix, "cnn_p_Abnormal"] = p_pos

    # -------------------- segment-level majority vote --------------------
    summary = {}
    for col in ["rf_pred", "cnn_pred"]:
        if col in out_df and out_df[col].notna().any():
            summary[col + "_majority"] = out_df[col].dropna().mode().iat[0]
        else:
            summary[col + "_majority"] = None

    return out_df, summary


In [None]:
import numpy as np
import torch
import wfdb

# ===== Your existing settings (adjust as needed) =====
REC_ID  = "100"        # which record to read from PhysioNet
DB_NAME = "mitdb"      # PhysioNet database name
LEAD    = 0            # which lead/channel to use
WIN, STRIDE, RCOUNT_MIN = 10, 5, 2   # window (s), stride (s), min R peaks for CNN
START_SEC, DUR_SEC = 0, 120          # start at 0s, take 120s (2 minutes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ===== Read record online from PhysioNet =====
rec = wfdb.rdrecord(REC_ID, pn_dir=DB_NAME)
fs  = rec.fs
sig = rec.p_signal[:, LEAD].astype(np.float32)

# ===== Take a continuous segment (default: 2 minutes) =====
start = int(START_SEC * fs)
end   = start + int(DUR_SEC * fs)
seg   = sig[start:end]

# If the segment is shorter than one window, pad to at least one window
need = int(WIN * fs) - len(seg)
if need > 0:
    seg = np.pad(seg, (0, need), mode="edge")

# ===== Inference (use RF/CNN as you like; set to None to skip) =====
df_pred, summary = predict_ecg_segment(
    seg, fs,
    rf=clf,                 # your trained RandomForest; set to None if not using RF
    cnn_model=model,        # your trained PyTorch 1D CNN; set to None if not using CNN
    device=device,
    win_sec=WIN, stride_sec=STRIDE, rcount_min=RCOUNT_MIN,
    thr_bin=0.5             # decision threshold if your CNN is binary
)

# ===== Results =====
print("=== Summary (majority vote) ===")
print(summary)

print("\n=== First few windows ===")
show_cols = [c for c in ["t0","t1","R_count","rf_pred","cnn_pred"] if c in df_pred.columns]
print(df_pred[show_cols].head())

if "rf_pred" in df_pred:
    print("\n=== Counts (RF) ===")
    print(df_pred["rf_pred"].value_counts(dropna=True))

if "cnn_pred" in df_pred:
    print("\n=== Counts (CNN) ===")
    print(df_pred["cnn_pred"].value_counts(dropna=True))



Learning initial signal parameters...
Found 8 beats during learning. Initializing using learned parameters
Running QRS detection...
QRS detection complete.
Learning initial signal parameters...
Found 8 beats during learning. Initializing using learned parameters
Running QRS detection...
QRS detection complete.
Learning initial signal parameters...
Found 8 beats during learning. Initializing using learned parameters
Running QRS detection...
QRS detection complete.
Learning initial signal parameters...
Found 8 beats during learning. Initializing using learned parameters
Running QRS detection...
QRS detection complete.
Learning initial signal parameters...
Found 8 beats during learning. Initializing using learned parameters
Running QRS detection...
QRS detection complete.
Learning initial signal parameters...
Found 8 beats during learning. Initializing using learned parameters
Running QRS detection...
QRS detection complete.
Learning initial signal parameters...
Found 8 beats during learn



## Multiple Classifcation (Normal/ PVC/ AFib)

In [None]:
# patients' ID

# MITDB ID
mit_ids = ["106", "201", "203"]
# AFDB ID
af_ids  = []


USE_LOCAL = False
MIT_LOCAL = "./mitdb_local"
AF_LOCAL  = "./afdb_local"

import numpy as np
import pandas as pd
import wfdb.processing as wproc
from scipy.signal import butter, filtfilt
from scipy.stats import variation

def bandpass(x, fs, lo=0.5, hi=40.0, order=3):  # Filter out noise
    b, a = butter(order, [lo/(fs/2), hi/(fs/2)], btype='band')
    return filtfilt(b, a, x)

def extract_features(seg, fs):
    seg_f = bandpass(seg, fs)

    # Detect R peak
    try:
        r_idx = wproc.xqrs_detect(sig=seg_f, fs=fs)
    except Exception:
        r_idx = np.array([], dtype=int)

    if len(r_idx) > 1:
        rr = np.diff(r_idx) / fs
        mean_hr = float(60.0 / rr.mean())
        cvrr    = float(variation(rr)) if rr.std() > 0 else 0.0
        rmssd   = float(np.sqrt(np.mean(np.square(np.diff(rr))))) if len(rr) > 2 else 0.0
    else:
        mean_hr = cvrr = rmssd = 0.0

    qrs_w = float((np.percentile(np.diff(r_idx),95)/fs)) if len(r_idx)>2 else 0.0
    r_amp = float(seg_f[r_idx].mean()) if len(r_idx)>0 else 0.0
    energy= float(np.sum(seg_f**2)/len(seg_f))
    return [mean_hr, cvrr, rmssd, qrs_w, r_amp, energy, int(len(r_idx))]

In [None]:
def window_label(t0, t1, ann, fs, pvc_ratio_threshold=0.2):

    start, end = int(t0 * fs), int(t1 * fs)
    m = (ann.sample >= start) & (ann.sample < end)
    beats = np.array(ann.symbol)[m] if np.any(m) else np.array([])

    # Rhythm annotations (aux_note)
    aux = [a for i, a in enumerate(ann.aux_note) if a and (start <= ann.sample[i] < end)]
    aux_norm = [a.upper() for a in aux]  # normalize to uppercase strings

    # 1) Check for critical rhythm-level events
    # If AFIB is present, return AF (keep AF as a separate label)
    if any("AFIB" in a for a in aux_norm):
          return "AFib"

    # 2) PVC detection (based on ratio and count)
    if len(beats) > 0:
        pvc_count = np.sum(beats == 'V')
        pvc_ratio = pvc_count / len(beats)
        if pvc_ratio >= pvc_ratio_threshold and pvc_count >= 2:
            return "PVC"

    # 4) Otherwise, label as Normal
    return "Normal"


In [None]:
# ---- Sliding window config ----
WIN, STRIDE = 10, 5
LEAD        = 0
RCOUNT_MIN  = 2

rows = []

# === For CNN (tri-class) ===
X_raw = []          # [N, T] z-scored waveform
y_tri = []          # [N]   0:Normal, 1:PVC, 2:AFib
META  = []          # for record-level split (db, rid)
label2id = {"Normal": 0, "PVC": 1, "AFib": 2}


datasets = [
    ("mitdb", mit_ids, MIT_LOCAL),
    ("afdb",  af_ids,  AF_LOCAL),
]

# ---- Load -> Window -> Features -> Label ----
for db_name, ids, local_dir in datasets:
    for rid in ids:

        if USE_LOCAL:
            rec_path = os.path.join(local_dir, str(rid))
            rec = wfdb.rdrecord(rec_path)
            ann = wfdb.rdann(rec_path, "atr")
        else:
            rec = wfdb.rdrecord(str(rid), pn_dir=db_name)
            ann = wfdb.rdann(str(rid), "atr", pn_dir=db_name)

        fs  = rec.fs
        sig = rec.p_signal[:, LEAD]

        win_samp, stride_samp = int(WIN*fs), int(STRIDE*fs)
        for start in range(0, len(sig)-win_samp+1, stride_samp):
            end = start + win_samp
            seg = sig[start:end]
            t0, t1 = start/fs, end/fs

            feats = extract_features(seg, fs)
            label = window_label(t0, t1, ann, fs) # "Normal" / "PVC" / "AFib" / "Abnormal"

            rows.append({
                "db": db_name, "record": str(rid),
                "t0": t0, "t1": t1,
                "mean_HR": feats[0], "CVRR": feats[1], "RMSSD": feats[2],
                "QRS_width": feats[3], "R_amp": feats[4], "Energy": feats[5],
                "R_count": feats[6], "label": label
            })
            # === collect raw window for CNN (tri-class) ===
            if (feats[6] >= RCOUNT_MIN) and (label in label2id):
                seg_f = bandpass(seg, fs)
                m, s  = np.mean(seg_f), np.std(seg_f)
                if s > 0:
                    seg_z = (seg_f - m) / s
                    X_raw.append(seg_z.astype(np.float32))
                    y_tri.append(label2id[label])
                    META.append((db_name, str(rid)))

# ---- Build DataFrame & keep tri-class only ----
df = pd.DataFrame(rows)


df = df[df["R_count"] >= RCOUNT_MIN].reset_index(drop=True)

df = df[df["label"].isin(["Normal", "PVC", "AFib"])].reset_index(drop=True)

print("Total windows kept:", len(df))
print(df["label"].value_counts())


feature_cols = ["mean_HR","CVRR","RMSSD","QRS_width","R_amp","Energy","R_count"]
X = df[feature_cols]
y = df["label"]

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.30, stratify=y, random_state=42)


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Running QRS detection...
QRS detection complete.
Learning initial signal parameters...
Found 8 beats during learning. Initializing using learned parameters
Running QRS detection...
QRS detection complete.
Learning initial signal parameters...
Failed to find 8 beats during learning.
Initializing using default parameters
Running QRS detection...
QRS detection complete.
Learning initial signal parameters...
Failed to find 8 beats during learning.
Initializing using default parameters
Running QRS detection...
QRS detection complete.
Learning initial signal parameters...
Failed to find 8 beats during learning.
Initializing using default parameters
Running QRS detection...
QRS detection complete.
Learning initial signal parameters...
Failed to find 8 beats during learning.
Initializing using default parameters
Running QRS detection...
QRS detection complete.
Learning initial signal parameters...
Failed to find 8 beats during le

In [None]:
# --- SMOTE for AFib + PVC (TRAIN set only) ---
from imblearn.over_sampling import SMOTE
from collections import Counter

print("Before SMOTE:", Counter(y_train))

# Upsample both AFib and PVC to the size of the largest class in TRAIN
max_n = y_train.value_counts().max()
need = {}
for cls in ["AFib", "PVC"]:
    if (y_train == cls).any():              # only resample classes present in train
        need[cls] = max_n                   # or use int(0.8 * max_n) for partial balancing

if need:
    # Use small k for very small classes to avoid errors
    sm = SMOTE(sampling_strategy=need, k_neighbors=1, random_state=42)
    X_tr_bal, y_tr_bal = sm.fit_resample(X_train, y_train)
else:
    print("No AFib/PVC in training split — skipping SMOTE.")
    X_tr_bal, y_tr_bal = X_train, y_train

print("After SMOTE:", Counter(y_tr_bal))

# (optional) keep column names for debugging/feature importance
# X_tr_bal = pd.DataFrame(X_tr_bal, columns=feature_cols)
# y_tr_bal = pd.Series(y_tr_bal, name="label")

# ---- Train on the SMOTE-balanced TRAIN set (do NOT use class_weight) ----
from sklearn.ensemble import RandomForestClassifier

clf = RandomForestClassifier(
    n_estimators=500,
    max_depth=None,
    class_weight=None,     # important: avoid double compensation
    random_state=42,
    n_jobs=-1
)
clf.fit(X_tr_bal, y_tr_bal)

Before SMOTE: Counter({'Normal': 458, 'PVC': 267, 'AFib': 31})
After SMOTE: Counter({'Normal': 458, 'PVC': 458, 'AFib': 458})


In [None]:
# ---- Evaluate on the untouched test set ----
y_pred = clf.predict(X_test)
print("\n=== Random Forest (oversampled train) ===")
print(classification_report(y_test, y_pred, digits=4))
print("Confusion matrix:\n",
      confusion_matrix(y_test, y_pred, labels=["Normal","PVC","AFib"]))


=== Random Forest (oversampled train) ===
              precision    recall  f1-score   support

        AFib     0.3500    0.5385    0.4242        13
      Normal     0.9337    0.9337    0.9337       196
         PVC     0.8981    0.8435    0.8700       115

    accuracy                         0.8858       324
   macro avg     0.7273    0.7719    0.7426       324
weighted avg     0.8976    0.8858    0.8906       324

Confusion matrix:
 [[183  10   3]
 [  8  97  10]
 [  5   1   7]]


## CNN

In [None]:
X_raw = np.asarray(X_raw, dtype=np.float32)   # [N, T]
y_tri = np.asarray(y_tri, dtype=np.int64)     # [N]
META  = np.asarray(META)

print("Collected windows:", len(y_tri))
if len(y_tri) == 0:
    raise RuntimeError("No windows collected. Add AFDB IDs or relax filters.")

# CNN expects [N, 1, T]
X_raw = X_raw[:, None, :]

Collected windows: 1080


In [None]:
# -------------------- Record-level split (avoid leakage) --------------------
recs = np.array([f"{d}:{r}" for d, r in META])
uniq = np.unique(recs)
train_recs, test_recs = train_test_split(uniq, test_size=0.30, random_state=42)
train_recs, val_recs  = train_test_split(train_recs, test_size=0.20, random_state=42)

tr_idx = np.isin(recs, train_recs)
va_idx = np.isin(recs, val_recs)
te_idx = np.isin(recs, test_recs)

Xtr, Xva, Xte = X_raw[tr_idx], X_raw[va_idx], X_raw[te_idx]
ytr, yva, yte = y_tri[tr_idx], y_tri[va_idx], y_tri[te_idx]

print("Train class counts:", np.bincount(ytr, minlength=3))
print("Valid class counts:", np.bincount(yva, minlength=3))
print("Test  class counts:", np.bincount(yte, minlength=3))

# -------------------- Oversample TRAIN (resample PVC & AFib) --------------------
def oversample_by_index(X, y, classes_to_upsample=(1, 2), target="max", seed=42):
    """
    Oversample chosen classes by duplicating indices.
    X: [N, 1, T], y: [N], classes_to_upsample: classes to upsample (1:PVC, 2:AFib)
    target: 'max' -> match the largest class count;
            float in (0,1] -> ratio of max (e.g., 0.8);
            int -> fixed target count per chosen class.
    """
    rng = np.random.default_rng(seed)
    y = np.asarray(y)
    counts = np.bincount(y, minlength=3)
    max_n = int(counts.max())

    if target == "max":
        tgt = max_n
    elif isinstance(target, (float, np.floating)):
        tgt = int(max_n * float(target))
    else:
        tgt = int(target)

    idx_all = []
    for c in range(3):
        idx_c = np.where(y == c)[0]
        # upsample only chosen classes and only if below target
        if (c in classes_to_upsample) and (len(idx_c) > 0) and (len(idx_c) < tgt):
            extra = rng.choice(idx_c, size=tgt - len(idx_c), replace=True)
            idx_c = np.concatenate([idx_c, extra])
        idx_all.append(idx_c)

    idx_bal = np.concatenate(idx_all)
    rng.shuffle(idx_bal)
    return X[idx_bal], y[idx_bal]

Xtr_bal, ytr_bal = oversample_by_index(Xtr, ytr, classes_to_upsample=(1, 2), target="max", seed=42)
print("After oversample (train):", np.bincount(ytr_bal, minlength=3))

# -------------------- Tensors & DataLoaders --------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Xtr_t = torch.tensor(Xtr_bal, dtype=torch.float32).to(device)
ytr_t = torch.tensor(ytr_bal, dtype=torch.long).to(device)
Xva_t = torch.tensor(Xva,     dtype=torch.float32).to(device)
yva_t = torch.tensor(yva,     dtype=torch.long).to(device)
Xte_t = torch.tensor(Xte,     dtype=torch.float32).to(device)
yte_t = torch.tensor(yte,     dtype=torch.long).to(device)

train_dl = DataLoader(TensorDataset(Xtr_t, ytr_t), batch_size=64, shuffle=True,  drop_last=False)
val_dl   = DataLoader(TensorDataset(Xva_t, yva_t), batch_size=256, shuffle=False, drop_last=False)
test_dl  = DataLoader(TensorDataset(Xte_t, yte_t), batch_size=256, shuffle=False, drop_last=False)

# -------------------- 1D-CNN Model --------------------
class ECG1DCNN(nn.Module):
    def __init__(self, n_classes=3):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv1d(1, 32, 7, padding=3), nn.BatchNorm1d(32), nn.ReLU(), nn.MaxPool1d(2),
            nn.Conv1d(32, 64, 5, padding=2), nn.BatchNorm1d(64), nn.ReLU(), nn.MaxPool1d(2),
            nn.Conv1d(64, 128, 3, padding=1), nn.BatchNorm1d(128), nn.ReLU(), nn.MaxPool1d(2),
            nn.Dropout(0.3),
        )
        self.gap = nn.AdaptiveAvgPool1d(1)  # ≈ GlobalAveragePooling1D
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128, 64), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(64, n_classes)  # logits
        )
    def forward(self, x):
        x = self.features(x)
        x = self.gap(x)
        return self.classifier(x)

model = ECG1DCNN(n_classes=3).to(device)
criterion = nn.CrossEntropyLoss()                        # already balanced via oversampling
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

# -------------------- Train with Early Stopping (val loss) --------------------
best_loss, best_state, patience, bad = float("inf"), None, 5, 0
for epoch in range(1, 31):
    # train
    model.train()
    run = 0.0
    for xb, yb in train_dl:
        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        run += loss.item() * xb.size(0)
    train_loss = run / len(train_dl.dataset)

    # validate
    model.eval()
    vrun = 0.0
    with torch.no_grad():
        for xb, yb in val_dl:
            vrun += criterion(model(xb), yb).item() * xb.size(0)
    val_loss = vrun / len(val_dl.dataset)
    print(f"Epoch {epoch:02d}/30 - train {train_loss:.4f} - val {val_loss:.4f}")

    if val_loss < best_loss - 1e-4:
        best_loss = val_loss
        best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        bad = 0
    else:
        bad += 1
        if bad >= patience:
            break

if best_state is not None:
    model.load_state_dict(best_state)

# -------------------- Evaluate on TEST --------------------
model.eval()
with torch.no_grad():
    logits = torch.cat([model(xb) for xb, _ in test_dl], dim=0).cpu().numpy()
y_pred = logits.argmax(1)

print("\n=== PyTorch 1D CNN (Normal / PVC / AFib) ===")
print(classification_report(yte, y_pred, target_names=["Normal","PVC","AFib"], digits=4))
print("Confusion matrix (labels=[0,1,2]):\n", confusion_matrix(yte, y_pred, labels=[0,1,2]))



Train class counts: [222 133   5]
Valid class counts: [249  72  39]
Test  class counts: [183 177   0]
After oversample (train): [222 222 222]
Epoch 01/30 - train 1.0723 - val 1.1162
Epoch 02/30 - train 1.0077 - val 1.1868
Epoch 03/30 - train 0.9239 - val 1.2163
Epoch 04/30 - train 0.8425 - val 1.2025
Epoch 05/30 - train 0.7367 - val 1.1834
Epoch 06/30 - train 0.6463 - val 1.4665

=== PyTorch 1D CNN (Normal / PVC / AFib) ===
              precision    recall  f1-score   support

      Normal     0.0000    0.0000    0.0000       183
         PVC     0.4930    1.0000    0.6604       177
        AFib     0.0000    0.0000    0.0000         0

    accuracy                         0.4917       360
   macro avg     0.1643    0.3333    0.2201       360
weighted avg     0.2424    0.4917    0.3247       360

Confusion matrix (labels=[0,1,2]):
 [[  0 182   1]
 [  0 177   0]
 [  0   0   0]]


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [None]:
# ===================== PyTorch 1D CNN (tri-class) =====================
import numpy as np, torch, torch.nn as nn, torch.optim as optim, os
from torch.utils.data import TensorDataset, DataLoader, WeightedRandomSampler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

# --- to tensors / shapes ---
X_raw = np.asarray(X_raw, dtype=np.float32)   # [N, T]
y_tri = np.asarray(y_tri, dtype=np.int64)     # [N]
META  = np.asarray(META)

if X_raw.ndim != 2 or len(y_tri) == 0:
    raise RuntimeError("No CNN windows collected. Make sure AFib/PVC exist and RCOUNT_MIN is satisfied.")

X_raw = X_raw[:, None, :]  # -> [N, 1, T]  (PyTorch Conv1d expects [N, C, T])

# --- record-level split to avoid leakage ---
recs = np.array([f"{d}:{r}" for d, r in META])
uniq = np.unique(recs)
train_recs, test_recs = train_test_split(uniq, test_size=0.30, random_state=42)
train_recs, val_recs  = train_test_split(train_recs, test_size=0.20, random_state=42)

tr_idx = np.isin(recs, train_recs)
va_idx = np.isin(recs, val_recs)
te_idx = np.isin(recs, test_recs)

Xtr, Xva, Xte = X_raw[tr_idx], X_raw[va_idx], X_raw[te_idx]
ytr, yva, yte = y_tri[tr_idx], y_tri[va_idx], y_tri[te_idx]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def to_tensor(x, y):
    return torch.tensor(x, dtype=torch.float32).to(device), torch.tensor(y, dtype=torch.long).to(device)

Xtr_t, ytr_t = to_tensor(Xtr, ytr)
Xva_t, yva_t = to_tensor(Xva, yva)
Xte_t, yte_t = to_tensor(Xte, yte)

train_ds = TensorDataset(Xtr_t, ytr_t)
val_ds   = TensorDataset(Xva_t, yva_t)
test_ds  = TensorDataset(Xte_t, yte_t)

# --- balance training with a WeightedRandomSampler (batch-level balancing) ---
num_classes = 3
cls_counts = np.bincount(ytr, minlength=num_classes)
cls_weights = 1.0 / np.maximum(cls_counts, 1)
w_per_sample = cls_weights[ytr]
sampler = WeightedRandomSampler(torch.tensor(w_per_sample, dtype=torch.float32), num_samples=len(w_per_sample), replacement=True)

train_dl = DataLoader(train_ds, batch_size=64, sampler=sampler, drop_last=False)
val_dl   = DataLoader(val_ds, batch_size=256, shuffle=False)
test_dl  = DataLoader(test_ds, batch_size=256, shuffle=False)

# --- 1D CNN model (mirrors your Keras骨架) ---
class ECG1DCNN(nn.Module):
    def __init__(self, n_classes=3):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv1d(1, 32, 7, padding=3), nn.BatchNorm1d(32), nn.ReLU(), nn.MaxPool1d(2),
            nn.Conv1d(32, 64, 5, padding=2), nn.BatchNorm1d(64), nn.ReLU(), nn.MaxPool1d(2),
            nn.Conv1d(64, 128, 3, padding=1), nn.BatchNorm1d(128), nn.ReLU(), nn.MaxPool1d(2),
            nn.Dropout(0.3),
        )
        self.gap = nn.AdaptiveAvgPool1d(1)  # ≈ GlobalAveragePooling1D
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128, 64), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(64, n_classes)        # logits
        )
    def forward(self, x):
        x = self.features(x)
        x = self.gap(x)
        x = self.classifier(x)
        return x

model = ECG1DCNN(n_classes=3).to(device)
criterion = nn.CrossEntropyLoss()                      # use sampler; no extra class weights
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

# --- training loop with early stopping (on val loss) ---
best_loss, best_state, patience, bad = float("inf"), None, 5, 0
for epoch in range(1, 31):
    # train
    model.train()
    run = 0.0
    for xb, yb in train_dl:
        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        run += loss.item() * xb.size(0)
    train_loss = run / len(train_ds)

    # validate
    model.eval()
    vrun = 0.0
    with torch.no_grad():
        for xb, yb in val_dl:
            vrun += criterion(model(xb), yb).item() * xb.size(0)
    val_loss = vrun / len(val_ds)
    print(f"Epoch {epoch:02d}/30 - train {train_loss:.4f} - val {val_loss:.4f}")

    if val_loss < best_loss - 1e-4:
        best_loss, best_state, bad = val_loss, {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}, 0
    else:
        bad += 1
        if bad >= patience:
            break

if best_state is not None:
    model.load_state_dict(best_state)

# --- evaluation (test) ---
model.eval()
with torch.no_grad():
    logits = torch.cat([model(xb) for xb, _ in test_dl], dim=0).cpu().numpy()
y_pred = logits.argmax(1)

print("\n=== PyTorch 1D CNN (Normal / PVC / AFib) ===")
print(classification_report(yte, y_pred, target_names=["Normal","PVC","AFib"], digits=4))
print("Confusion matrix (labels=[0,1,2]):\n", confusion_matrix(yte, y_pred, labels=[0,1,2]))

# =====================================================================


Epoch 01/30 - train 1.0870 - val 1.1523
Epoch 02/30 - train 1.0502 - val 1.1363
Epoch 03/30 - train 1.0117 - val 1.1445
Epoch 04/30 - train 0.9658 - val 1.1538
Epoch 05/30 - train 0.9260 - val 1.1374
Epoch 06/30 - train 0.8506 - val 1.0868
Epoch 07/30 - train 0.8201 - val 1.1057
Epoch 08/30 - train 0.7579 - val 1.0872
Epoch 09/30 - train 0.6984 - val 1.3838
Epoch 10/30 - train 0.6716 - val 1.3210
Epoch 11/30 - train 0.6221 - val 1.1927

=== PyTorch 1D CNN (Normal / PVC / AFib) ===
              precision    recall  f1-score   support

      Normal     1.0000    0.3443    0.5122       183
         PVC     0.8855    0.6554    0.7532       177
        AFib     0.0000    0.0000    0.0000         0

    accuracy                         0.4972       360
   macro avg     0.6285    0.3332    0.4218       360
weighted avg     0.9437    0.4972    0.6307       360

Confusion matrix (labels=[0,1,2]):
 [[ 63  15 105]
 [  0 116  61]
 [  0   0   0]]


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
