# LSTM + Patient Features Fusion — Google Colab

Implements the dual-branch architecture from `specs/lstm_patient_fusion_spec.md`.

## Key design decisions for imbalanced classes

| Metric | Why |
|---|---|
| **Macro ROC AUC** | Competition metric; treats all 4 classes equally regardless of frequency |
| **Macro F1** | Training monitor; captures the precision/recall tradeoff — better than accuracy |
| **Per-class F1** | Reveals which deterioration labels (esp. rare 2 & 3) the model handles worst |
| **AUPRC** | Precision-recall curve area; more sensitive than ROC AUC when positives are rare |

> **Why not accuracy?** With imbalanced classes a model that always predicts the majority class scores high accuracy but is clinically useless. F1 penalises both missing real events (low recall) and over-alerting (low precision).

## 0. Install Dependencies

In [None]:
%pip install -q torch gdown pandas scikit-learn matplotlib

## 1. Load Data from Google Drive

Loads data from the [shared Drive folder](https://drive.google.com/drive/folders/13NvOvSW1W0shkAxnYyZJlHFqokhHjUeI).
- **Colab**: Mounts Drive. Update `COLAB_DATA_PATH` to match your folder after mounting.
- **Local**: Downloads via gdown.

In [None]:
from pathlib import Path
import gdown

GOOGLE_DRIVE_FOLDER_ID = "13NvOvSW1W0shkAxnYyZJlHFqokhHjUeI"
GOOGLE_DRIVE_URL = f"https://drive.google.com/drive/folders/{GOOGLE_DRIVE_FOLDER_ID}"
COLAB_DATA_PATH = "/content/drive/MyDrive/Medhack_data"  # <-- update if needed

def get_data_dir() -> Path:
    try:
        from google.colab import drive
        drive.mount("/content/drive", force_remount=False)
        data_dir = Path(COLAB_DATA_PATH)
        if not (data_dir / "train_data_lagged.csv").exists():
            raise FileNotFoundError(
                f"train_data_lagged.csv not found in {data_dir}. "
                "Update COLAB_DATA_PATH to match your Drive folder."
            )
        return data_dir
    except ImportError:
        out = Path("data_from_drive")
        out.mkdir(exist_ok=True)
        gdown.download_folder(url=GOOGLE_DRIVE_URL, output=str(out), quiet=False)
        if (out / "train_data_lagged.csv").exists():
            return out
        for sub in out.iterdir():
            if sub.is_dir() and (sub / "train_data_lagged.csv").exists():
                return sub
        raise FileNotFoundError(f"train_data_lagged.csv not found in {out}.")

DATA_DIR = get_data_dir()
print(f"DATA_DIR = {DATA_DIR}")

## 2. Setup

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.preprocessing import StandardScaler, OneHotEncoder, label_binarize
from sklearn.metrics import (
    roc_auc_score, f1_score, classification_report,
    confusion_matrix, ConfusionMatrixDisplay,
    precision_recall_curve, average_precision_score,
)

# Device: prefer CUDA (Colab GPU) > MPS (Apple Silicon) > CPU
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
elif torch.backends.mps.is_available():
    DEVICE = torch.device('mps')
else:
    DEVICE = torch.device('cpu')
print(f"Device: {DEVICE}")

MODELS_DIR = Path('/content/models') if Path('/content').exists() else Path('../models')
MODELS_DIR.mkdir(exist_ok=True)

VITAL_COLS = ['heart_rate', 'systolic_bp', 'diastolic_bp', 'respiratory_rate', 'oxygen_saturation']
N_LAGS = 6
TRAIN_SAMPLE_SIZE = 500_000
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE)

NUMERIC_PATIENT_COLS     = ['age', 'bmi', 'pain_score']
CATEGORICAL_PATIENT_COLS = ['race', 'ethnicity']
GENDER_COL = 'gender'

## 3. Load and Reshape Vital Time Series

In [None]:
def reshape_lagged_to_sequences(df: pd.DataFrame, vital_cols: list, n_lags: int) -> np.ndarray:
    """Reshape pre-computed lag columns into (n_samples, timesteps, n_vitals).
    Timestep order: lag_n, ..., lag_1, current (oldest → newest).
    """
    sequences = []
    for col in vital_cols:
        lag_cols = [f'{col}_lag{i}' for i in range(n_lags, 0, -1)]
        sequences.append(df[lag_cols + [col]].values)
    return np.stack(sequences, axis=-1)  # (n_samples, 7, 5)


