In [7]:
import sys
from pathlib import Path

sys.path.insert(0, str(Path.cwd().parent) if 'notebooks' in str(Path.cwd()) else str(Path.cwd()))

from src.db import attach_duckdb, load_sql, duckdb_to_df

In [10]:
### FETCH COHORT FROM REMOTE DB (WITH CHECKPOINT) ###

import pandas as pd
import os
from src.db import get_db, load_sql

print("Fetching Cohort...")
checkpoint_path = "/tmp/cohort_checkpoint.csv"

# Try to load from checkpoint if it exists
if os.path.exists(checkpoint_path):
    print(f"Loading cohort from checkpoint ({checkpoint_path})...")
    cohort_df = pd.read_csv(checkpoint_path)
else:
    try:
        db = get_db()
        db.attach_postgres(dbname="mimic")
        cohort_df = db.execute_query(load_sql("cohorts.sql"))
        
        # Save checkpoint for recovery
        cohort_df.to_csv(checkpoint_path, index=False)
        print(f"✓ Saved cohort checkpoint to {checkpoint_path}")
    except Exception as e:
        print(f"❌ Error loading cohort: {type(e).__name__}: {e}")

cohort_ids = tuple(cohort_df['hadm_id'].tolist())
print(f"✓ {len(cohort_ids)} patients in cohort")

ImportError: cannot import name 'get_db' from 'src.db' (/workspaces/MHIA123-Project/playground/src/db.py)

In [5]:
### FETCH FEATURES (OPTIMIZED SQL WITH TEMP TABLE) ###

import pandas as pd
import numpy as np
from src.db import get_db, load_sql

print("Fetching Features...")
print(f"Total cohort IDs: {len(cohort_ids)}")

try:
    db = get_db()
    
    # Create temp table from cohort_ids (much faster than IN clause)
    print("Creating temporary table from cohort IDs...")
    cohort_ids_df = pd.DataFrame({'hadm_id': list(cohort_ids)})
    db.create_temp_table(cohort_ids_df, table_name='temp_cohort')
    
    # Optimized SQL using temp table join instead of IN clause
    optimized_query = """
    WITH cohort_filtered AS (
        SELECT DISTINCT
            adm.subject_id,
            adm.hadm_id,
            adm.admittime,
            CASE 
                WHEN adm.hospital_expire_flag = 1 THEN adm.deathtime
                ELSE adm.dischtime
            END AS end_time
        FROM mimic.mimiciv_hosp.admissions adm
        INNER JOIN temp_cohort tc ON adm.hadm_id = tc.hadm_id
    ),
    vitals_agg AS (
        SELECT 
            c.hadm_id,
            AVG(CASE WHEN v.charttime <= c.admittime + INTERVAL '24' HOUR THEN v.heart_rate END) as hr_base,
            AVG(CASE WHEN v.charttime <= c.admittime + INTERVAL '24' HOUR THEN v.mbp END) as map_base,
            AVG(CASE WHEN v.charttime >= c.end_time - INTERVAL '24' HOUR THEN v.heart_rate END) as hr_end,
            AVG(CASE WHEN v.charttime >= c.end_time - INTERVAL '24' HOUR THEN v.mbp END) as map_end
        FROM cohort_filtered c
        LEFT JOIN mimic.mimiciv_derived.vitalsign v 
            ON c.subject_id = v.subject_id
            AND v.charttime BETWEEN c.admittime AND c.end_time
        GROUP BY c.hadm_id
    ),
    labs_agg AS (
        SELECT 
            c.hadm_id,
            AVG(CASE WHEN l.charttime <= c.admittime + INTERVAL '24' HOUR THEN l.creatinine END) as crea_base,
            AVG(CASE WHEN l.charttime <= c.admittime + INTERVAL '24' HOUR THEN l.lactate END) as lac_base,
            AVG(CASE WHEN l.charttime >= c.end_time - INTERVAL '24' HOUR THEN l.creatinine END) as crea_end,
            AVG(CASE WHEN l.charttime >= c.end_time - INTERVAL '24' HOUR THEN l.lactate END) as lac_end
        FROM cohort_filtered c
        LEFT JOIN mimic.mimiciv_derived.chemistry l
            ON c.subject_id = l.subject_id
            AND l.charttime BETWEEN c.admittime AND c.end_time
        GROUP BY c.hadm_id
    )
    SELECT 
        c.hadm_id,
        COALESCE(v.hr_base, 0) as hr_base,
        COALESCE(v.map_base, 0) as map_base,
        COALESCE(l.crea_base, 0) as crea_base,
        COALESCE(l.lac_base, 0) as lac_base,
        COALESCE(v.hr_end, 0) as hr_end,
        COALESCE(v.map_end, 0) as map_end,
        COALESCE(l.crea_end, 0) as crea_end,
        COALESCE(l.lac_end, 0) as lac_end
    FROM cohort_filtered c
    LEFT JOIN vitals_agg v ON c.hadm_id = v.hadm_id
    LEFT JOIN labs_agg l ON c.hadm_id = l.hadm_id
    """
    
    print("Executing optimized query...")
    features_df = db.execute_query(optimized_query)
    db.unregister_temp_table('temp_cohort')
    
    print(f"✓ Features DF fetched: {features_df.shape}")
    print(f"✓ Features memory: {features_df.memory_usage(deep=True).sum() / 1e6:.2f} MB")
    
except Exception as e:
    print(f"❌ Optimized query failed: {type(e).__name__}: {e}")

