In [None]:
import os, ast, math, warnings
from tqdm import tqdm
import numpy as np
import pandas as pd
import wfdb

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

from sklearn.metrics import roc_auc_score, f1_score, classification_report
from torch.optim import AdamW

from torch_geometric.data import Data, DataLoader as GeoDataLoader, Batch
from torch_geometric.nn import SAGEConv, global_mean_pool

# AMP compatibility (works with older and newer PyTorch)
try:
    from torch.amp import autocast as _autocast, GradScaler as _GradScaler
except Exception:
    from torch.cuda.amp import autocast as _autocast, GradScaler as _GradScaler

warnings.filterwarnings("ignore")

In [None]:
BASE_PATH = "/kaggle/input/ptb-xl-dataset/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.1/"
METADATA_CSV = os.path.join(BASE_PATH, "ptbxl_database.csv")
SCP_FILE = os.path.join(BASE_PATH, "scp_statements.csv")
TARGET_SR = 100
BATCH_SIZE = 32
EPOCHS = 20
LR = 1e-3
WEIGHT_DECAY = 1e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_WORKERS = 2
PATIENCE = 7  # early stopping patience

print("Device:", DEVICE)

In [None]:
scp_df = pd.read_csv(SCP_FILE, index_col=0)
diagnostic_scp = scp_df[scp_df.diagnostic == 1]
classes = sorted(diagnostic_scp.diagnostic_class.unique().tolist())
NUM_CLASSES = len(classes)
print("Classes:", classes)

def aggregate_superclass_from_scp_field(scp_codes_field):
    try:
        d = ast.literal_eval(scp_codes_field) if isinstance(scp_codes_field, str) else scp_codes_field
    except Exception:
        return []
    out = []
    for k in (d.keys() if isinstance(d, dict) else []):
        if k in diagnostic_scp.index:
            out.append(diagnostic_scp.loc[k].diagnostic_class)
    return list(set(out))

In [None]:
def create_anatomical_adjacency():
    rel = [
        (0,1),(1,2),
        (0,4),(1,4),(2,5),(0,3),(3,4),(4,5),
        (6,7),(7,8),(8,9),(9,10),(10,11),
        (1,6),(5,10),(4,11)
    ]
    edges = []
    for a,b in rel:
        edges.append([a,b]); edges.append([b,a])
    return torch.tensor(edges, dtype=torch.long).t().contiguous()

EDGE_INDEX = create_anatomical_adjacency()

In [None]:
def file_exists_for_row(base_path, fname):
    # fname may already contain 'records100/00000/00001_lr.dat' or without extension
    full = os.path.join(base_path, fname)
    base, ext = os.path.splitext(full)
    # Check base.dat + base.hea
    if os.path.exists(base + ".dat") and os.path.exists(base + ".hea"):
        return True, base
    # maybe 'fname' already had .dat extension included
    if os.path.exists(full) and (full.endswith(".dat") or full.endswith(".hea")):
        # normalize to base without extension
        base2 = os.path.splitext(full)[0]
        if os.path.exists(base2 + ".dat") and os.path.exists(base2 + ".hea"):
            return True, base2
    # fallback: return False and proposed base
    return False, base

