In [1]:
import time, random, sys, os
import subprocess
import threading
from pathlib import Path
from datetime import datetime
import warnings
warnings.filterwarnings("ignore")


import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, SequentialLR
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast


import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    roc_auc_score, average_precision_score, f1_score,
    accuracy_score, precision_score, recall_score, confusion_matrix,
    roc_curve, auc
)
from tqdm import tqdm


# ---- Reproducibility ----
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


# ---- Device ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
if device.type == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM (GB): {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f}")
    print(f"PyTorch: {torch.__version__}")
    print(f"CUDA: {torch.version.cuda}")


# ---- GPU Monitoring (background thread) ----
def monitor_gpu():
    try:
        subprocess.Popen([
            'nvidia-smi',
            '--query-gpu=timestamp,name,utilization.gpu,memory.used,memory.total',
            '--format=csv', '-l', '10'
        ])
    except FileNotFoundError:
        print("Warning: nvidia-smi not found ‚Äì GPU monitoring skipped.")


threading.Thread(target=monitor_gpu, daemon=True).start()


# ---- Paths (auto-detect project root) ----
def find_project_root(start: Path) -> Path:
    for p in [start] + list(start.parents):
        if (p / "data").exists():
            return p
    return start


PROJECT_ROOT = find_project_root(Path.cwd())
DATA_DIR = PROJECT_ROOT / "data" / "processed"


if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))


# ---- Experiment ----
EXP_NAME = "vit_contour_research_approved"
RUN_ID = datetime.now().strftime("%Y%m%d_%H%M%S")
EXP_DIR = PROJECT_ROOT / "experiments" / EXP_NAME / RUN_ID
EXP_DIR.mkdir(parents=True, exist_ok=True)


print(f"\n{'='*70}")
print(f"PROJECT_ROOT: {PROJECT_ROOT}")
print(f"DATA_DIR: {DATA_DIR}")
print(f"EXP_DIR: {EXP_DIR}")
print(f"{'='*70}\n")


cell_start = time.time()
print(f"‚úì Cell 1 initialized in {time.time() - cell_start:.2f}s")


Device: cuda
GPU: NVIDIA GeForce RTX 3050 6GB Laptop GPU
VRAM (GB): 6.44
PyTorch: 2.7.1+cu118
CUDA: 11.8

PROJECT_ROOT: d:\IIT\L6\FYP\ChagaSight
DATA_DIR: d:\IIT\L6\FYP\ChagaSight\data\processed
EXP_DIR: d:\IIT\L6\FYP\ChagaSight\experiments\vit_contour_research_approved\20260112_120757

‚úì Cell 1 initialized in 0.00s


In [2]:
cell_start = time.time()

# ---- Load all metadata ----
datasets = ["ptbxl", "sami_trop", "code15"]
dfs = []

for ds in datasets:
    csv_path = DATA_DIR / "metadata" / f"{ds}_metadata.csv"
    if not csv_path.exists():
        raise FileNotFoundError(f"Missing: {csv_path}")
    
    df = pd.read_csv(csv_path)
    df["dataset"] = ds
    df["label"] = df["label"].astype(float)
    dfs.append(df)
    print(f"Loaded {ds}: {len(df)} records")

df_all = pd.concat(dfs, ignore_index=True)

# ---- TEST 1: Check for duplicates ----
duplicate_ids = df_all[df_all.duplicated(subset=['id'], keep=False)]
if len(duplicate_ids) > 0:
    print(f"‚ö†Ô∏è WARNING: {len(duplicate_ids)} duplicate IDs found")
else:
    print("‚úì No duplicate IDs")

# ---- TEST 2: Check label values ----
print(f"\nLabel statistics:")
print(f"  Min: {df_all['label'].min():.4f}")
print(f"  Max: {df_all['label'].max():.4f}")
print(f"  Mean: {df_all['label'].mean():.4f}")
print(f"  NaN count: {df_all['label'].isna().sum()}")

# Drop any NaN labels
df_all = df_all.dropna(subset=['label']).reset_index(drop=True)