train_raw = pd.read_csv(DATA_DIR / 'train_data_lagged.csv')
test_raw  = pd.read_csv(DATA_DIR / 'test_data_lagged.csv')

X_ts_train_full = reshape_lagged_to_sequences(train_raw, VITAL_COLS, N_LAGS)
y_train_full    = train_raw['label'].values
X_ts_test       = reshape_lagged_to_sequences(test_raw,  VITAL_COLS, N_LAGS)
y_test          = test_raw['label'].values

print(f"Train: {X_ts_train_full.shape}  |  Test: {X_ts_test.shape}")

# Show class distribution — this tells us why accuracy is misleading
print("\nClass distribution (train):")
counts = pd.Series(y_train_full).value_counts().sort_index()
for cls, cnt in counts.items():
    print(f"  Label {cls}: {cnt:>8,}  ({100*cnt/len(y_train_full):.1f}%)")

## 4. Load and Preprocess Patient Features (Tier 1)

In [None]:
patients = pd.read_csv(DATA_DIR / 'patients.csv')
print(f"patients.csv: {patients.shape}")
print(patients.head(3))

In [None]:
def encode_gender(series: pd.Series) -> np.ndarray:
    mapping = {'M': 0, 'Male': 0, 'male': 0, 'F': 1, 'Female': 1, 'female': 1}
    return series.map(mapping).fillna(-1).astype(float).values.reshape(-1, 1)


class PatientFeaturePreprocessor:
    """Fit on train patients; transform any split."""

    def __init__(self, numeric_cols=NUMERIC_PATIENT_COLS,
                 categorical_cols=CATEGORICAL_PATIENT_COLS, gender_col=GENDER_COL):
        self.numeric_cols = numeric_cols
        self.categorical_cols = categorical_cols
        self.gender_col = gender_col
        self.num_scaler = StandardScaler()
        self.cat_encoder = OneHotEncoder(handle_unknown='ignore', sparse_output=False)
        self._num_medians: dict = {}
        self._cat_cols_present: list = []

    def fit(self, df: pd.DataFrame) -> 'PatientFeaturePreprocessor':
        for col in self.numeric_cols:
            self._num_medians[col] = df[col].median() if col in df.columns else 0.0
        self.num_scaler.fit(self._impute_numeric(df))
        self._cat_cols_present = [c for c in self.categorical_cols if c in df.columns]
        if self._cat_cols_present:
            self.cat_encoder.fit(df[self._cat_cols_present].fillna('unknown').astype(str))
        return self

    def _impute_numeric(self, df: pd.DataFrame) -> np.ndarray:
        parts = []
        for col in self.numeric_cols:
            vals = df[col].fillna(self._num_medians[col]).values if col in df.columns \
                   else np.full(len(df), self._num_medians.get(col, 0.0))
            parts.append(vals.reshape(-1, 1))
        return np.hstack(parts).astype(float)

    def transform(self, df: pd.DataFrame) -> np.ndarray:
        parts = [self.num_scaler.transform(self._impute_numeric(df))]
        parts.append(encode_gender(df[self.gender_col]) if self.gender_col in df.columns
                     else np.full((len(df), 1), -1.0))
        if self._cat_cols_present:
            parts.append(self.cat_encoder.transform(
                df[self._cat_cols_present].fillna('unknown').astype(str)
            ))
        return np.hstack(parts).astype(np.float32)


train_patients = patients[patients['encounter_id'].isin(train_raw['encounter_id'])]
pat_prep = PatientFeaturePreprocessor()
pat_prep.fit(train_patients)
n_static = pat_prep.transform(train_patients.head(1)).shape[1]
print(f"n_static features: {n_static}")

## 5. Join, Sample, and Scale

In [None]:
def build_static_features(lagged_df, patients_df, preprocessor):
    pat_cols = ['encounter_id'] + [
        c for c in preprocessor.numeric_cols + preprocessor.categorical_cols + [preprocessor.gender_col]
        if c in patients_df.columns
    ]
    merged = lagged_df[['encounter_id']].merge(patients_df[pat_cols], on='encounter_id', how='left')
    return preprocessor.transform(merged)