In [None]:
class PTBXLGraphDataset(Dataset):
    def __init__(self, df: pd.DataFrame, base_path: str, use_lr: bool = True, augment: bool=False):
        self.rows = []
        missing = 0
        for _, row in df.iterrows():
            fname = row['filename_lr'] if use_lr else row['filename_hr']
            ok, basefile = file_exists_for_row(base_path, fname)
            if ok:
                self.rows.append((row, basefile))
            else:
                missing += 1
        if missing>0:
            print(f"Warning: {missing} records missing files (skipped).")
        self.edge_index = EDGE_INDEX
        self.augment = augment

    def __len__(self): return len(self.rows)

    def __getitem__(self, idx):
        row, basefile = self.rows[idx]
        # wfdb accepts basefile without extension
        rec, _ = wfdb.rdsamp(basefile)
        sig = rec.astype(np.float32)  # (T, 12)
        # some records may be transposed; ensure shape (T,12)
        if sig.ndim == 1:
            sig = sig[:, None]
        if sig.shape[1] != 12 and sig.shape[0] == 12:
            sig = sig.T
        if sig.shape[1] != 12:
            # pad/truncate to 12 leads
            if sig.shape[1] < 12:
                sig = np.pad(sig, ((0,0),(0, 12 - sig.shape[1])), mode='constant')
            else:
                sig = sig[:, :12]

        # normalize per lead
        sig = (sig - sig.mean(axis=0)) / (sig.std(axis=0) + 1e-8)
        processed = sig.T  # (12, T)

        # augment (mild)
        if self.augment:
            if np.random.rand() < 0.25:
                processed = processed + 0.01 * np.random.randn(*processed.shape)
            if np.random.rand() < 0.25:
                shift = int(np.random.randint(-int(0.05 * processed.shape[1]), int(0.05 * processed.shape[1]) + 1))
                processed = np.roll(processed, shift, axis=1)

        # target: try 'diagnostic_superclass' field first (if stored), else build from scp_codes
        labels_list = []
        if 'diagnostic_superclass' in row.index and isinstance(row['diagnostic_superclass'], (list, tuple, np.ndarray)):
            labels_list = list(row['diagnostic_superclass'])
        else:
            try:
                labels_list = aggregate_superclass_from_scp_field(row['scp_codes'])
            except Exception:
                labels_list = []

        y = np.zeros(NUM_CLASSES, dtype=np.float32)
        for c in labels_list:
            if c in classes:
                y[classes.index(c)] = 1.0

        # Data.x: shape (num_nodes=12, seq_len)
        data = Data(x=torch.tensor(processed, dtype=torch.float), edge_index=self.edge_index, y=torch.tensor(y, dtype=torch.float))
        return data

In [None]:
class Residual1D(nn.Module):
    def __init__(self, in_ch, out_ch, kernel=7):
        super().__init__()
        pad = (kernel-1)//2
        self.conv = nn.Sequential(
            nn.Conv1d(in_ch, out_ch, kernel, padding=pad),
            nn.BatchNorm1d(out_ch),
            nn.ReLU(),
            nn.Conv1d(out_ch, out_ch, kernel, padding=pad),
            nn.BatchNorm1d(out_ch),
        )
        self.res = nn.Conv1d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
        self.act = nn.ReLU()

    def forward(self, x):
        return self.act(self.conv(x) + self.res(x))

class PerNodeCNN(nn.Module):
    def __init__(self, out_dim=48):
        super().__init__()
        self.net = nn.Sequential(
            Residual1D(1, 16),
            Residual1D(16, 32),
            Residual1D(32, 48)
        )
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(48, out_dim)

    def forward(self, x):
        # x: (total_nodes, seq_len)
        x = x.unsqueeze(1)          # (total_nodes, 1, seq_len)
        x = self.net(x)            # (total_nodes, channels, seq_len)
        x = self.pool(x).squeeze(-1)  # (total_nodes, channels)
        return self.fc(x)          # (total_nodes, out_dim)

class GNNHybrid(nn.Module):
    def __init__(self, num_classes, node_feat_dim=48, gnn_hidden=64):
        super().__init__()
        self.node_cnn = PerNodeCNN(node_feat_dim)
        self.gnn1 = SAGEConv(node_feat_dim, gnn_hidden)
        self.gnn2 = SAGEConv(gnn_hidden, gnn_hidden)
        self.classifier = nn.Sequential(nn.Linear(gnn_hidden, 128), nn.ReLU(), nn.Dropout(0.25), nn.Linear(128, num_classes))

    def forward(self, data):
        # data.x: (total_nodes, seq_len)
        x = self.node_cnn(data.x.to(DEVICE))   # (total_nodes, node_feat_dim)
        edge_index = data.edge_index.to(DEVICE)
        batch = data.batch.to(DEVICE) if hasattr(data, 'batch') else torch.zeros(x.size(0), dtype=torch.long, device=DEVICE)
        x = F.relu(self.gnn1(x, edge_index))
        x = F.relu(self.gnn2(x, edge_index))
        g = global_mean_pool(x, batch)   # (batch_size, gnn_hidden)
        out = self.classifier(g)
        return out

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1.0, gamma=2.0):
        super().__init__()
        self.alpha, self.gamma = alpha, gamma
        self.bce = nn.BCEWithLogitsLoss(reduction='none')
    def forward(self, logits, targets):
        bce = self.bce(logits, targets)
        pt = torch.exp(-bce)
        loss = self.alpha * (1 - pt)**self.gamma * bce
        return loss.mean()

