# ChagaSight — Vision Transformer (Baseline Training)

Baseline ViT training on 2D ECG contour images  
Datasets: PTB-XL (negatives), SaMi-Trop (positives), CODE-15 (soft labels)

Baseline configuration:
- 1% subset (pipeline verification)
- No data augmentation
- AMP enabled
- Strict data integrity checks


In [12]:
import time, random
from pathlib import Path

import numpy as np
import pandas as pd
import torch

start_time = time.time()

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.benchmark = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
if device.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
    print("VRAM (GB):", torch.cuda.get_device_properties(0).total_memory / 1e9)

# Find project root robustly (VS Code safe)
def find_project_root(start: Path) -> Path:
    for p in [start] + list(start.parents):
        if (p / "data").exists():
            return p
    return start

PROJECT_ROOT = find_project_root(Path.cwd())
DATA_DIR = PROJECT_ROOT / "data" / "processed"
MODEL_DIR = PROJECT_ROOT / "models"
MODEL_DIR.mkdir(exist_ok=True)

print("PROJECT_ROOT:", PROJECT_ROOT)
print("DATA_DIR:", DATA_DIR)

print(f"⏱ Cell 1 time: {time.time() - start_time:.2f}s")


Device: cuda
GPU: NVIDIA GeForce RTX 3050 6GB Laptop GPU
VRAM (GB): 6.441926656
PROJECT_ROOT: d:\IIT\L6\FYP\ChagaSight
DATA_DIR: d:\IIT\L6\FYP\ChagaSight\data\processed
⏱ Cell 1 time: 0.00s


In [13]:
# =========================
# Cell 2 — Metadata loading + integrity filtering + 1% subset
# =========================
import time
from pathlib import Path
from sklearn.model_selection import train_test_split

start_time = time.time()

datasets = ["ptbxl", "sami_trop", "code15"]
dfs = []

# Load metadata CSVs
for ds in datasets:
    csv_path = DATA_DIR / "metadata" / f"{ds}_metadata.csv"
    if not csv_path.exists():
        raise FileNotFoundError(f"Missing metadata CSV: {csv_path}")
    df = pd.read_csv(csv_path)
    df["dataset"] = ds
    dfs.append(df)

df_all = pd.concat(dfs, ignore_index=True)
print("Total metadata rows:", len(df_all))

# -------------------------
# HARD integrity filter (RELATIVE PATH SAFE)
# -------------------------
def img_exists(p):
    return (PROJECT_ROOT / Path(p)).exists()

exists_mask = df_all["img_path"].apply(img_exists)
missing_count = (~exists_mask).sum()

if missing_count > 0:
    print(f"⚠️ Dropping {missing_count} rows with missing image files")
    print(df_all.loc[~exists_mask, ["dataset", "img_path"]].head())

df_all = df_all.loc[exists_mask].reset_index(drop=True)
print("Rows after integrity filter:", len(df_all))

# -------------------------
# START SMALL: 1% subset
# -------------------------
subset_frac = 0.01
df_all = df_all.sample(frac=subset_frac, random_state=SEED).reset_index(drop=True)
print("Subset records (1%):", len(df_all))

# -------------------------
# Binary label ONLY for stratification / metrics
# -------------------------
df_all["label_bin"] = (df_all["label"] > 0.5).astype(int)

# -------------------------
# Train / Val / Test split
# -------------------------
train_df, temp_df = train_test_split(
    df_all,
    test_size=0.2,
    stratify=df_all["label_bin"],
    random_state=SEED
)

val_df, test_df = train_test_split(
    temp_df,
    test_size=0.5,
    stratify=temp_df["label_bin"],
    random_state=SEED
)

print(f"Train: {len(train_df)} | Val: {len(val_df)} | Test: {len(test_df)}")
print(f"⏱ Cell 2 time: {time.time() - start_time:.2f}s")


Total metadata rows: 63228
Rows after integrity filter: 63228
Subset records (1%): 632
Train: 505 | Val: 63 | Test: 64
⏱ Cell 2 time: 3.33s


In [14]:
# =========================
# Cell 3 — Dataset + DataLoaders (baseline, no augmentation)
# =========================
import time
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

start_time = time.time()

class ECGImageDataset(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # Resolve relative path safely
        img_path = PROJECT_ROOT / Path(row["img_path"])
        if not img_path.exists():
            raise FileNotFoundError(f"Missing image file: {img_path}")

        img = np.load(img_path).astype(np.float32)

        # Strict shape check (research safety)
        if img.shape != (3, 24, 2048):
            raise ValueError(f"Invalid image shape {img.shape} at {img_path}")

        img = torch.from_numpy(img)  # (3, 24, 2048)
        label = torch.tensor(row["label"], dtype=torch.float32)

        return img, label


# -------------------------
# DataLoaders
# -------------------------
batch_size = 16  # RTX 3050 (6GB) safe

train_ds = ECGImageDataset(train_df)
val_ds   = ECGImageDataset(val_df)
test_ds  = ECGImageDataset(test_df)

# Oversample confident positives (CODE-15 & SaMi-Trop)
weights = train_df["label"].apply(lambda x: 10.0 if x > 0.7 else 1.0).values
sampler = WeightedRandomSampler(weights, len(weights), replacement=True)

train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    sampler=sampler,
    num_workers=0,      # Windows-safe
    pin_memory=True
)

