# LSTM + Patient Features Fusion

Implements the dual-branch architecture from `specs/lstm_patient_fusion_spec.md`:
- **LSTM branch**: vital sign time series `(batch, 7, 5)` → `(batch, 32)`
- **Patient branch**: static demographics `(batch, n_static)` (Tier 1: age, gender, BMI, pain score, race, ethnicity)
- **Fusion**: `concat` → Dense(32) → Dense(4, softmax)

Evaluated with macro ROC AUC (OVR).

## 1. Setup

In [None]:
import numpy as np
import pandas as pd
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.preprocessing import StandardScaler, OneHotEncoder, label_binarize
from sklearn.metrics import roc_auc_score

# MPS = Apple Silicon GPU, falls back to CPU
if torch.backends.mps.is_available():
    DEVICE = torch.device('mps')
elif torch.cuda.is_available():
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')
print(f"Device: {DEVICE}")

DATA_DIR   = Path('../data')
MODELS_DIR = Path('../models')
MODELS_DIR.mkdir(exist_ok=True)

VITAL_COLS = ['heart_rate', 'systolic_bp', 'diastolic_bp', 'respiratory_rate', 'oxygen_saturation']
N_LAGS = 6
TRAIN_SAMPLE_SIZE = 500_000
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE)

# Tier 1 patient features
NUMERIC_PATIENT_COLS     = ['age', 'bmi', 'pain_score']
CATEGORICAL_PATIENT_COLS = ['race', 'ethnicity']
GENDER_COL = 'gender'

## 2. Load and Reshape Vital Time Series

Reshapes flat lag columns into `(n_samples, 7, 5)`: 7 timesteps (lag6→current), 5 vital channels.

In [None]:
def reshape_lagged_to_sequences(df: pd.DataFrame, vital_cols: list, n_lags: int) -> np.ndarray:
    """Reshape pre-computed lag columns into (n_samples, timesteps, n_vitals).

    Timestep order: lag_n, ..., lag_1, current (oldest → newest).
    """
    sequences = []
    for col in vital_cols:
        lag_cols = [f'{col}_lag{i}' for i in range(n_lags, 0, -1)]
        sequences.append(df[lag_cols + [col]].values)
    return np.stack(sequences, axis=-1)  # (n_samples, 7, 5)


train_raw = pd.read_csv(DATA_DIR / 'train_data_lagged.csv')
test_raw  = pd.read_csv(DATA_DIR / 'test_data_lagged.csv')

X_ts_train_full = reshape_lagged_to_sequences(train_raw, VITAL_COLS, N_LAGS)
y_train_full    = train_raw['label'].values
X_ts_test       = reshape_lagged_to_sequences(test_raw,  VITAL_COLS, N_LAGS)
y_test          = test_raw['label'].values

print(f"Train sequences: {X_ts_train_full.shape}")
print(f"Test  sequences: {X_ts_test.shape}")

## 3. Load and Preprocess Patient Features

**Tier 1** features per spec:
- `age`, `bmi`, `pain_score` — standardized, median-imputed
- `gender` — binary encoded (M=0, F=1), missing→-1
- `race`, `ethnicity` — one-hot encoded (fit on train only)

In [None]:
patients = pd.read_csv(DATA_DIR / 'patients.csv')
print(f"patients.csv shape: {patients.shape}")
print(patients.head(3))
print("\nColumn dtypes:")
print(patients.dtypes)

In [None]:
def encode_gender(series: pd.Series) -> np.ndarray:
    """Map gender to binary: M→0, F→1, missing→-1."""
    mapping = {'M': 0, 'Male': 0, 'male': 0, 'F': 1, 'Female': 1, 'female': 1}
    return series.map(mapping).fillna(-1).astype(float).values.reshape(-1, 1)