def reshape_batch_y_for_loss(batch_y, num_graphs):
    """Ensure batch_y becomes (num_graphs, num_classes). Robust to PyG batching quirks."""
    if not isinstance(batch_y, torch.Tensor):
        batch_y = torch.tensor(batch_y)
    # If already shaped (num_graphs, num_classes) -> good
    if batch_y.dim() == 2 and batch_y.shape[0] == num_graphs:
        return batch_y
    # If 1D flattened of length num_graphs * num_classes
    if batch_y.dim() == 1 and batch_y.numel() == num_graphs * NUM_CLASSES:
        return batch_y.view(num_graphs, NUM_CLASSES)
    # If 2D but shape (num_graphs * something, ) -> reshape
    try:
        return batch_y.reshape(num_graphs, -1)
    except Exception:
        # As last resort, repeat/trim
        y = batch_y.flatten()
        y = y[:num_graphs * NUM_CLASSES]
        return y.view(num_graphs, NUM_CLASSES)

def evaluate_model(model, loader, thresholds=None):
    model.eval()
    Ys, Ps = [], []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(DEVICE)
            # model expects a Batch
            try:
                ctx = _autocast(device_type=DEVICE.type, enabled=(DEVICE.type == "cuda"))
            except TypeError:
                ctx = _autocast(enabled=(DEVICE.type == "cuda"))
            with ctx:
                logits = model(batch)
                probs = torch.sigmoid(logits).cpu().numpy()
            # obtain y shape robustly
            y_np = batch.y.cpu().numpy()
            if y_np.ndim == 1 and y_np.size == batch.num_graphs * NUM_CLASSES:
                y_np = y_np.reshape(batch.num_graphs, NUM_CLASSES)
            elif y_np.ndim == 2 and y_np.shape[0] != batch.num_graphs:
                y_np = y_np.reshape(batch.num_graphs, -1)
            Ys.append(y_np)
            Ps.append(probs)
    if len(Ys) == 0:
        return float('nan'), float('nan'), None, None
    Y = np.vstack(Ys); P = np.vstack(Ps)
    # AUROC robust
    try:
        auc = roc_auc_score(Y, P, average='macro')
    except Exception:
        aucs = []
        for i in range(Y.shape[1]):
            try:
                aucs.append(roc_auc_score(Y[:,i], P[:,i]))
            except Exception:
                pass
        auc = float(np.mean(aucs)) if len(aucs) else float('nan')
    if thresholds is None:
        preds = (P > 0.5).astype(int)
    else:
        preds = (P > np.array(thresholds)).astype(int)
    f1 = f1_score(Y, preds, average='macro', zero_division=0)
    return auc, f1, Y, P

def tune_thresholds(y_true, y_pred):
    thresholds = []
    for i in range(y_true.shape[1]):
        best_t, best_f1 = 0.5, -1.0
        for t in np.linspace(0.1, 0.9, 17):
            f = f1_score(y_true[:,i], (y_pred[:,i] > t).astype(int), zero_division=0)
            if f > best_f1:
                best_f1, best_t = f, t
        thresholds.append(best_t)
    return thresholds

In [None]:
def build_loaders(metadata_csv=METADATA_CSV, base_path=BASE_PATH):
    df = pd.read_csv(metadata_csv, index_col='ecg_id')
    # ensure diagnostic_superclass exists
    if 'diagnostic_superclass' not in df.columns:
        df['diagnostic_superclass'] = df['scp_codes'].apply(aggregate_superclass_from_scp_field)
    # prefer official folds if available
    if 'strat_fold' in df.columns:
        train_df = df[df.strat_fold.isin(range(1,9))].copy()
        val_df   = df[df.strat_fold == 9].copy()
        test_df  = df[df.strat_fold == 10].copy()
    else:
        # fallback random splits
        from sklearn.model_selection import train_test_split
        train_val, test_df = train_test_split(df, test_size=0.1, random_state=42, shuffle=True)
        train_df, val_df = train_test_split(train_val, test_size=0.1, random_state=42, shuffle=True)

    train_ds = PTBXLGraphDataset(train_df, base_path, use_lr=True, augment=True)
    val_ds   = PTBXLGraphDataset(val_df, base_path, use_lr=True, augment=False)
    test_ds  = PTBXLGraphDataset(test_df, base_path, use_lr=True, augment=False)

    train_loader = GeoDataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    val_loader = GeoDataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
    test_loader = GeoDataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
    return train_loader, val_loader, test_loader