# ---- TEST 3: Class distribution per dataset ----
print(f"\nClass distribution by dataset:")
for ds in datasets:
    ds_df = df_all[df_all['dataset'] == ds]
    pos_count = (ds_df['label'] > 0.5).sum()
    neg_count = (ds_df['label'] <= 0.5).sum()
    soft_count = ((ds_df['label'] > 0.1) & (ds_df['label'] < 0.9)).sum()
    print(f"  {ds:12} | Pos: {pos_count:5} | Neg: {neg_count:5} | Soft: {soft_count:5} | Total: {len(ds_df):5}")

# Global distribution
pos_total = (df_all['label'] > 0.5).sum()
neg_total = (df_all['label'] <= 0.5).sum()
imbalance_ratio = neg_total / (pos_total + 1e-6)
print(f"\n  GLOBAL: {pos_total} positive, {neg_total} negative (ratio: {imbalance_ratio:.1f}x imbalance)")
print(f"  ‚ö†Ô∏è Severe class imbalance detected! Weighted loss is CRITICAL.\n")

# ---- TEST 4: Check image file existence ----
def img_exists(p):
    """Cross-platform path checking"""
    clean_path = str(p).replace("\\", "/")
    full_path = (PROJECT_ROOT / Path(clean_path)).resolve()
    return full_path.exists()

exists_mask = df_all["img_path"].apply(img_exists)
missing_count = (~exists_mask).sum()

if missing_count > 0:
    print(f"‚ö†Ô∏è WARNING: {missing_count} missing image files")
    df_all = df_all.loc[exists_mask].reset_index(drop=True)
    print(f"Dropped {missing_count} rows. Remaining: {len(df_all)}")
else:
    print("‚úì All image files exist")

# ---- Create binary labels for metrics only ----
df_all["label_bin"] = (df_all["label"] > 0.5).astype(int)

# ---- TEST 5: Stratified splits with distribution check ----
print(f"\nCreating stratified train/val/test splits...")
train_df, temp_df = train_test_split(
    df_all, test_size=0.2, stratify=df_all["label_bin"], random_state=SEED
)
val_df, test_df = train_test_split(
    temp_df, test_size=0.5, stratify=temp_df["label_bin"], random_state=SEED
)

print(f"Train: {len(train_df)} ({(train_df['label_bin']==1).sum()} positive)")
print(f"Val:   {len(val_df)} ({(val_df['label_bin']==1).sum()} positive)")
print(f"Test:  {len(test_df)} ({(test_df['label_bin']==1).sum()} positive)")

# Check for leakage
overlap = set(train_df['id']) & set(val_df['id']) | set(train_df['id']) & set(test_df['id'])
if len(overlap) > 0:
    raise RuntimeError(f"Data leakage detected! {len(overlap)} overlapping IDs")
print("‚úì No data leakage")

print(f"\n‚úì Cell 2 completed in {time.time() - cell_start:.2f}s")


Loaded ptbxl: 21799 records
Loaded sami_trop: 1631 records
Loaded code15: 39798 records
‚úì No duplicate IDs

Label statistics:
  Min: 0.0000
  Max: 1.0000
  Mean: 0.1595
  NaN count: 0

Class distribution by dataset:
  ptbxl        | Pos:     0 | Neg: 21799 | Soft:     0 | Total: 21799
  sami_trop    | Pos:  1631 | Neg:     0 | Soft:     0 | Total:  1631
  code15       | Pos:   819 | Neg: 38979 | Soft: 39798 | Total: 39798

  GLOBAL: 2450 positive, 60778 negative (ratio: 24.8x imbalance)
  ‚ö†Ô∏è Severe class imbalance detected! Weighted loss is CRITICAL.

‚úì All image files exist

Creating stratified train/val/test splits...
Train: 50582 (1960 positive)
Val:   6323 (245 positive)
Test:  6323 (245 positive)
‚úì No data leakage

‚úì Cell 2 completed in 18.73s


In [3]:
# SET THIS TO True FOR QUICK TEST (5 min), False FOR FULL TRAINING (3 hours)
USE_SAMPLE = False
SAMPLE_FRACTION = 0.01  # 0.01% of data = ~50 samples