X_static_train_full = build_static_features(train_raw, patients, pat_prep)
X_static_test       = build_static_features(test_raw,  patients, pat_prep)

# Stratified sample
splitter = StratifiedShuffleSplit(n_splits=1, train_size=TRAIN_SAMPLE_SIZE, random_state=RANDOM_STATE)
idx, _ = next(splitter.split(X_ts_train_full, y_train_full))
X_ts_train   = X_ts_train_full[idx]
X_stat_train = X_static_train_full[idx]
y_train      = y_train_full[idx]

# Scale vitals (fit on train only)
n_samples, n_timesteps, n_vitals = X_ts_train.shape
ts_scaler = StandardScaler()
ts_scaler.fit(X_ts_train.reshape(-1, n_vitals))

X_ts_train_sc = ts_scaler.transform(X_ts_train.reshape(-1, n_vitals)).reshape(n_samples, n_timesteps, n_vitals)
X_ts_test_sc  = ts_scaler.transform(X_ts_test.reshape(-1, n_vitals)).reshape(X_ts_test.shape[0], n_timesteps, n_vitals)

print(f"Train: ts={X_ts_train_sc.shape}, static={X_stat_train.shape}")
print(f"Test:  ts={X_ts_test_sc.shape},  static={X_static_test.shape}")

## 6. Build Fusion Model

In [None]:
class LSTMFusionModel(nn.Module):
    def __init__(self, n_vitals: int, n_static: int, n_classes: int = 4):
        super().__init__()
        self.lstm1 = nn.LSTM(n_vitals, 64, batch_first=True)
        self.drop1 = nn.Dropout(0.3)
        self.lstm2 = nn.LSTM(64, 32, batch_first=True)
        self.drop2 = nn.Dropout(0.3)
        self.classifier = nn.Sequential(
            nn.Linear(32 + n_static, 32),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(32, n_classes),
        )

    def forward(self, x_ts: torch.Tensor, x_static: torch.Tensor) -> torch.Tensor:
        out, _ = self.lstm1(x_ts)
        out = self.drop1(out)
        out, _ = self.lstm2(out)
        out = out[:, -1, :]  # last timestep → (batch, 32)
        out = self.drop2(out)
        return self.classifier(torch.cat([out, x_static], dim=1))


n_classes = len(np.unique(y_train))
model = LSTMFusionModel(n_vitals, n_static, n_classes).to(DEVICE)
print(model)
print(f"\nTrainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

## 7. Train

### Why train loss can rise while val loss falls
Class weights penalize minority-class errors heavily (up to ~10×). In early epochs the model
shifts away from always predicting the majority class — this causes more minority-class mistakes
which the weighted loss inflates. Val loss falls because the model is genuinely improving on
rare classes. The pattern stabilises after a few epochs.

### What to monitor
Watch **val macro F1**, not val accuracy. F1 is the harmonic mean of precision and recall —
it tells you whether the model is actually catching deterioration events vs. just over-alerting.

In [None]:
# Inverse-frequency class weights
class_counts = np.bincount(y_train)
class_weights_np = len(y_train) / (n_classes * class_counts)
print("Class weights:", {i: round(w, 3) for i, w in enumerate(class_weights_np)})

criterion = nn.CrossEntropyLoss(weight=torch.tensor(class_weights_np, dtype=torch.float32).to(DEVICE))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=2, min_lr=1e-5)

def make_loader(X_ts, X_stat, y, batch_size=512, shuffle=True):
    ds = TensorDataset(
        torch.tensor(X_ts,   dtype=torch.float32),
        torch.tensor(X_stat, dtype=torch.float32),
        torch.tensor(y,      dtype=torch.long),
    )
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, pin_memory=(DEVICE.type == 'cuda'))

val_split = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=RANDOM_STATE)
tr_idx, val_idx = next(val_split.split(X_ts_train_sc, y_train))

train_loader = make_loader(X_ts_train_sc[tr_idx],  X_stat_train[tr_idx],  y_train[tr_idx])
val_loader   = make_loader(X_ts_train_sc[val_idx], X_stat_train[val_idx], y_train[val_idx], shuffle=False)