class PatientFeaturePreprocessor:
    """Fit on train patients; transform any split. Produces a dense float32 array."""

    def __init__(self, numeric_cols=NUMERIC_PATIENT_COLS,
                 categorical_cols=CATEGORICAL_PATIENT_COLS, gender_col=GENDER_COL):
        self.numeric_cols = numeric_cols
        self.categorical_cols = categorical_cols
        self.gender_col = gender_col
        self.num_scaler = StandardScaler()
        self.cat_encoder = OneHotEncoder(handle_unknown='ignore', sparse_output=False)
        self._num_medians: dict = {}
        self._cat_cols_present: list = []

    def fit(self, patients_df: pd.DataFrame) -> 'PatientFeaturePreprocessor':
        for col in self.numeric_cols:
            self._num_medians[col] = patients_df[col].median() if col in patients_df.columns else 0.0
        self.num_scaler.fit(self._impute_numeric(patients_df))

        self._cat_cols_present = [c for c in self.categorical_cols if c in patients_df.columns]
        if self._cat_cols_present:
            self.cat_encoder.fit(patients_df[self._cat_cols_present].fillna('unknown').astype(str))
        return self

    def _impute_numeric(self, df: pd.DataFrame) -> np.ndarray:
        parts = []
        for col in self.numeric_cols:
            vals = df[col].fillna(self._num_medians[col]).values if col in df.columns \
                   else np.full(len(df), self._num_medians.get(col, 0.0))
            parts.append(vals.reshape(-1, 1))
        return np.hstack(parts).astype(float)

    def transform(self, patients_df: pd.DataFrame) -> np.ndarray:
        parts = [self.num_scaler.transform(self._impute_numeric(patients_df))]
        if self.gender_col in patients_df.columns:
            parts.append(encode_gender(patients_df[self.gender_col]))
        else:
            parts.append(np.full((len(patients_df), 1), -1.0))
        if self._cat_cols_present:
            parts.append(self.cat_encoder.transform(
                patients_df[self._cat_cols_present].fillna('unknown').astype(str)
            ))
        return np.hstack(parts).astype(np.float32)


train_patients = patients[patients['encounter_id'].isin(train_raw['encounter_id'])]
pat_preprocessor = PatientFeaturePreprocessor()
pat_preprocessor.fit(train_patients)
n_static = pat_preprocessor.transform(train_patients.head(1)).shape[1]
print(f"n_static features: {n_static}")

## 4. Join Patient Features to Lagged Data

In [None]:
def build_static_features(lagged_df: pd.DataFrame, patients_df: pd.DataFrame,
                           preprocessor: PatientFeaturePreprocessor) -> np.ndarray:
    """Left-join patients to lagged rows on encounter_id, then transform."""
    pat_cols = ['encounter_id'] + [
        c for c in preprocessor.numeric_cols + preprocessor.categorical_cols + [preprocessor.gender_col]
        if c in patients_df.columns
    ]
    merged = lagged_df[['encounter_id']].merge(patients_df[pat_cols], on='encounter_id', how='left')
    return preprocessor.transform(merged)


X_static_train_full = build_static_features(train_raw, patients, pat_preprocessor)
X_static_test       = build_static_features(test_raw,  patients, pat_preprocessor)

print(f"Static train: {X_static_train_full.shape}")
print(f"Static test:  {X_static_test.shape}")

## 5. Stratified Sample + Standardize Vitals

In [None]:
splitter = StratifiedShuffleSplit(n_splits=1, train_size=TRAIN_SAMPLE_SIZE, random_state=RANDOM_STATE)
idx, _ = next(splitter.split(X_ts_train_full, y_train_full))

X_ts_train   = X_ts_train_full[idx]
X_stat_train = X_static_train_full[idx]
y_train      = y_train_full[idx]

print(f"Sampled train: {X_ts_train.shape}, static: {X_stat_train.shape}")
print(pd.Series(y_train).value_counts().sort_index())

In [None]:
n_samples, n_timesteps, n_vitals = X_ts_train.shape
ts_scaler = StandardScaler()
ts_scaler.fit(X_ts_train.reshape(-1, n_vitals))

X_ts_train_sc = ts_scaler.transform(X_ts_train.reshape(-1, n_vitals)).reshape(n_samples, n_timesteps, n_vitals)
X_ts_test_sc  = ts_scaler.transform(X_ts_test.reshape(-1, n_vitals)).reshape(X_ts_test.shape[0], n_timesteps, n_vitals)

print(f"Scaled train: {X_ts_train_sc.shape}, test: {X_ts_test_sc.shape}")