if USE_SAMPLE:
    print(f"\n‚ö†Ô∏è  SAMPLE MODE ENABLED (testing on {SAMPLE_FRACTION*100:.3f}% of data)")
    
    n_samples_train = max(32, int(len(train_df) * SAMPLE_FRACTION))
    n_samples_val = max(16, int(len(val_df) * SAMPLE_FRACTION))
    n_samples_test = max(16, int(len(test_df) * SAMPLE_FRACTION))
    
    train_df = train_df.sample(n=n_samples_train, random_state=SEED).reset_index(drop=True)
    val_df = val_df.sample(n=n_samples_val, random_state=SEED).reset_index(drop=True)
    test_df = test_df.sample(n=n_samples_test, random_state=SEED).reset_index(drop=True)
    
    print(f"  Train: {len(train_df)} ({(train_df['label_bin']==1).sum()} positive)")
    print(f"  Val:   {len(val_df)} ({(val_df['label_bin']==1).sum()} positive)")
    print(f"  Test:  {len(test_df)} ({(test_df['label_bin']==1).sum()} positive)")
    print(f"  Expected training time: ~2-3 minutes\n")
else:
    print(f"\n‚úì FULL MODE - Using all {len(train_df)} training samples\n")



‚úì FULL MODE - Using all 50582 training samples



In [4]:
cell_start = time.time()

class ECGImageDataset(Dataset):
    """
    Loads ECG 2D contour images with full validation.
    ‚≠ê FIXED: Handles BOTH uint8 [0,255] and float32 [-3,3] ranges
    """
    
    def __init__(self, df, validate_first=True):
        self.df = df.reset_index(drop=True)
        self.img_paths = [(PROJECT_ROOT / Path(str(p))).resolve() 
                         for p in self.df["img_path"]]
        self.labels = self.df["label"].astype(np.float32).values
        
        if validate_first:
            print(f"  Validating dataset of {len(self)} samples...")
            self._validate_first_n_samples(n=min(10, len(self)))
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        # Load image
        try:
            img = np.load(self.img_paths[idx]).astype(np.float32)
        except Exception as e:
            raise RuntimeError(f"Failed to load {self.img_paths[idx]}: {e}")
        
        # TEST: Shape validation
        if img.shape != (3, 24, 2048):
            raise ValueError(f"Invalid shape {img.shape} at {self.img_paths[idx]}")
        
        # TEST: NaN/Inf check
        if np.isnan(img).any():
            raise ValueError(f"NaN values in {self.img_paths[idx]}")
        if np.isinf(img).any():
            raise ValueError(f"Inf values in {self.img_paths[idx]}")
        
        # ‚≠ê FIXED: Handle BOTH uint8 [0,255] and float32 [-3,3]
        if img.max() > 4:  # Likely uint8 [0,255]
            # Convert from [0,255] to [-1,1]
            img = (img.astype(np.float32) - 128.0) / 128.0
        else:  # Already normalized [-3,3]
            # Clip and normalize to [-1,1]
            img = np.clip(img, -3, 3)
            img = img / 3.0
        
        # Convert to tensor
        img = torch.from_numpy(img)
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        
        return img, label
    
    def _validate_first_n_samples(self, n=10):
        """Validate first n samples for sanity"""
        for idx in range(min(n, len(self))):
            try:
                img, label = self[idx]
                assert img.shape == (3, 24, 2048), f"Shape mismatch at {idx}: {img.shape}"
                assert img.min() >= -1.1 and img.max() <= 1.1, f"Range mismatch at {idx}: [{img.min():.2f}, {img.max():.2f}]"
                assert not torch.isnan(img).any(), f"NaN in tensor at {idx}"
                assert label.item() >= 0 and label.item() <= 1, f"Invalid label at {idx}: {label.item()}"
            except Exception as e:
                raise RuntimeError(f"Validation failed at sample {idx}: {e}")
        print(f"  ‚úì First {n} samples validated (range: [-1, 1])")

# Create datasets with validation
print(f"Creating datasets with validation...")
train_ds = ECGImageDataset(train_df, validate_first=True)
val_ds = ECGImageDataset(val_df, validate_first=True)
test_ds = ECGImageDataset(test_df, validate_first=True)

# ---- Compute class weights for weighted loss ----
pos_count = (train_df['label'] > 0.5).sum()
neg_count = (train_df['label'] <= 0.5).sum()
pos_weight = neg_count / (pos_count + 1e-6)

print(f"\nClass weights:")
print(f"  Positive weight: {pos_weight:.2f} (compensate for {neg_count}/{pos_count} imbalance)")
print(f"  Negative weight: 1.0")