In [None]:
def train_and_save():
    train_loader, val_loader, test_loader = build_loaders()
    model = GNNHybrid(num_classes=NUM_CLASSES).to(DEVICE)
    print(model)
    criterion = FocalLoss(alpha=1.0, gamma=2.0)
    optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3, verbose=True)

    # GradScaler robust init
    try:
        scaler = _GradScaler('cuda' if DEVICE.type == 'cuda' else 'cpu')
    except TypeError:
        scaler = _GradScaler(enabled=(DEVICE.type=='cuda'))

    best_auc = -1.0
    patience_cnt = 0

    for epoch in range(1, EPOCHS+1):
        model.train()
        running_loss = 0.0
        n_graphs = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}", leave=False)
        for batch in pbar:
            batch = batch.to(DEVICE)
            optimizer.zero_grad(set_to_none=True)
            # robust autocast usage
            try:
                ctx = _autocast(device_type=DEVICE.type, enabled=(DEVICE.type=="cuda"))
            except TypeError:
                ctx = _autocast(enabled=(DEVICE.type=="cuda"))
            with ctx:
                logits = model(batch)  # (batch_size, num_classes)
                # ensure batch.y shaped (batch_size, num_classes)
                target = batch.y
                target = reshape_batch_y_for_loss(target, batch.num_graphs).to(DEVICE)
                loss = criterion(logits, target)
            scaler.scale(loss).backward()
            # gradient clipping (helpful)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3.0)
            scaler.step(optimizer)
            scaler.update()
            running_loss += loss.item() * batch.num_graphs
            n_graphs += batch.num_graphs
            pbar.set_postfix({'loss': f"{running_loss/max(1,n_graphs):.4f}"})

        # validation
        val_auc, val_f1, val_Y, val_P = evaluate_model(model, val_loader)
        scheduler.step(val_auc if not math.isnan(val_auc) else 0.0)
        print(f"Epoch {epoch}  train_loss={running_loss/max(1,n_graphs):.4f}  val_auc={val_auc:.4f}  val_f1={val_f1:.4f}")

        if not math.isnan(val_auc) and val_auc > best_auc + 1e-5:
            best_auc = val_auc
            patience_cnt = 0
            torch.save(model.state_dict(), "best_model_gnn.pth")
            print("  -> saved best_model_gnn.pth")
        else:
            patience_cnt += 1
            if patience_cnt >= PATIENCE:
                print("Early stopping triggered.")
                break

    # load best and tune thresholds
    print("Loading best model for final evaluation and threshold tuning...")
    model.load_state_dict(torch.load("best_model_gnn.pth", map_location=DEVICE))
    _, _, Yv, Pv = evaluate_model(model, val_loader)
    thresholds = tune_thresholds(Yv, Pv)
    test_auc, test_f1, Yt, Pt = evaluate_model(model, test_loader, thresholds=thresholds)
    print(f"[TEST] AUROC={test_auc:.4f}  MacroF1={test_f1:.4f}")
    print(classification_report(Yt, (Pt > np.array(thresholds)).astype(int), target_names=classes, zero_division=0))
    torch.save(model.state_dict(), "ptbxl_gnn_hybrid_final.pth")
    print("Final model saved as ptbxl_gnn_hybrid_final.pth")
    return model, thresholds

In [None]:
def predict_dathea(basefile_noext, model, thresholds=None):
    base = os.path.splitext(basefile_noext)[0]
    rec, _ = wfdb.rdsamp(base)
    sig = rec.astype(np.float32)
    if sig.ndim == 1: sig = sig[:, None]
    if sig.shape[1] != 12 and sig.shape[0] == 12: sig = sig.T
    if sig.shape[1] != 12:
        sig = np.pad(sig, ((0,0),(0, max(0, 12 - sig.shape[1]))))
        sig = sig[:, :12]
    sig = (sig - sig.mean(axis=0)) / (sig.std(axis=0) + 1e-8)
    processed = sig.T
    data = Data(x=torch.tensor(processed, dtype=torch.float), edge_index=EDGE_INDEX)
    batch = Batch.from_data_list([data]).to(DEVICE)
    model.eval()
    with torch.no_grad():
        try:
            ctx = _autocast(device_type=DEVICE.type, enabled=(DEVICE.type=="cuda"))
        except TypeError:
            ctx = _autocast(enabled=(DEVICE.type=="cuda"))
        with ctx:
            logits = model(batch)
            probs = torch.sigmoid(logits).cpu().numpy()[0]
    if thresholds is None:
        preds = (probs > 0.5).astype(int)
    else:
        preds = (probs > np.array(thresholds)).astype(int)
    result = {classes[i]: float(probs[i]) for i in range(NUM_CLASSES)}
    predicted = [classes[i] for i in range(NUM_CLASSES) if preds[i] == 1]
    return result, predicted