## 6. Build Fusion Model

```
vitals (7,5) → LSTM(64)→Dropout→LSTM(32)→Dropout ─┐
                                                    concat → Linear(32,relu)→Dropout→Linear(4)
static (n,)  ─────────────────────────────────────-┘
```

In [None]:
class LSTMFusionModel(nn.Module):
    def __init__(self, n_vitals: int, n_static: int, n_classes: int = 4):
        super().__init__()
        self.lstm1 = nn.LSTM(n_vitals, 64, batch_first=True)
        self.drop1 = nn.Dropout(0.3)
        self.lstm2 = nn.LSTM(64, 32, batch_first=True)
        self.drop2 = nn.Dropout(0.3)
        self.classifier = nn.Sequential(
            nn.Linear(32 + n_static, 32),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(32, n_classes),
        )

    def forward(self, x_ts: torch.Tensor, x_static: torch.Tensor) -> torch.Tensor:
        out, _ = self.lstm1(x_ts)
        out = self.drop1(out)
        out, _ = self.lstm2(out)
        out = out[:, -1, :]  # last timestep → (batch, 32)
        out = self.drop2(out)
        fused = torch.cat([out, x_static], dim=1)
        return self.classifier(fused)  # logits (batch, n_classes)


n_classes = len(np.unique(y_train))
model = LSTMFusionModel(n_vitals, n_static, n_classes).to(DEVICE)
print(model)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTrainable parameters: {total_params:,}")

## 7. Train

In [None]:
# Inverse-frequency class weights for imbalanced labels
class_counts = np.bincount(y_train)
class_weights_np = len(y_train) / (n_classes * class_counts)
class_weights_t = torch.tensor(class_weights_np, dtype=torch.float32).to(DEVICE)
print(f"Class weights: {dict(enumerate(class_weights_np.round(3)))}")

criterion = nn.CrossEntropyLoss(weight=class_weights_t)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=2, min_lr=1e-5)

# Build DataLoaders
def make_loader(X_ts, X_stat, y, batch_size=512, shuffle=True):
    ds = TensorDataset(
        torch.tensor(X_ts,   dtype=torch.float32),
        torch.tensor(X_stat, dtype=torch.float32),
        torch.tensor(y,      dtype=torch.long),
    )
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle)

# 90/10 train/val split
val_split = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=RANDOM_STATE)
tr_idx, val_idx = next(val_split.split(X_ts_train_sc, y_train))

train_loader = make_loader(X_ts_train_sc[tr_idx],  X_stat_train[tr_idx],  y_train[tr_idx])
val_loader   = make_loader(X_ts_train_sc[val_idx], X_stat_train[val_idx], y_train[val_idx], shuffle=False)

In [None]:
EPOCHS = 15
PATIENCE = 3

best_val_loss = float('inf')
epochs_no_improve = 0
best_state = None
history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}

for epoch in range(1, EPOCHS + 1):
    # --- Train ---
    model.train()
    train_loss = train_correct = train_total = 0
    for X_ts_b, X_st_b, y_b in train_loader:
        X_ts_b, X_st_b, y_b = X_ts_b.to(DEVICE), X_st_b.to(DEVICE), y_b.to(DEVICE)
        optimizer.zero_grad()
        logits = model(X_ts_b, X_st_b)
        loss = criterion(logits, y_b)
        loss.backward()
        optimizer.step()
        train_loss    += loss.item() * len(y_b)
        train_correct += (logits.argmax(1) == y_b).sum().item()
        train_total   += len(y_b)

    # --- Validate ---
    model.eval()
    val_loss = val_correct = val_total = 0
    with torch.no_grad():
        for X_ts_b, X_st_b, y_b in val_loader:
            X_ts_b, X_st_b, y_b = X_ts_b.to(DEVICE), X_st_b.to(DEVICE), y_b.to(DEVICE)
            logits = model(X_ts_b, X_st_b)
            val_loss    += criterion(logits, y_b).item() * len(y_b)
            val_correct += (logits.argmax(1) == y_b).sum().item()
            val_total   += len(y_b)

    t_loss = train_loss / train_total
    v_loss = val_loss   / val_total
    t_acc  = train_correct / train_total
    v_acc  = val_correct   / val_total
    history['train_loss'].append(t_loss)
    history['val_loss'].append(v_loss)
    history['train_acc'].append(t_acc)
    history['val_acc'].append(v_acc)

    scheduler.step(v_loss)
    lr = optimizer.param_groups[0]['lr']
    print(f"Epoch {epoch:02d}  train_loss={t_loss:.4f}  val_loss={v_loss:.4f}  "
          f"train_acc={t_acc:.4f}  val_acc={v_acc:.4f}  lr={lr:.2e}")

    # Early stopping
    if v_loss < best_val_loss:
        best_val_loss = v_loss
        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= PATIENCE:
            print(f"Early stopping at epoch {epoch}")
            break