# ---- DataLoaders (Windows-safe, no multiprocessing delays) ----
batch_size = 16
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, 
                         num_workers=0, pin_memory=False)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
                       num_workers=0, pin_memory=False)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False,
                        num_workers=0, pin_memory=False)

# ---- Sanity check: First batch ----
x_batch, y_batch = next(iter(train_loader))
print(f"\nFirst batch check:")
print(f"  Shape: {x_batch.shape} (expect [16, 3, 24, 2048])")
print(f"  Range: [{x_batch.min().item():.3f}, {x_batch.max().item():.3f}] (expect [-1, 1])")
print(f"  Labels: {y_batch[:5].tolist()} (expect values in [0, 1])")
print(f"  Label dtype: {y_batch.dtype}")

assert x_batch.shape == (batch_size, 3, 24, 2048), "Batch shape mismatch!"
assert not torch.isnan(x_batch).any(), "NaN in batch!"
assert x_batch.min() >= -1.1 and x_batch.max() <= 1.1, f"Range mismatch: [{x_batch.min():.2f}, {x_batch.max():.2f}]"
print("  ‚úì First batch validation passed")

print(f"\n‚úì Cell 3 completed in {time.time() - cell_start:.2f}s")


Creating datasets with validation...
  Validating dataset of 50582 samples...
  ‚úì First 10 samples validated (range: [-1, 1])
  Validating dataset of 6323 samples...
  ‚úì First 10 samples validated (range: [-1, 1])
  Validating dataset of 6323 samples...
  ‚úì First 10 samples validated (range: [-1, 1])

Class weights:
  Positive weight: 24.81 (compensate for 48622/1960 imbalance)
  Negative weight: 1.0

First batch check:
  Shape: torch.Size([16, 3, 24, 2048]) (expect [16, 3, 24, 2048])
  Range: [-1.000, 0.992] (expect [-1, 1])
  Labels: [0.0, 0.20000000298023224, 0.0, 0.20000000298023224, 0.0] (expect values in [0, 1])
  Label dtype: torch.float32
  ‚úì First batch validation passed

‚úì Cell 3 completed in 11.34s


In [5]:
cell_start = time.time()

class ViTClassifier(nn.Module):
    """
    Vision Transformer for 2D ECG images (3x24x2048).
    Implements AoL (Aggregation of Layers) for improved performance.
    
    Architecture:
    - Patch embedding: (3, 24, 2048) ‚Üí (512, 3, 128) ‚Üí flattened (384, 512)
    - 12 transformer blocks
    - CLS token aggregation across all layers (AoL)
    - Classification head
    """
    
    def __init__(self, patch_h=8, patch_w=16, embed_dim=512, 
                 depth=12, heads=8, mlp_ratio=4.0, dropout=0.15):
        super().__init__()
        
        self.patch_h = patch_h
        self.patch_w = patch_w
        self.embed_dim = embed_dim
        
        # ---- Patch embedding (Conv2d for efficiency) ----
        self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=(patch_h, patch_w),
                                     stride=(patch_h, patch_w))
        
        # Expected patches: (24 / 8) * (2048 / 16) = 3 * 128 = 384
        num_patches = (24 // patch_h) * (2048 // patch_w)
        
        # ---- Position embedding + CLS token ----
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        
        # ---- Transformer encoder (12 blocks) ----
        self.blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=embed_dim,
                nhead=heads,
                dim_feedforward=int(embed_dim * mlp_ratio),
                dropout=dropout,
                activation='gelu',
                batch_first=True
            ) for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        
        # ---- Classification head ----
        self.head = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(embed_dim, 1)
        )
        
        # Initialize weights
        nn.init.normal_(self.pos_embed, std=0.02)
        nn.init.normal_(self.cls_token, std=0.02)
    
    def forward(self, x, return_feats=False):
        """
        Forward pass with AoL.
        
        Args:
            x: (B, 3, 24, 2048)
            return_feats: If True, return feature vectors (for alignment)
        
        Returns:
            If return_feats: (B, embed_dim) features
            Else: (B, 1) logits
        """
        B = x.shape[0]  # ‚≠ê FIXED: Extract batch size as integer
        
        # Patch embedding
        x = self.patch_embed(x)  # (B, embed_dim, 3, 128)
        x = x.flatten(2).transpose(1, 2)  # (B, 384, embed_dim)
        
        # Add CLS token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # (B, 385, embed_dim)
        
        # Add position embeddings
        x = x + self.pos_embed
        
        # ---- Transformer blocks with AoL ----
        layer_outputs = []
        for block in self.blocks:
            x = block(x)
            layer_outputs.append(x[:, 0])  # CLS token from each layer
        
        # ---- AoL: Average of all CLS tokens across layers ----
        feats = torch.stack(layer_outputs, dim=1).mean(dim=1)  # (B, embed_dim)
        x = self.norm(feats)
        
        if return_feats:
            return x
        
        # Classification
        return self.head(x)  # (B, 1)