model, thresholds = train_and_save()

Device: cuda
Classes: ['CD', 'HYP', 'MI', 'NORM', 'STTC']
GNNHybrid(
  (node_cnn): PerNodeCNN(
    (net): Sequential(
      (0): Residual1D(
        (conv): Sequential(
          (0): Conv1d(1, 16, kernel_size=(7,), stride=(1,), padding=(3,))
          (1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
          (3): Conv1d(16, 16, kernel_size=(7,), stride=(1,), padding=(3,))
          (4): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (res): Conv1d(1, 16, kernel_size=(1,), stride=(1,))
        (act): ReLU()
      )
      (1): Residual1D(
        (conv): Sequential(
          (0): Conv1d(16, 32, kernel_size=(7,), stride=(1,), padding=(3,))
          (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
          (3): Conv1d(32, 32, kernel_size=(7,), stride=(1,), padding=(3,))
          (4): BatchNorm1d(32, eps=1e-05, momentum=0

                                                                          

Epoch 1  train_loss=0.1113  val_auc=0.8368  val_f1=0.4589
  -> saved best_model_gnn.pth


                                                                          

Epoch 2  train_loss=0.0950  val_auc=0.8526  val_f1=0.4924
  -> saved best_model_gnn.pth


                                                                          

Epoch 3  train_loss=0.0891  val_auc=0.8705  val_f1=0.6429
  -> saved best_model_gnn.pth


                                                                          

Epoch 4  train_loss=0.0842  val_auc=0.8885  val_f1=0.6144
  -> saved best_model_gnn.pth


                                                                          

Epoch 5  train_loss=0.0807  val_auc=0.8870  val_f1=0.6104


                                                                          

Epoch 6  train_loss=0.0804  val_auc=0.8871  val_f1=0.5925


                                                                          

Epoch 7  train_loss=0.0792  val_auc=0.8929  val_f1=0.6164
  -> saved best_model_gnn.pth


                                                                          

Epoch 8  train_loss=0.0772  val_auc=0.9002  val_f1=0.6525
  -> saved best_model_gnn.pth


                                                                          

Epoch 9  train_loss=0.0756  val_auc=0.9018  val_f1=0.6757
  -> saved best_model_gnn.pth


                                                                           

Epoch 10  train_loss=0.0755  val_auc=0.9007  val_f1=0.6335


                                                                           

Epoch 11  train_loss=0.0755  val_auc=0.9005  val_f1=0.6638


                                                                           

Epoch 12  train_loss=0.0732  val_auc=0.9055  val_f1=0.6445
  -> saved best_model_gnn.pth


                                                                           

Epoch 13  train_loss=0.0729  val_auc=0.9050  val_f1=0.6418


                                                                           

Epoch 14  train_loss=0.0731  val_auc=0.9028  val_f1=0.6622


                                                                           

Epoch 15  train_loss=0.0728  val_auc=0.9086  val_f1=0.6821
  -> saved best_model_gnn.pth


                                                                           

Epoch 16  train_loss=0.0721  val_auc=0.9065  val_f1=0.6584


                                                                           

Epoch 17  train_loss=0.0728  val_auc=0.9033  val_f1=0.6128


                                                                           

Epoch 18  train_loss=0.0723  val_auc=0.9041  val_f1=0.6049


                                                                           

Epoch 19  train_loss=0.0721  val_auc=0.9084  val_f1=0.6822


                                                                           

Epoch 20  train_loss=0.0692  val_auc=0.9123  val_f1=0.6865
  -> saved best_model_gnn.pth
Loading best model for final evaluation and threshold tuning...
[TEST] AUROC=0.9069  MacroF1=0.7057
              precision    recall  f1-score   support

          CD       0.83      0.62      0.71       498
         HYP       0.35      0.68      0.46       263
          MI       0.73      0.78      0.75       553
        NORM       0.77      0.93      0.85       964
        STTC       0.68      0.86      0.76       523

   micro avg       0.69      0.81      0.74      2801
   macro avg       0.67      0.77      0.71      2801
weighted avg       0.72      0.81      0.75      2801
 samples avg       0.73      0.81      0.75      2801

Final model saved as ptbxl_gnn_hybrid_final.pth


In [3]:
!python --version

Python 3.11.13


In [4]:
!pip freeze>req.txt