In [None]:
EPOCHS  = 15
PATIENCE = 3

best_val_f1 = -1.0
epochs_no_improve = 0
best_state = None
history = {'train_loss': [], 'val_loss': [], 'val_f1_macro': []}

for epoch in range(1, EPOCHS + 1):
    # --- Train ---
    model.train()
    train_loss = train_total = 0
    for X_ts_b, X_st_b, y_b in train_loader:
        X_ts_b, X_st_b, y_b = X_ts_b.to(DEVICE), X_st_b.to(DEVICE), y_b.to(DEVICE)
        optimizer.zero_grad()
        loss = criterion(model(X_ts_b, X_st_b), y_b)
        loss.backward()
        optimizer.step()
        train_loss  += loss.item() * len(y_b)
        train_total += len(y_b)

    # --- Validate: collect all predictions for F1 ---
    model.eval()
    val_loss = val_total = 0
    val_preds, val_true = [], []
    with torch.no_grad():
        for X_ts_b, X_st_b, y_b in val_loader:
            X_ts_b, X_st_b, y_b = X_ts_b.to(DEVICE), X_st_b.to(DEVICE), y_b.to(DEVICE)
            logits = model(X_ts_b, X_st_b)
            val_loss  += criterion(logits, y_b).item() * len(y_b)
            val_total += len(y_b)
            val_preds.append(logits.argmax(1).cpu())
            val_true.append(y_b.cpu())

    t_loss   = train_loss / train_total
    v_loss   = val_loss   / val_total
    val_preds_np = torch.cat(val_preds).numpy()
    val_true_np  = torch.cat(val_true).numpy()
    v_f1 = f1_score(val_true_np, val_preds_np, average='macro', zero_division=0)

    history['train_loss'].append(t_loss)
    history['val_loss'].append(v_loss)
    history['val_f1_macro'].append(v_f1)

    scheduler.step(v_loss)
    lr = optimizer.param_groups[0]['lr']
    print(f"Epoch {epoch:02d}  train_loss={t_loss:.4f}  val_loss={v_loss:.4f}  "
          f"val_f1_macro={v_f1:.4f}  lr={lr:.2e}")

    # Early stop on val macro F1 (maximise)
    if v_f1 > best_val_f1:
        best_val_f1 = v_f1
        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= PATIENCE:
            print(f"Early stopping at epoch {epoch} (best val macro F1={best_val_f1:.4f})")
            break

model.load_state_dict(best_state)
torch.save(model.state_dict(), MODELS_DIR / 'lstm_patient_fusion.pt')
print(f"\nBest model saved  |  val macro F1 = {best_val_f1:.4f}")

## 8. Training Curves

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(13, 4))

axes[0].plot(history['train_loss'], label='train loss')
axes[0].plot(history['val_loss'],   label='val loss')
axes[0].set_title('Weighted Loss')
axes[0].set_xlabel('Epoch')
axes[0].legend()

axes[1].plot(history['val_f1_macro'], color='tab:green', label='val macro F1')
axes[1].set_title('Val Macro F1 (↑ better)')
axes[1].set_xlabel('Epoch')
axes[1].set_ylim(0, 1)
axes[1].legend()

plt.tight_layout()
plt.show()

## 9. Evaluate on Test Set

In [None]:
def predict_proba(model, X_ts_sc, X_stat, batch_size=512):
    model.eval()
    loader = make_loader(X_ts_sc, X_stat, np.zeros(len(X_ts_sc), dtype=np.int64),
                         batch_size=batch_size, shuffle=False)
    probs = []
    with torch.no_grad():
        for X_ts_b, X_st_b, _ in loader:
            logits = model(X_ts_b.to(DEVICE), X_st_b.to(DEVICE))
            probs.append(torch.softmax(logits, dim=1).cpu().numpy())
    return np.vstack(probs)


y_proba = predict_proba(model, X_ts_test_sc, X_static_test)
y_pred  = y_proba.argmax(axis=1)

# --- Primary metric ---
roc_auc = roc_auc_score(y_test, y_proba, multi_class='ovr', average='macro')
print(f"Test ROC AUC (macro, OVR): {roc_auc:.4f}")
print()