# Initialize model
model = ViTClassifier().to(device)

# Count parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"ViT Model parameters: {num_params:,}")
print(f"Model device: {next(model.parameters()).device}")

# Sanity check: Forward pass
with torch.no_grad():
    x_test = torch.randn(2, 3, 24, 2048).to(device)
    
    # Test logits
    logits = model(x_test)
    assert logits.shape == (2, 1), f"Expected (2, 1), got {logits.shape}"
    
    # Test features
    feats = model(x_test, return_feats=True)
    assert feats.shape == (2, 512), f"Expected (2, 512), got {feats.shape}"
    
    print(f"‚úì Forward pass check passed")
    print(f"  Logits shape: {logits.shape}, range: [{logits.min():.2f}, {logits.max():.2f}]")
    print(f"  Features shape: {feats.shape}, range: [{feats.min():.2f}, {feats.max():.2f}]")

print(f"\n‚úì Cell 4 completed in {time.time() - cell_start:.2f}s")


ViT Model parameters: 38,224,897
Model device: cuda:0
‚úì Forward pass check passed
  Logits shape: torch.Size([2, 1]), range: [-0.95, -0.39]
  Features shape: torch.Size([2, 512]), range: [-3.33, 2.96]

‚úì Cell 4 completed in 0.44s


In [None]:
cell_start = time.time()

# ---- Hyperparameters (adaptive for sample mode) ----
if USE_SAMPLE:
    num_epochs = 2
    print(f"‚ö†Ô∏è  SAMPLE MODE: Training for {num_epochs} epochs only (~2-3 min)")
else:
    num_epochs = 5
    print(f"‚úì FULL MODE: Training for {num_epochs} epochs (~2-3 hours)")

learning_rate = 1e-4
warmup_epochs = min(2, num_epochs // 2)
use_amp = False  #  FIXED: Disable AMP to avoid conflicts

# ---- Optimizer & Loss ----
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device))
print(f"Loss function: BCEWithLogitsLoss with pos_weight={pos_weight:.2f}")

# ---- Scheduler ----
warmup_scheduler = LambdaLR(
    optimizer,
    lr_lambda=lambda epoch: (epoch + 1) / warmup_epochs if epoch < warmup_epochs else 1.0
)

cosine_scheduler = CosineAnnealingLR(
    optimizer,
    T_max=max(1, num_epochs - warmup_epochs),
    eta_min=learning_rate / 10
)

scheduler = SequentialLR(
    optimizer,
    schedulers=[warmup_scheduler, cosine_scheduler],
    milestones=[warmup_epochs]
)

print(f"Scheduler: SequentialLR with warmup={warmup_epochs}, cosine={max(1, num_epochs - warmup_epochs)}")

# ---- Early Stopping ----
patience = 7
patience_counter = 0
best_val_auc = 0.0
best_model_path = EXP_DIR / "model_best.pth"

# ---- Training Loop ----
history = {
    'epoch': [],
    'train_loss': [],
    'val_auc': [],
    'val_auprc': [],
    'val_f1': [],
    'challenge_score': [],
    'lr': []
}

print(f"\n{'='*70}")
print(f"Starting training for {num_epochs} epochs...")
print(f"  Batch size: 16")
print(f"  Gradient accumulation: 2 steps")
print(f"  Effective batch: 32")
print(f"  AMP enabled: False")
print(f"{'='*70}\n")