model.load_state_dict(best_state)
torch.save(model.state_dict(), MODELS_DIR / 'lstm_patient_fusion.pt')
print(f"\nBest model saved to {MODELS_DIR / 'lstm_patient_fusion.pt'}")

## 8. Training Curves

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(history['train_loss'], label='train')
axes[0].plot(history['val_loss'],   label='val')
axes[0].set_title('Loss')
axes[0].set_xlabel('Epoch')
axes[0].legend()

axes[1].plot(history['train_acc'], label='train')
axes[1].plot(history['val_acc'],   label='val')
axes[1].set_title('Accuracy')
axes[1].set_xlabel('Epoch')
axes[1].legend()

plt.tight_layout()
plt.show()

## 9. Evaluate on Test Set

In [None]:
def predict_proba(model, X_ts_sc, X_stat, batch_size=512):
    """Run inference and return softmax probabilities as numpy array."""
    model.eval()
    loader = make_loader(X_ts_sc, X_stat, np.zeros(len(X_ts_sc), dtype=np.int64),
                         batch_size=batch_size, shuffle=False)
    probs = []
    with torch.no_grad():
        for X_ts_b, X_st_b, _ in loader:
            logits = model(X_ts_b.to(DEVICE), X_st_b.to(DEVICE))
            probs.append(torch.softmax(logits, dim=1).cpu().numpy())
    return np.vstack(probs)


y_proba = predict_proba(model, X_ts_test_sc, X_static_test)

roc_auc = roc_auc_score(y_test, y_proba, multi_class='ovr', average='macro')
print(f"Test ROC AUC (macro, OVR): {roc_auc:.4f}")

y_test_bin = label_binarize(y_test, classes=list(range(n_classes)))
for cls in range(n_classes):
    auc_cls = roc_auc_score(y_test_bin[:, cls], y_proba[:, cls])
    print(f"  Class {cls} AUC: {auc_cls:.4f}")

## 10. Holdout Predictions

In [None]:
holdout_raw = pd.read_csv(DATA_DIR / 'holdout_data_lagged.csv')

X_ts_holdout    = reshape_lagged_to_sequences(holdout_raw, VITAL_COLS, N_LAGS)
X_ts_holdout_sc = ts_scaler.transform(
    X_ts_holdout.reshape(-1, n_vitals)
).reshape(X_ts_holdout.shape[0], n_timesteps, n_vitals)
X_static_holdout = build_static_features(holdout_raw, patients, pat_preprocessor)

y_holdout_proba = predict_proba(model, X_ts_holdout_sc, X_static_holdout)
print(f"Holdout predictions shape: {y_holdout_proba.shape}")

if 'label' in holdout_raw.columns:
    roc_auc_holdout = roc_auc_score(
        holdout_raw['label'], y_holdout_proba, multi_class='ovr', average='macro'
    )
    print(f"Holdout ROC AUC (macro, OVR): {roc_auc_holdout:.4f}")
else:
    print("Holdout has no labels; predictions ready for submission.")

## 11. Save Submission CSV

In [None]:
submission = pd.DataFrame(y_holdout_proba, columns=[f'label_{i}' for i in range(n_classes)])
if 'encounter_id' in holdout_raw.columns:
    submission.insert(0, 'encounter_id', holdout_raw['encounter_id'].values)

out_path = DATA_DIR / 'lstm_fusion_holdout_predictions.csv'
submission.to_csv(out_path, index=False)
print(f"Saved: {out_path}")
submission.head()