# --- Classification report: per-class precision, recall, F1 ---
print(classification_report(y_test, y_pred,
      target_names=[f'Label {i}' for i in range(n_classes)],
      digits=4))

### 9a. Confusion Matrix

In [None]:
cm = confusion_matrix(y_test, y_pred)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

ConfusionMatrixDisplay(cm, display_labels=[f'Label {i}' for i in range(n_classes)]).plot(
    ax=axes[0], colorbar=False)
axes[0].set_title('Confusion Matrix (counts)')

# Normalised by true label (row) — reveals per-class recall
cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True)
ConfusionMatrixDisplay(cm_norm.round(2), display_labels=[f'Label {i}' for i in range(n_classes)]).plot(
    ax=axes[1], colorbar=False)
axes[1].set_title('Confusion Matrix (row-normalised = recall per class)')

plt.tight_layout()
plt.show()

### 9b. Precision-Recall Curves (one vs. rest)

For imbalanced classes AUPRC is more informative than ROC AUC:
a random classifier has AUPRC ≈ class prevalence, so a rare class makes it easy to game ROC AUC.

In [None]:
y_test_bin = label_binarize(y_test, classes=list(range(n_classes)))

fig, axes = plt.subplots(1, n_classes, figsize=(5 * n_classes, 4))
for cls in range(n_classes):
    prec, rec, _ = precision_recall_curve(y_test_bin[:, cls], y_proba[:, cls])
    auprc = average_precision_score(y_test_bin[:, cls], y_proba[:, cls])
    prevalence = y_test_bin[:, cls].mean()
    axes[cls].plot(rec, prec, lw=2)
    axes[cls].axhline(prevalence, linestyle='--', color='grey', label=f'baseline (prevalence={prevalence:.2f})')
    axes[cls].set_title(f'Label {cls} — AUPRC={auprc:.3f}')
    axes[cls].set_xlabel('Recall')
    axes[cls].set_ylabel('Precision')
    axes[cls].set_ylim(0, 1)
    axes[cls].legend(fontsize=8)

plt.suptitle('Precision-Recall Curves (one-vs-rest)', y=1.02)
plt.tight_layout()
plt.show()

print("\nPer-class summary:")
for cls in range(n_classes):
    roc = roc_auc_score(y_test_bin[:, cls], y_proba[:, cls])
    auprc = average_precision_score(y_test_bin[:, cls], y_proba[:, cls])
    prev = y_test_bin[:, cls].mean()
    print(f"  Label {cls}:  ROC AUC={roc:.4f}  AUPRC={auprc:.4f}  prevalence={prev:.3f}")

## 10. Holdout Predictions

In [None]:
holdout_raw = pd.read_csv(DATA_DIR / 'holdout_data_lagged.csv')

X_ts_holdout = reshape_lagged_to_sequences(holdout_raw, VITAL_COLS, N_LAGS)
X_ts_holdout_sc = ts_scaler.transform(
    X_ts_holdout.reshape(-1, n_vitals)
).reshape(X_ts_holdout.shape[0], n_timesteps, n_vitals)
X_static_holdout = build_static_features(holdout_raw, patients, pat_prep)

y_holdout_proba = predict_proba(model, X_ts_holdout_sc, X_static_holdout)
print(f"Holdout predictions shape: {y_holdout_proba.shape}")

if 'label' in holdout_raw.columns:
    roc_auc_holdout = roc_auc_score(
        holdout_raw['label'], y_holdout_proba, multi_class='ovr', average='macro'
    )
    print(f"Holdout ROC AUC (macro, OVR): {roc_auc_holdout:.4f}")
else:
    print("Holdout has no labels; predictions ready for submission.")

## 11. Save Submission CSV

In [None]:
submission = pd.DataFrame(y_holdout_proba, columns=[f'label_{i}' for i in range(n_classes)])
if 'encounter_id' in holdout_raw.columns:
    submission.insert(0, 'encounter_id', holdout_raw['encounter_id'].values)

out_path = DATA_DIR / 'lstm_fusion_holdout_predictions.csv'
submission.to_csv(out_path, index=False)
print(f"Saved: {out_path}")
submission.head()