val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

test_loader = DataLoader(
    test_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

# -------------------------
# Sanity check: first batch
# -------------------------
print("Running DataLoader sanity check...")
x_batch, y_batch = next(iter(train_loader))

print("✓ Batch image shape :", x_batch.shape)   # (16, 3, 24, 2048)
print("✓ Batch label shape :", y_batch.shape)
print("✓ Sample labels    :", y_batch[:10].tolist())
print(
    f"✓ Image value range : "
    f"[{x_batch.min().item():.3f}, {x_batch.max().item():.3f}]"
)

print(f"⏱ Cell 3 time: {time.time() - start_time:.2f}s")


Running DataLoader sanity check...
✓ Batch image shape : torch.Size([16, 3, 24, 2048])
✓ Batch label shape : torch.Size([16])
✓ Sample labels    : [0.20000000298023224, 0.20000000298023224, 0.20000000298023224, 0.0, 0.20000000298023224, 1.0, 0.800000011920929, 0.20000000298023224, 0.800000011920929, 1.0]
✓ Image value range : [0.000, 255.000]
⏱ Cell 3 time: 0.03s


In [15]:
# =========================
# Cell 4 — Vision Transformer model + forward sanity test
# =========================
import time
import torch
import torch.nn as nn

start_time = time.time()

# -------------------------
# Patch Embedding
# -------------------------
class PatchEmbedding(nn.Module):
    def __init__(self, patch_size=16, in_ch=3, embed_dim=768):
        super().__init__()
        self.proj = nn.Conv2d(
            in_ch,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
        self.num_patches = (24 // patch_size) * (2048 // patch_size)

    def forward(self, x):
        x = self.proj(x)                  # (B, E, H', W')
        x = x.flatten(2).transpose(1, 2)  # (B, N, E)
        return x


# -------------------------
# Transformer Block
# -------------------------
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim=768, heads=12, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(
            embed_dim,
            heads,
            dropout=dropout,
            batch_first=True
        )
        self.norm2 = nn.LayerNorm(embed_dim)

        mlp_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        y, _ = self.attn(self.norm1(x), self.norm1(x), self.norm1(x))
        x = x + y
        x = x + self.mlp(self.norm2(x))
        return x


# -------------------------
# ViT Classifier
# -------------------------
class ViTClassifier(nn.Module):
    def __init__(
        self,
        patch_size=16,
        embed_dim=768,
        depth=12,
        heads=12,
        mlp_ratio=4.0,
        dropout=0.1
    ):
        super().__init__()

        self.patch_embed = PatchEmbedding(
            patch_size=patch_size,
            in_ch=3,
            embed_dim=embed_dim
        )

        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, embed_dim)
        )
        self.pos_drop = nn.Dropout(dropout)

        self.blocks = nn.ModuleList([
            TransformerBlock(
                embed_dim=embed_dim,
                heads=heads,
                mlp_ratio=mlp_ratio,
                dropout=dropout
            )
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, 1)  # binary logits

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x[:, 0])
        return self.head(x).squeeze(-1)


# -------------------------
# Instantiate model
# -------------------------
model = ViTClassifier().to(device)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"ViT trainable parameters: {num_params:,}")
print("Model device:", next(model.parameters()).device)

# -------------------------
# Forward sanity test
# -------------------------
model.eval()
with torch.no_grad():
    x_batch_gpu = x_batch.to(device)
    logits = model(x_batch_gpu)

print("✓ Forward pass OK")
print("✓ Logits shape :", logits.shape)  # (B,)

if device.type == "cuda":
    mem = torch.cuda.max_memory_allocated() / 1e9
    print(f"✓ GPU memory used (GB): {mem:.2f}")

print(f"⏱ Cell 4 time: {time.time() - start_time:.2f}s")


ViT trainable parameters: 85,747,201
Model device: cuda:0
✓ Forward pass OK
✓ Logits shape : torch.Size([16])
✓ GPU memory used (GB): 2.46
⏱ Cell 4 time: 0.59s


In [16]:
# =========================
# Cell 5 — Training loop (baseline, AMP, progress, timing)
# =========================
import time
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import roc_auc_score
from tqdm.auto import tqdm

start_time = time.time()

# -------------------------
# Training configuration
# -------------------------
num_epochs = 5
learning_rate = 3e-4
weight_decay = 0.05

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(
    model.parameters(),
    lr=learning_rate,
    weight_decay=weight_decay
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=num_epochs
)

use_amp = device.type == "cuda"
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