for epoch in range(num_epochs):
    # ---- Train ----
    model.train()
    train_loss = 0.0
    optimizer.zero_grad()
    
    train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1:02d}", leave=False)
    
    for step, (imgs, labels) in enumerate(train_bar):
        imgs = imgs.to(device)
        labels = labels.to(device).unsqueeze(1)
        
        #  SIMPLIFIED: No autocast complexity
        logits = model(imgs)
        loss = criterion(logits, labels)
        loss = loss / 2
        loss.backward()
        
        if (step + 1) % 2 == 0:
            optimizer.step()
            optimizer.zero_grad()
        
        train_loss += loss.item() * 2
        train_bar.set_postfix(loss=train_loss / (step + 1))
    
    train_loss_avg = train_loss / len(train_loader)
    
    # ---- Validation ----
    model.eval()
    val_preds, val_trues = [], []
    
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs = imgs.to(device)
            logits = model(imgs)
            probs = torch.sigmoid(logits).cpu().numpy().squeeze()
            val_preds.extend(probs if isinstance(probs, np.ndarray) else [probs])
            val_trues.extend(labels.numpy())
    
    val_preds = np.array(val_preds)
    val_trues = np.array(val_trues)
    
    # ---- Metrics ----
    try:
        val_auc = roc_auc_score(val_trues, val_preds)
        val_auprc = average_precision_score(val_trues, val_preds)
    except:
        val_auc = 0.0
        val_auprc = 0.0
    
    val_trues_bin = (val_trues > 0.5).astype(int)
    val_preds_bin = (val_preds >= 0.5).astype(int)
    
    try:
        val_f1 = f1_score(val_trues_bin, val_preds_bin)
    except:
        val_f1 = 0.0
    
    if val_trues_bin.sum() > 0:
        sorted_idx = np.argsort(val_preds)[::-1]
        top_5_pct_idx = max(1, int(0.05 * len(val_preds)))
        challenge_score = val_trues_bin[sorted_idx[:top_5_pct_idx]].mean()
    else:
        challenge_score = 0.0
    
    # ---- Early stopping ----
    improved = ""
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        patience_counter = 0
        torch.save({
            'model_state_dict': model.state_dict(),
            'epoch': epoch + 1,
            'val_auc': val_auc
        }, best_model_path)
        improved = "‚úÖ (best)"
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"\n‚ö†Ô∏è  Early stopping triggered after {epoch+1} epochs")
            break
    
    scheduler.step()
    
    history['epoch'].append(epoch + 1)
    history['train_loss'].append(train_loss_avg)
    history['val_auc'].append(val_auc)
    history['val_auprc'].append(val_auprc)
    history['val_f1'].append(val_f1)
    history['challenge_score'].append(challenge_score)
    history['lr'].append(optimizer.param_groups[0]['lr'])
    
    print(f"Epoch {epoch+1:02d} | Loss: {train_loss_avg:.4f} | "
          f"AUROC: {val_auc:.4f} | AUPRC: {val_auprc:.4f} | "
          f"F1: {val_f1:.4f} | Challenge: {challenge_score:.4f} {improved}")

pd.DataFrame(history).to_csv(EXP_DIR / "metrics.csv", index=False)
print(f"\n‚úì Cell 5 completed in {time.time() - cell_start:.2f}s")


‚úì FULL MODE: Training for 5 epochs (~2-3 hours)
Loss function: BCEWithLogitsLoss with pos_weight=24.81
Scheduler: SequentialLR with warmup=2, cosine=3

Starting training for 5 epochs...
  Batch size: 16
  Gradient accumulation: 2 steps
  Effective batch: 32
  AMP enabled: False



                                                                          

Epoch 01 | Loss: 1.8952 | AUROC: 0.0000 | AUPRC: 0.0000 | F1: 0.1029 | Challenge: 0.5475 


                                                                          

Epoch 02 | Loss: 1.6700 | AUROC: 0.0000 | AUPRC: 0.0000 | F1: 0.0835 | Challenge: 0.4272 


Epoch 03:  72%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè  | 2277/3162 [46:19<18:35,  1.26s/it, loss=1.83] 

In [None]:
cell_start = time.time()