# Merge carefully
#print("Merging cohort + features...")
#full_data = pd.merge(cohort_df, features_df, on='hadm_id', how='inner')
#print(f"✓ Merged data shape: {full_data.shape}")

Fetching Features...
Total cohort IDs: 11726
Creating temporary table from cohort IDs...
Executing optimized query...
❌ Optimized query failed: BinderException: Binder Error: Catalog "mimic" does not exist!


In [13]:
### DATA VALIDATION & MEMORY CHECK ###
import numpy as np

def print_memory_status():
    mem = psutil.virtual_memory()
    pid = os.getpid()
    proc = psutil.Process(pid)
    print(f"System: {mem.available/1e9:.2f} GB avail | {mem.percent}% used")
    print(f"Process: {proc.memory_info().rss/1e6:.2f} MB")

print("=== AFTER MERGE ===")
print_memory_status()

print(f"Cohort DF: {cohort_df.shape}")
print(f"Features DF: {features_df.shape}")
print(f"Merged DF: {full_data.shape}")

# If dataset is still too large, downsample
MAX_SAMPLES = 2000
if len(full_data) > MAX_SAMPLES:
    print(f"⚠️ Dataset too large ({len(full_data)} > {MAX_SAMPLES}), downsampling...")
    full_data = full_data.sample(n=MAX_SAMPLES, random_state=42, stratify=full_data['label'] if 'label' in full_data.columns else None)
    print(f"✓ Downsampled to {len(full_data)} samples")

print(f"✓ Full data shape: {full_data.shape}")
print(f"✓ Full data memory: {full_data.memory_usage(deep=True).sum() / 1e6:.2f} MB")

=== AFTER MERGE ===
System: 4.33 GB avail | 47.3% used
Process: 1765.51 MB
Cohort DF: (11726, 9)
Features DF: (1000, 9)
Merged DF: (1000, 17)
✓ Full data shape: (1000, 17)
✓ Full data memory: 0.38 MB


In [14]:
### DATA PREPROCESSING & VALIDATION ###

import numpy as np

# Impute all numeric columns with median
print("Imputing missing values...")
numeric_cols = full_data.select_dtypes(include=[np.number]).columns.tolist()
for col in numeric_cols:
    if col != 'label':  # Don't impute the target
        full_data[col] = full_data[col].fillna(full_data[col].median())

# Ensure gender is numeric
if 'gender' in full_data.columns:
    full_data['gender'] = full_data['gender'].apply(lambda x: 1 if str(x).upper() == 'M' else 0)

# Define feature columns (include baseline and end values)
feature_cols = [col for col in full_data.columns 
                if col not in ['hadm_id', 'subject_id', 'label', 'admittime', 'dischtime', 'deathtime', 'end_time']
                and full_data[col].dtype in [np.float64, np.float32, np.int64, np.int32]]

print(f"Feature columns: {feature_cols}")

# ⚠️ CRITICAL: Validate data before model training
print("Validating data integrity...")
assert len(full_data) > 0, "❌ Data is empty after preprocessing!"
assert len(feature_cols) > 0, "❌ No feature columns found!"
assert full_data[feature_cols].isna().sum().sum() == 0, f"❌ NaN values in features: {full_data[feature_cols].isna().sum()}"
assert not np.isinf(full_data[feature_cols].values).any(), "❌ Inf values in features!"
print(f"✓ Data validated: {len(full_data)} samples, {len(feature_cols)} features")
if 'label' in full_data.columns:
    print(f"✓ Class balance: {(full_data['label'] == 1).sum()} positives, {(full_data['label'] == 0).sum()} negatives")

In [15]:
### SPLIT DATA TO TRAINING AND TEST SETS ###

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

X = full_data[feature_cols].values
y = full_data['label'].values

# 8. Split and Scale
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

In [16]:
### MODEL DEFINITION ###

from src.model import CardiacDataset, MortalityPredictor
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim

train_dataset = CardiacDataset(X_train_scaled, y_train)
test_dataset = CardiacDataset(X_test_scaled, y_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

model = MortalityPredictor(input_dim=X_train.shape[1])

criterion = nn.BCELoss() # Binary Cross Entropy
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [17]:
### MODEL TRAINING ###

EPOCHS = 20
print("Starting Training...")

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0
    for X_batch, y_batch in train_loader:
        optimizer.zero_grad()
        y_pred = model(X_batch)
        loss = criterion(y_pred, y_batch)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {train_loss/len(train_loader):.4f}")

In [18]:
### MODEL EVALUATION ###

import torch
import numpy as np
from sklearn.metrics import roc_auc_score, classification_report, brier_score_loss

model.eval()
all_preds = []
all_targets = []

with torch.no_grad():
    for X_batch, y_batch in test_loader:
        X_batch = X_batch.to(device)
        y_pred = model(X_batch)
        all_preds.extend(y_pred.cpu().numpy())
        all_targets.extend(y_batch.numpy())

y_pred_probs = np.array(all_preds).flatten()
y_targets = np.array(all_targets).flatten()

# Calculate metrics
auc = roc_auc_score(y_targets, y_pred_probs)
brier = brier_score_loss(y_targets, y_pred_probs)

print(f"✓ Model Evaluation Results:")
print(f"  AUC-ROC: {auc:.4f}")
print(f"  Brier Score: {brier:.4f}")

# Binary predictions at 0.5 threshold
y_pred_binary = (y_pred_probs > 0.5).astype(int)
print(f"\nClassification Report:")
print(classification_report(y_targets, y_pred_binary))