best_val_auc = 0.0
best_model_path = MODEL_DIR / "vit_baseline_best.pth"

print("Starting training...")
print("AMP enabled:", use_amp)

# -------------------------
# Epoch loop
# -------------------------
for epoch in range(num_epochs):
    epoch_start = time.time()

    # ---- Training ----
    model.train()
    train_loss = 0.0

    train_bar = tqdm(
        train_loader,
        desc=f"Epoch {epoch+1}/{num_epochs} [Train]",
        leave=False
    )

    for imgs, labels in train_bar:
        imgs = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast(enabled=use_amp):
            logits = model(imgs)
            loss = criterion(logits, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()
        train_bar.set_postfix(loss=f"{loss.item():.4f}")

    train_loss /= len(train_loader)

    # ---- Validation ----
    model.eval()
    val_preds = []
    val_trues = []

    with torch.no_grad():
        for imgs, labels in tqdm(
            val_loader,
            desc=f"Epoch {epoch+1}/{num_epochs} [Val]",
            leave=False
        ):
            imgs = imgs.to(device, non_blocking=True)
            probs = torch.sigmoid(model(imgs)).cpu().numpy()

            val_preds.extend(probs)
            val_trues.extend(labels.numpy())

    val_trues = np.asarray(val_trues)
    val_preds = np.asarray(val_preds)

    # Binarise labels for metrics ONLY
    val_trues_bin = (val_trues > 0.5).astype(int)

    val_auc = roc_auc_score(val_trues_bin, val_preds)

    # Save best model
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        torch.save(model.state_dict(), best_model_path)
        improved = "✅"
    else:
        improved = ""

    scheduler.step()

    print(
        f"Epoch {epoch+1:02d} | "
        f"loss={train_loss:.4f} | "
        f"val AUROC={val_auc:.4f} {improved} | "
        f"time={time.time() - epoch_start:.1f}s"
    )

print("\nTraining complete.")
print("Best validation AUROC:", best_val_auc)
print(f"⏱ Cell 5 total time: {time.time() - start_time:.2f}s")


Starting training...
AMP enabled: True


  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)


Epoch 1/5 [Train]:   0%|          | 0/32 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=use_amp):


Epoch 1/5 [Val]:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 01 | loss=1.0297 | val AUROC=0.8033 ✅ | time=10.6s


Epoch 2/5 [Train]:   0%|          | 0/32 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=use_amp):


Epoch 2/5 [Val]:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 02 | loss=0.6076 | val AUROC=0.9098 ✅ | time=10.0s


Epoch 3/5 [Train]:   0%|          | 0/32 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=use_amp):


Epoch 3/5 [Val]:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 03 | loss=0.6258 | val AUROC=0.9672 ✅ | time=9.9s


Epoch 4/5 [Train]:   0%|          | 0/32 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=use_amp):


Epoch 4/5 [Val]:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 04 | loss=0.6047 | val AUROC=0.9672  | time=9.4s


Epoch 5/5 [Train]:   0%|          | 0/32 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=use_amp):


Epoch 5/5 [Val]:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 05 | loss=0.6063 | val AUROC=0.9672  | time=9.7s

Training complete.
Best validation AUROC: 0.9672131147540983
⏱ Cell 5 total time: 49.57s


In [17]:
# =========================
# Cell 6 — Test evaluation (held-out set)
# =========================
import time
from sklearn.metrics import roc_auc_score
from tqdm.auto import tqdm

start_time = time.time()

# -------------------------
# Load best model
# -------------------------
best_model_path = MODEL_DIR / "vit_baseline_best.pth"
assert best_model_path.exists(), "Best model checkpoint not found!"

model.load_state_dict(torch.load(best_model_path, map_location=device))
model.eval()

print("Loaded best model from:", best_model_path)

# -------------------------
# Test loop
# -------------------------
test_preds = []
test_trues = []

with torch.no_grad():
    for imgs, labels in tqdm(test_loader, desc="Test evaluation"):
        imgs = imgs.to(device, non_blocking=True)

        probs = torch.sigmoid(model(imgs)).cpu().numpy()
        test_preds.extend(probs)
        test_trues.extend(labels.numpy())

test_preds = np.asarray(test_preds)
test_trues = np.asarray(test_trues)

# -------------------------
# Metrics (binary labels ONLY for metrics)
# -------------------------
test_trues_bin = (test_trues > 0.5).astype(int)
test_auc = roc_auc_score(test_trues_bin, test_preds)

print("\n=== TEST RESULTS ===")
print(f"Test AUROC : {test_auc:.4f}")
print(f"⏱ Cell 6 time: {time.time() - start_time:.2f}s")


Loaded best model from: d:\IIT\L6\FYP\ChagaSight\models\vit_baseline_best.pth


Test evaluation:   0%|          | 0/4 [00:00<?, ?it/s]


=== TEST RESULTS ===
Test AUROC : 0.9194
⏱ Cell 6 time: 1.09s