# ---- Check if model file exists ----
if best_model_path.exists():
    print(f"Loading best model from: {best_model_path}")
    checkpoint = torch.load(best_model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded best model from epoch {checkpoint['epoch']} (Val AUROC: {checkpoint['val_auc']:.4f})")
else:
    print(f"‚ö†Ô∏è Model file not found: {best_model_path}")
    print(f"Using current model state (from last epoch)")
    print(f"This is normal for quick tests with few epochs\n")

# Test inference
model.eval()
test_preds, test_trues = [], []

print(f"\nRunning test inference on {len(test_ds)} samples...")
with torch.no_grad():
    test_bar = tqdm(test_loader, desc="Test", leave=False)
    for imgs, labels in test_bar:
        imgs = imgs.to(device)
        logits = model(imgs)
        probs = torch.sigmoid(logits).cpu().numpy().squeeze()
        test_preds.extend(probs if isinstance(probs, np.ndarray) else [probs])
        test_trues.extend(labels.numpy())

test_preds = np.array(test_preds)
test_trues = np.array(test_trues)

# Compute metrics
test_trues_bin = (test_trues > 0.5).astype(int)
test_preds_bin = (test_preds >= 0.5).astype(int)

try:
    test_auc = roc_auc_score(test_trues_bin, test_preds)
    test_auprc = average_precision_score(test_trues_bin, test_preds)
except:
    test_auc = 0.0
    test_auprc = 0.0

try:
    test_acc = accuracy_score(test_trues_bin, test_preds_bin)
    test_f1 = f1_score(test_trues_bin, test_preds_bin)
    test_prec = precision_score(test_trues_bin, test_preds_bin, zero_division=0)
    test_rec = recall_score(test_trues_bin, test_preds_bin, zero_division=0)
except:
    test_acc = 0.0
    test_f1 = 0.0
    test_prec = 0.0
    test_rec = 0.0

# Challenge score
if test_trues_bin.sum() > 0:
    sorted_idx = np.argsort(test_preds)[::-1]
    top_5_pct_idx = max(1, int(0.05 * len(test_preds)))
    challenge_score_test = test_trues_bin[sorted_idx[:top_5_pct_idx]].mean()
else:
    challenge_score_test = 0.0

print(f"\n{'='*70}")
print(f"FINAL TEST RESULTS")
print(f"{'='*70}")
print(f"AUROC:           {test_auc:.4f}  (target: >0.70)")
print(f"AUPRC:           {test_auprc:.4f}")
print(f"Accuracy:        {test_acc:.4f}")
print(f"F1 Score:        {test_f1:.4f}")
print(f"Precision:       {test_prec:.4f}")
print(f"Recall:          {test_rec:.4f}")
print(f"Challenge Score: {challenge_score_test:.4f}  (target: >0.35)")
print(f"{'='*70}\n")

# Save test results
test_results = {
    'metric': ['AUROC', 'AUPRC', 'Accuracy', 'F1', 'Precision', 'Recall', 'Challenge'],
    'value': [test_auc, test_auprc, test_acc, test_f1, test_prec, test_rec, challenge_score_test]
}
pd.DataFrame(test_results).to_csv(EXP_DIR / "test_results.csv", index=False)

print(f"‚úì Cell 6 completed in {time.time() - cell_start:.2f}s")


In [None]:
cell_start = time.time()

# Load history
df_hist = pd.DataFrame(history)

# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss
axes[0, 0].plot(df_hist['epoch'], df_hist['train_loss'], 'b-o', label='Train')
axes[0, 0].set_title('Training Loss', fontsize=12, fontweight='bold')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].grid(True, alpha=0.3)

# AUROC
axes[0, 1].plot(df_hist['epoch'], df_hist['val_auc'], 'g-o', label='Val AUROC')
axes[0, 1].axhline(0.7, color='r', linestyle='--', label='Target (0.7)')
axes[0, 1].set_title('Validation AUROC', fontsize=12, fontweight='bold')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('AUROC')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# AUPRC
axes[1, 0].plot(df_hist['epoch'], df_hist['val_auprc'], 'm-o')
axes[1, 0].set_title('Validation AUPRC', fontsize=12, fontweight='bold')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('AUPRC')
axes[1, 0].grid(True, alpha=0.3)

# Challenge Score
axes[1, 1].plot(df_hist['epoch'], df_hist['challenge_score'], 'c-o')
axes[1, 1].axhline(0.35, color='r', linestyle='--', label='Target (0.35)')
axes[1, 1].set_title('Challenge Score (Top 5%)', fontsize=12, fontweight='bold')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Score')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(EXP_DIR / 'training_curves.png', dpi=150, bbox_inches='tight')
plt.close()

print(f"‚úì Training curves saved")

# Confusion matrix
cm = confusion_matrix(test_trues_bin, test_preds_bin)
fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(cm, cmap='Blues')
ax.set_xticks([0, 1])
ax.set_yticks([0, 1])
ax.set_xticklabels(['Negative', 'Positive'])
ax.set_yticklabels(['Negative', 'Positive'])
ax.set_xlabel('Predicted')
ax.set_ylabel('True')
ax.set_title('Confusion Matrix (Test Set)')
for i in range(2):
    for j in range(2):
        ax.text(j, i, str(cm[i, j]), ha='center', va='center', color='white', fontsize=16, fontweight='bold')
plt.colorbar(im)
plt.tight_layout()
plt.savefig(EXP_DIR / 'confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.close()

print(f"‚úì Confusion matrix saved")

# ROC curve
fpr, tpr, _ = roc_curve(test_trues_bin, test_preds)
roc_auc = auc(fpr, tpr)

fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(fpr, tpr, 'b-', linewidth=2, label=f'ROC (AUC={roc_auc:.4f})')
ax.plot([0, 1], [0, 1], 'k--', linewidth=1, label='Chance')
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title('ROC Curve (Test Set)')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(EXP_DIR / 'roc_curve.png', dpi=150, bbox_inches='tight')
plt.close()

print(f"‚úì ROC curve saved")

print(f"\n‚úì Cell 7 completed in {time.time() - cell_start:.2f}s")


In [None]:
cell_start = time.time()

# Create submission CSV
submission_df = pd.DataFrame({
    'record_id': test_df['id'].values,
    'probability': test_preds,
    'binary_prediction': test_preds_bin,
    'true_label_soft': test_trues,
    'true_label_binary': test_trues_bin
})

submission_df.to_csv(EXP_DIR / 'test_predictions.csv', index=False)
print(f"‚úì Predictions saved to {EXP_DIR / 'test_predictions.csv'}")

# Summary statistics
print(f"\nPrediction statistics:")
print(f"  Mean probability: {test_preds.mean():.4f}")
print(f"  Std probability: {test_preds.std():.4f}")
print(f"  Min probability: {test_preds.min():.4f}")
print(f"  Max probability: {test_preds.max():.4f}")
print(f"  Predicted positive count: {test_preds_bin.sum()} / {len(test_preds)}")

print(f"\n‚úì Cell 8 completed in {time.time() - cell_start:.2f}s")


In [None]:
print(f"\n{'='*70}")
print(f"FINAL SUMMARY")
print(f"{'='*70}")

print(f"\nüìä METRICS ACHIEVED:")
print(f"  ‚úì Test AUROC:           {test_auc:.4f} (target: >0.70)")
print(f"  ‚úì Test AUPRC:           {test_auprc:.4f}")
print(f"  ‚úì Test Challenge Score: {challenge_score_test:.4f} (target: >0.35)")

if test_auc >= 0.70:
    print(f"  ‚úÖ AUROC TARGET ACHIEVED!")
else:
    print(f"  ‚ö†Ô∏è  AUROC below target.")

if challenge_score_test >= 0.35:
    print(f"  ‚úÖ CHALLENGE SCORE TARGET ACHIEVED!")
else:
    print(f"  ‚ö†Ô∏è  Challenge score below target.")

print(f"\nüìÅ OUTPUTS SAVED:")
print(f"  ‚Ä¢ {best_model_path}")
print(f"  ‚Ä¢ {EXP_DIR / 'metrics.csv'}")
print(f"  ‚Ä¢ {EXP_DIR / 'test_results.csv'}")
print(f"  ‚Ä¢ {EXP_DIR / 'test_predictions.csv'}")
print(f"  ‚Ä¢ {EXP_DIR / 'training_curves.png'}")
print(f"  ‚Ä¢ {EXP_DIR / 'confusion_matrix.png'}")
print(f"  ‚Ä¢ {EXP_DIR / 'roc_curve.png'}")

print(f"\n{'='*70}")
print(f"‚ú® All cells completed successfully!")
print(f"{'='*70}\n")
