# Continuation of the project; I did the baseline TF-IDF and I will continue to
# build BiLSTM and ClinicalModernBERT

### 0 . Shared helpers, preprocessor, metrics, plotting

0. HELPERS, PREPROCESSOR, METRICS, DIRS

In [None]:

# CHUNK 0 – HELPERS, PREPROCESSOR, METRICS, DIRS

import os, gc, time, math, random, re
from pathlib import Path

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import (
    roc_auc_score, f1_score, precision_score, recall_score,
    accuracy_score, confusion_matrix
)
from sklearn.calibration import calibration_curve

import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="whitegrid")

# ------------------------------
# Output directories
# ------------------------------
BASE_OUT_DIR = Path("/content/drive/MyDrive/clinical_project_cpu/deep_models")
PLOTS_DIR    = BASE_OUT_DIR / "plots"
CALIB_DIR    = BASE_OUT_DIR / "calibration"
EMB_DIR      = BASE_OUT_DIR / "embeddings"

for d in [BASE_OUT_DIR, PLOTS_DIR, CALIB_DIR, EMB_DIR]:
    d.mkdir(parents=True, exist_ok=True)

print("Device in deep-learning chunks:", device)
print("Plots dir:", PLOTS_DIR)
print("Calibration dir:", CALIB_DIR)
print("Embeddings dir:", EMB_DIR)

# ------------------------------
# GPU memory helper
# ------------------------------
def clear_gpu_memory():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# ------------------------------
# Clinical text preprocessor
# ------------------------------
class ClinicalTextPreprocessor:
    """
    Simple, domain-aware cleaner that:
    - lowercases
    - keeps alphanumerics and basic punctuation
    - joins TRIAGE + RADIOLOGY with explicit markers
    """
    def __init__(self, lowercase=True, strip_special=True, normalize_ws=True):
        self.lowercase = lowercase
        self.strip_special = strip_special
        self.normalize_ws = normalize_ws

    def _clean(self, text: str) -> str:
        if pd.isna(text):
            text = ""
        text = str(text)
        if self.lowercase:
            text = text.lower()
        if self.strip_special:
            text = re.sub(r"[^a-z0-9\.\,\-\?\!\:\;\(\)\/\s]+", " ", text)
        if self.normalize_ws:
            text = re.sub(r"\s+", " ", text).strip()
        return text

    def join_triage_radiology(self, triage_text, radiology_text) -> str:
        t = self._clean(triage_text)
        r = self._clean(radiology_text)
        parts = []
        if t:
            parts.append(f"TRIAGE: {t}")
        if r:
            parts.append(f"RADIOLOGY: {r}")
        return " [SEP] ".join(parts) if parts else ""

    def __call__(self, row):
        return self.join_triage_radiology(row.get("triage_text", ""),
                                          row.get("radiology_text", ""))

preprocessor = ClinicalTextPreprocessor()

# add combined_text once here
for df in [X_train, X_val, X_test]:
    df["combined_text"] = df.apply(preprocessor, axis=1)

print("Sample combined text:\n", X_train["combined_text"].head(3).tolist())

# ------------------------------
# Vocab builder + tokenizer for BiLSTM
# ------------------------------
from collections import Counter

def build_vocab(texts, vocab_size=10000, min_freq=2):
    counter = Counter()
    for txt in texts:
        for tok in txt.split():
            counter[tok] += 1
    # reserve 0: PAD, 1: UNK
    most_common = [w for w, c in counter.items() if c >= min_freq]
    most_common = sorted(most_common, key=lambda w: counter[w],
                         reverse=True)[:vocab_size-2]
    vocab = {"<PAD>": 0, "<UNK>": 1}
    for i, w in enumerate(most_common, start=2):
        vocab[w] = i
    return vocab

def text_to_ids(text, vocab, max_len):
    tokens = text.split()
    ids = [vocab.get(tok, vocab["<UNK>"]) for tok in tokens]
    if len(ids) > max_len:
        ids = ids[:max_len]
    pad_len = max_len - len(ids)
    ids = ids + [vocab["<PAD>"]] * pad_len
    attn_mask = [1]*min(len(tokens), max_len) + [0]*pad_len
    return np.array(ids, dtype=np.int64), np.array(attn_mask, dtype=np.int64)

# ------------------------------
# Generic binary metrics helper
# ------------------------------
def binary_metrics(y_true, y_prob, threshold=0.5):
    y_true = np.asarray(y_true)
    y_prob = np.asarray(y_prob)
    y_pred = (y_prob >= threshold).astype(int)

    auc = roc_auc_score(y_true, y_prob)
    f1  = f1_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred)
    rec  = recall_score(y_true, y_pred)
    acc  = accuracy_score(y_true, y_pred)

    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    spec = tn / (tn + fp + 1e-8)

    return {
        "auc": auc,
        "f1": f1,
        "precision": prec,
        "recall": rec,
        "accuracy": acc,
        "specificity": spec,
        "confusion": (tn, fp, fn, tp),
    }

# ------------------------------
# Plot helpers
# ------------------------------
def plot_training_curves(history, model_name, out_dir=PLOTS_DIR):
    """
    history: dict with keys "train_loss", "val_loss", "val_auc"
    """
    epochs = range(1, len(history["train_loss"])+1)

    # Loss
    plt.figure(figsize=(6,4))
    plt.plot(epochs, history["train_loss"], label="Train loss")
    plt.plot(epochs, history["val_loss"], label="Val loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(f"{model_name} – Loss")
    plt.legend()
    plt.tight_layout()
    path_loss = out_dir / f"{model_name}_loss.png"
    plt.savefig(path_loss, dpi=150)
    plt.close()

    # AUC
    plt.figure(figsize=(6,4))
    plt.plot(epochs, history["val_auc"], marker="o")
    plt.xlabel("Epoch")
    plt.ylabel("Validation AUC")
    plt.title(f"{model_name} – Validation AUC")
    plt.tight_layout()
    path_auc = out_dir / f"{model_name}_val_auc.png"
    plt.savefig(path_auc, dpi=150)
    plt.close()

    print(f"[{model_name}] Saved loss + AUC curves to {out_dir}")
    return str(path_loss), str(path_auc)

def plot_confusion_matrix(cm, model_name, out_dir=PLOTS_DIR):
    tn, fp, fn, tp = cm
    mat = np.array([[tn, fp],
                    [fn, tp]])
    plt.figure(figsize=(4,4))
    sns.heatmap(mat, annot=True, fmt="d", cmap="Blues",
                xticklabels=["Pred 0", "Pred 1"],
                yticklabels=["True 0", "True 1"])
    plt.title(f"{model_name} – Confusion Matrix")
    plt.tight_layout()
    path_cm = out_dir / f"{model_name}_confusion_matrix.png"
    plt.savefig(path_cm, dpi=150)
    plt.close()
    print(f"[{model_name}] Saved confusion matrix to {path_cm}")
    return str(path_cm)


Device in deep-learning chunks: cuda
Plots dir: /content/drive/MyDrive/clinical_project_cpu/deep_models/plots
Calibration dir: /content/drive/MyDrive/clinical_project_cpu/deep_models/calibration
Embeddings dir: /content/drive/MyDrive/clinical_project_cpu/deep_models/embeddings
Sample combined text:
 ['TRIAGE: abdominal pain, vomiting', 'TRIAGE: anorexia', 'TRIAGE: abdominal pain, abdominal distention, bowel obstruction [SEP] RADIOLOGY: examination: chest (posteroanterior and lateral) indication: with shortness of breath// eval pneumonia comparison: prior dated findings: posteroanterior and lateral views of the chest provided. lungs are clear. clips are noted near the ge junction. 2 discrete metallic stents partially visualized in the right upper quadrant. there is negation 0 focal consolidation, effusion, or pneumothorax. there are negation 0 signs of congestion or edema. the cardiomediastinal silhouette is normal. imaged osseous structures are intact. negation 0 free air below the rig

1. BiLSTM COMPACT ENCODER

In [None]:

# 1 BiLSTM COMPACT ENCODER

MAX_LEN_LSTM = 512
VOCAB_SIZE   = 10000

print(f"Baseline TF-IDF AUC to beat: {baseline_auc:.3f}")

# 1) Build vocabulary on training combined text
print("\n[BiLSTM] Building vocabulary...")
vocab = build_vocab(X_train["combined_text"].tolist(),
                    vocab_size=VOCAB_SIZE, min_freq=2)
print(f"Vocab size (including PAD/UNK): {len(vocab)}")

# 2) Dataset for BiLSTM  (INCLUDES stay_id)
class ClinicalLSTMDataset(Dataset):
    def __init__(self, df, labels, vocab, max_len=512):
        self.texts   = df["combined_text"].tolist()
        self.labels  = labels.values.astype(np.int64)
        # prefer explicit stay_id column if present, else use index
        if "stay_id" in df.columns:
            self.stay_ids = df["stay_id"].values.astype(np.int64)
        else:
            self.stay_ids = df.index.values.astype(np.int64)
        self.vocab   = vocab
        self.max_len = max_len

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        text = self.texts[idx]
        ids, attn = text_to_ids(text, self.vocab, self.max_len)
        return {
            "input_ids":      torch.tensor(ids,  dtype=torch.long),
            "attention_mask": torch.tensor(attn, dtype=torch.float32),
            "label":          torch.tensor(self.labels[idx],   dtype=torch.long),
            "stay_id":        torch.tensor(self.stay_ids[idx], dtype=torch.long),
        }

train_ds_lstm = ClinicalLSTMDataset(X_train, y_train, vocab, max_len=MAX_LEN_LSTM)
val_ds_lstm   = ClinicalLSTMDataset(X_val,   y_val,   vocab, max_len=MAX_LEN_LSTM)
test_ds_lstm  = ClinicalLSTMDataset(X_test,  y_test,  vocab, max_len=MAX_LEN_LSTM)

train_loader_lstm = DataLoader(train_ds_lstm, batch_size=32, shuffle=True)
val_loader_lstm   = DataLoader(val_ds_lstm,   batch_size=64, shuffle=False)
test_loader_lstm  = DataLoader(test_ds_lstm,  batch_size=64, shuffle=False)

# 3) BiLSTM with attention + encode_text
class BiLSTMEncoder(nn.Module):
    """
    Compact domain-aware encoder:
    - Embedding + BiLSTM
    - Attention pooling to get fixed-length embedding
    - encode_text() returns [batch, hidden_dim*2] embeddings
    """
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=128,
                 num_layers=2, dropout=0.2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(
            embed_dim,
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout,
            bidirectional=True,
        )
        self.attn = nn.Linear(hidden_dim*2, 1)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(hidden_dim*2, 1)   # binary

    def encode_text(self, input_ids, attention_mask):
        emb = self.embedding(input_ids)                      # [B, T, D]
        lstm_out, _ = self.lstm(emb)                         # [B, T, 2H]

        scores = self.attn(lstm_out).squeeze(-1)             # [B, T]
        scores = scores.masked_fill(attention_mask == 0, -1e9)
        weights = torch.softmax(scores, dim=-1)              # [B, T]
        ctx = torch.bmm(weights.unsqueeze(1), lstm_out)      # [B, 1, 2H]
        ctx = ctx.squeeze(1)                                 # [B, 2H]
        return ctx

    def forward(self, input_ids, attention_mask):
        ctx = self.encode_text(input_ids, attention_mask)
        logits = self.classifier(self.dropout(ctx)).squeeze(-1)   # [B]
        return logits

# 4) Training loop (logs loss + AUC)
def train_lstm_model(model, train_loader, val_loader,
                     n_epochs=6, lr=3e-4, weight_decay=0.02):

    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr,
                                  weight_decay=weight_decay)
    criterion = nn.BCEWithLogitsLoss()

    history = {"train_loss": [], "val_loss": [], "val_auc": []}
    best_auc = 0.0
    best_state = None
    patience = 3
    no_improve = 0

    for epoch in range(1, n_epochs+1):
        model.train()
        running_loss = 0.0
        for batch in train_loader:
            ids   = batch["input_ids"].to(device)
            mask  = batch["attention_mask"].to(device)
            labels = batch["label"].float().to(device)

            optimizer.zero_grad()
            logits = model(ids, mask)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * ids.size(0)

        train_loss = running_loss / len(train_loader.dataset)

        # ----- validation -----
        model.eval()
        val_loss = 0.0
        all_probs, all_labels = [], []
        with torch.no_grad():
            for batch in val_loader:
                ids   = batch["input_ids"].to(device)
                mask  = batch["attention_mask"].to(device)
                labels = batch["label"].float().to(device)

                logits = model(ids, mask)
                loss = criterion(logits, labels)
                val_loss += loss.item() * ids.size(0)

                probs = torch.sigmoid(logits).cpu().numpy()
                all_probs.append(probs)
                all_labels.append(batch["label"].numpy())

        val_loss = val_loss / len(val_loader.dataset)
        all_probs = np.concatenate(all_probs)
        all_labels = np.concatenate(all_labels)

        metrics = binary_metrics(all_labels, all_probs)
        val_auc = metrics["auc"]

        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["val_auc"].append(val_auc)

        print(f"[BiLSTM] Epoch {epoch}/{n_epochs} "
              f"Train loss={train_loss:.4f} | Val loss={val_loss:.4f} "
              f"| Val AUC={val_auc:.4f}")

        if val_auc > best_auc + 1e-4:
            best_auc = val_auc
            best_state = model.state_dict()
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                print("[BiLSTM] Early stopping – no improvement.")
                break

    if best_state is not None:
        model.load_state_dict(best_state)

    return model, history, best_auc

# 5) Train BiLSTM
clear_gpu_memory()
bilstm_model = BiLSTMEncoder(vocab_size=len(vocab), embed_dim=128,
                             hidden_dim=128, num_layers=2, dropout=0.2)

t0 = time.time()
bilstm_model, history_lstm, best_val_auc_lstm = train_lstm_model(
    bilstm_model,
    train_loader_lstm,
    val_loader_lstm,
    n_epochs=6,
    lr=3e-4,
    weight_decay=0.02,
)
elapsed = (time.time() - t0) / 60.0
print(f"[BiLSTM] Training finished in {elapsed:.1f} minutes. "
      f"Best Val AUC={best_val_auc_lstm:.4f}")

plot_training_curves(history_lstm, model_name="BiLSTM")

# 6) Evaluate on TEST and plot confusion matrix + full metrics
def evaluate_lstm_on_test(model, loader):
    model.eval()
    all_probs, all_labels = [], []
    with torch.no_grad():
        for batch in loader:
            ids   = batch["input_ids"].to(device)
            mask  = batch["attention_mask"].to(device)
            labels = batch["label"].numpy()

            logits = model(ids, mask)
            probs = torch.sigmoid(logits).cpu().numpy()
            all_probs.append(probs)
            all_labels.append(labels)

    all_probs  = np.concatenate(all_probs)
    all_labels = np.concatenate(all_labels)
    metrics = binary_metrics(all_labels, all_probs, threshold=0.5)

    print("\n[BiLSTM] TEST METRICS")
    for k in ["auc", "f1", "precision", "recall",
              "specificity", "accuracy"]:
        print(f"  {k:>11}: {metrics[k]:.4f}")
    tn, fp, fn, tp = metrics["confusion"]
    print(f"  Confusion: TN={tn}, FP={fp}, FN={fn}, TP={tp}")

    plot_confusion_matrix(metrics["confusion"], model_name="BiLSTM")
    return all_labels, all_probs, metrics

y_test_lstm, y_prob_lstm, metrics_lstm = evaluate_lstm_on_test(
    bilstm_model, test_loader_lstm
)

print(f"\n[BiLSTM] Improvement over TF-IDF baseline: "
      f"{metrics_lstm['auc'] - baseline_auc:+.4f} AUC")


Baseline TF-IDF AUC to beat: 0.818

[BiLSTM] Building vocabulary...
Vocab size (including PAD/UNK): 10000
[BiLSTM] Epoch 1/6 Train loss=0.4793 | Val loss=0.4633 | Val AUC=0.8087
[BiLSTM] Epoch 2/6 Train loss=0.4501 | Val loss=0.4538 | Val AUC=0.8196
[BiLSTM] Epoch 3/6 Train loss=0.4381 | Val loss=0.4501 | Val AUC=0.8245
[BiLSTM] Epoch 4/6 Train loss=0.4292 | Val loss=0.4483 | Val AUC=0.8252
[BiLSTM] Epoch 5/6 Train loss=0.4191 | Val loss=0.4543 | Val AUC=0.8232
[BiLSTM] Epoch 6/6 Train loss=0.4084 | Val loss=0.4586 | Val AUC=0.8205
[BiLSTM] Training finished in 8.7 minutes. Best Val AUC=0.8252
[BiLSTM] Saved loss + AUC curves to /content/drive/MyDrive/clinical_project_cpu/deep_models/plots

[BiLSTM] TEST METRICS
          auc: 0.8144
           f1: 0.4889
    precision: 0.7028
       recall: 0.3748
  specificity: 0.9391
     accuracy: 0.7824
  Confusion: TN=20296, FP=1317, FN=5195, TP=3115
[BiLSTM] Saved confusion matrix to /content/drive/MyDrive/clinical_project_cpu/deep_models/plots/

2. ClinicalModernBERT ENCODER

In [52]:
# CHUNK 2 – TASK 6: ClinicalModernBERT ENCODER

from transformers import AutoTokenizer, AutoModel

BERT_MODEL_NAME = "simonlee711/clinical_modernbert"  # same as before
MAX_LEN_BERT    = 256
BATCH_SIZE_BERT = 16

print("\n[Clinical BERT] Loading tokenizer & base model...")
tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)
bert_base = AutoModel.from_pretrained(BERT_MODEL_NAME)

class ClinicalBertEncoder(nn.Module):
    """
    Compact encoder on top of ModernBERT:
    - mean pooling over last hidden states (with mask)
    - encode_text() returns [B, hidden_dim]
    """
    def __init__(self, bert_model, dropout=0.3):
        super().__init__()
        self.bert = bert_model
        hidden_size = self.bert.config.hidden_size
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(hidden_size, 1)

    def encode_text(self, input_ids, attention_mask):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        last_hidden = outputs.last_hidden_state   # [B, T, H]
        mask = attention_mask.unsqueeze(-1)       # [B, T, 1]
        summed  = torch.sum(last_hidden * mask, dim=1)   # [B, H]
        counts  = torch.clamp(mask.sum(dim=1), min=1e-9) # [B, 1]
        mean_pooled = summed / counts
        return mean_pooled

    def forward(self, input_ids, attention_mask):
        emb = self.encode_text(input_ids, attention_mask)
        logits = self.classifier(self.dropout(emb)).squeeze(-1)
        return logits

# Dataset includes stay_id
class ClinicalBertDataset(Dataset):
    def __init__(self, df, labels, tokenizer, max_len=256):
        self.texts    = df["combined_text"].tolist()
        self.labels   = labels.values.astype(np.int64)
        if "stay_id" in df.columns:
            self.stay_ids = df["stay_id"].values.astype(np.int64)
        else:
            self.stay_ids = df.index.values.astype(np.int64)
        self.tokenizer = tokenizer
        self.max_len   = max_len

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        text = self.texts[idx]
        enc = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_len,
            padding="max_length",
            return_tensors="pt",
        )
        return {
            "input_ids":      enc["input_ids"].squeeze(0),
            "attention_mask": enc["attention_mask"].squeeze(0),
            "label":          torch.tensor(self.labels[idx],   dtype=torch.long),
            "stay_id":        torch.tensor(self.stay_ids[idx], dtype=torch.long),
        }

train_ds_bert = ClinicalBertDataset(X_train, y_train, tokenizer,
                                    max_len=MAX_LEN_BERT)
val_ds_bert   = ClinicalBertDataset(X_val,   y_val,   tokenizer,
                                    max_len=MAX_LEN_BERT)
test_ds_bert  = ClinicalBertDataset(X_test,  y_test,  tokenizer,
                                    max_len=MAX_LEN_BERT)

train_loader_bert = DataLoader(train_ds_bert, batch_size=BATCH_SIZE_BERT,
                               shuffle=True)
val_loader_bert   = DataLoader(val_ds_bert,   batch_size=BATCH_SIZE_BERT*2,
                               shuffle=False)
test_loader_bert  = DataLoader(test_ds_bert,  batch_size=BATCH_SIZE_BERT*2,
                               shuffle=False)

def train_bert_model(model, train_loader, val_loader,
                     n_epochs=5, lr=2e-5, weight_decay=0.01):

    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr,
                                  weight_decay=weight_decay)
    criterion = nn.BCEWithLogitsLoss()

    history = {"train_loss": [], "val_loss": [], "val_auc": []}
    best_auc = 0.0
    best_state = None
    patience = 3
    no_improve = 0

    for epoch in range(1, n_epochs+1):
        model.train()
        running_loss = 0.0
        for batch in train_loader:
            ids   = batch["input_ids"].to(device)
            mask  = batch["attention_mask"].to(device)
            labels = batch["label"].float().to(device)

            optimizer.zero_grad()
            logits = model(ids, mask)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * ids.size(0)

        train_loss = running_loss / len(train_loader.dataset)

        # ----- validation -----
        model.eval()
        val_loss = 0.0
        all_probs, all_labels = [], []
        with torch.no_grad():
            for batch in val_loader:
                ids   = batch["input_ids"].to(device)
                mask  = batch["attention_mask"].to(device)
                labels = batch["label"].float().to(device)

                logits = model(ids, mask)
                loss = criterion(logits, labels)
                val_loss += loss.item() * ids.size(0)

                probs = torch.sigmoid(logits).cpu().numpy()
                all_probs.append(probs)
                all_labels.append(batch["label"].numpy())

        val_loss = val_loss / len(val_loader.dataset)
        all_probs  = np.concatenate(all_probs)
        all_labels = np.concatenate(all_labels)

        metrics = binary_metrics(all_labels, all_probs)
        val_auc = metrics["auc"]

        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["val_auc"].append(val_auc)

        print(f"[Clinical BERT] Epoch {epoch}/{n_epochs} "
              f"Train loss={train_loss:.4f} | Val loss={val_loss:.4f} "
              f"| Val AUC={val_auc:.4f}")

        if val_auc > best_auc + 1e-4:
            best_auc = val_auc
            best_state = model.state_dict()
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                print("[Clinical BERT] Early stopping – no improvement.")
                break

    if best_state is not None:
        model.load_state_dict(best_state)
    return model, history, best_auc

clear_gpu_memory()
bert_model = ClinicalBertEncoder(bert_base, dropout=0.3)

t0 = time.time()
bert_model, history_bert, best_val_auc_bert = train_bert_model(
    bert_model,
    train_loader_bert,
    val_loader_bert,
    n_epochs=5,
    lr=2e-5,
    weight_decay=0.01,
)
elapsed = (time.time() - t0) / 60.0
print(f"[Clinical BERT] Training finished in {elapsed:.1f} minutes. "
      f"Best Val AUC={best_val_auc_bert:.4f}")

plot_training_curves(history_bert, model_name="ClinicalBERT")

def evaluate_bert_on_test(model, loader):
    model.eval()
    all_probs, all_labels = [], []
    with torch.no_grad():
        for batch in loader:
            ids   = batch["input_ids"].to(device)
            mask  = batch["attention_mask"].to(device)

            logits = model(ids, mask)
            probs = torch.sigmoid(logits).cpu().numpy()
            all_probs.append(probs)
            all_labels.append(batch["label"].numpy())

    all_probs  = np.concatenate(all_probs)
    all_labels = np.concatenate(all_labels)
    metrics = binary_metrics(all_labels, all_probs, threshold=0.5)

    print("\n[Clinical BERT] TEST METRICS")
    for k in ["auc", "f1", "precision", "recall",
              "specificity", "accuracy"]:
        print(f"  {k:>11}: {metrics[k]:.4f}")
    tn, fp, fn, tp = metrics["confusion"]
    print(f"  Confusion: TN={tn}, FP={fp}, FN={fn}, TP={tp}")

    plot_confusion_matrix(metrics["confusion"], model_name="ClinicalBERT")
    return all_labels, all_probs, metrics

y_test_bert, y_prob_bert, metrics_bert = evaluate_bert_on_test(
    bert_model, test_loader_bert
)
print(f"\n[Clinical BERT] Improvement over TF-IDF baseline: "
      f"{metrics_bert['auc'] - baseline_auc:+.4f} AUC")

# Decide best deep model for Task 7
if metrics_bert["auc"] >= metrics_lstm["auc"]:
    best_model_name = "ClinicalBERT"
    best_model = bert_model
    best_y_true = y_test_bert
    best_y_prob = y_prob_bert
else:
    best_model_name = "BiLSTM"
    best_model = bilstm_model
    best_y_true = y_test_lstm
    best_y_prob = y_prob_lstm

print(f"\n[Task 6] Best deep model for Task 7: {best_model_name} "
      f"(AUC={max(metrics_bert['auc'], metrics_lstm['auc']):.4f})")



[Clinical BERT] Loading tokenizer & base model...


Some weights of BertModel were not initialized from the model checkpoint at simonlee711/clinical_modernbert and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[Clinical BERT] Epoch 1/5 Train loss=0.4651 | Val loss=0.4530 | Val AUC=0.8200
[Clinical BERT] Epoch 2/5 Train loss=0.4302 | Val loss=0.4446 | Val AUC=0.8267
[Clinical BERT] Epoch 3/5 Train loss=0.3825 | Val loss=0.4636 | Val AUC=0.8180
[Clinical BERT] Epoch 4/5 Train loss=0.2895 | Val loss=0.5454 | Val AUC=0.8000
[Clinical BERT] Epoch 5/5 Train loss=0.2173 | Val loss=0.7089 | Val AUC=0.7747
[Clinical BERT] Early stopping – no improvement.
[Clinical BERT] Training finished in 163.1 minutes. Best Val AUC=0.8267
[ClinicalBERT] Saved loss + AUC curves to /content/drive/MyDrive/clinical_project_cpu/deep_models/plots

[Clinical BERT] TEST METRICS
          auc: 0.7722
           f1: 0.5374
    precision: 0.6020
       recall: 0.4853
  specificity: 0.8766
     accuracy: 0.7680
  Confusion: TN=18947, FP=2666, FN=4277, TP=4033
[ClinicalBERT] Saved confusion matrix to /content/drive/MyDrive/clinical_project_cpu/deep_models/plots/ClinicalBERT_confusion_matrix.png

[Clinical BERT] Improvement ove

3. CALIBRATION & RELIABILITY

In [53]:
# CHUNK 3 – TASK 7: CALIBRATION & RELIABILITY

class TemperatureScaler(nn.Module):
    """Temperature scaling on logits."""
    def __init__(self):
        super().__init__()
        self.log_T = nn.Parameter(torch.zeros(1))

    def forward(self, logits):
        T = torch.exp(self.log_T)
        return logits / T

def fit_temperature(model, val_loader):
    model.eval()
    logits_list, labels_list = [], []
    with torch.no_grad():
        for batch in val_loader:
            ids   = batch["input_ids"].to(device)
            mask  = batch["attention_mask"].to(device)
            labels = batch["label"].float().to(device)
            logits = model(ids, mask)
            logits_list.append(logits)
            labels_list.append(labels)
    logits = torch.cat(logits_list)
    labels = torch.cat(labels_list)

    scaler = TemperatureScaler().to(device)
    optimizer = torch.optim.LBFGS([scaler.log_T], lr=0.01, max_iter=50)
    nll_criterion = nn.BCEWithLogitsLoss()

    def _closure():
        optimizer.zero_grad()
        scaled_logits = scaler(logits)
        loss = nll_criterion(scaled_logits, labels)
        loss.backward()
        return loss

    optimizer.step(_closure)
    with torch.no_grad():
        T = torch.exp(scaler.log_T).item()
    print(f"[Calibration] Optimal temperature: {T:.3f}")
    return scaler

def reliability_plot(y_true, y_prob, title, out_path):
    prob_true, prob_pred = calibration_curve(
        y_true, y_prob, n_bins=10, strategy="quantile"
    )

    plt.figure(figsize=(5,5))
    plt.plot([0,1], [0,1], "k--", label="Perfectly calibrated")
    plt.plot(prob_pred, prob_true, marker="o", label="Model")
    plt.xlabel("Predicted probability")
    plt.ylabel("Observed frequency")
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_path, dpi=150)
    plt.close()
    print(f"[Calibration] Saved reliability plot to {out_path}")

# choose loaders for best model
if best_model_name == "ClinicalBERT":
    val_loader_best  = val_loader_bert
    test_loader_best = test_loader_bert
else:
    val_loader_best  = val_loader_lstm
    test_loader_best = test_loader_lstm

clear_gpu_memory()
best_model = best_model.to(device)
temp_scaler = fit_temperature(best_model, val_loader_best)

# apply calibration to TEST logits
best_model.eval()
logits_uncal, logits_cal, labels_test = [], [], []

with torch.no_grad():
    for batch in test_loader_best:
        ids   = batch["input_ids"].to(device)
        mask  = batch["attention_mask"].to(device)
        labels = batch["label"].numpy()

        logits = best_model(ids, mask)
        logits_uncal.append(logits.cpu().numpy())
        logits_cal.append(temp_scaler(logits).cpu().numpy())
        labels_test.append(labels)

logits_uncal = np.concatenate(logits_uncal).squeeze()
logits_cal   = np.concatenate(logits_cal).squeeze()
labels_test  = np.concatenate(labels_test)

probs_uncal = 1 / (1 + np.exp(-logits_uncal))
probs_cal   = 1 / (1 + np.exp(-logits_cal))

metrics_uncal = binary_metrics(labels_test, probs_uncal)
metrics_cal   = binary_metrics(labels_test, probs_cal)

print("\n[Calibration] TEST comparison (AUC & F1):")
print(f"  Uncalibrated AUC: {metrics_uncal['auc']:.4f}, "
      f"F1: {metrics_uncal['f1']:.4f}")
print(f"  Temp-scaled  AUC: {metrics_cal['auc']:.4f}, "
      f"F1: {metrics_cal['f1']:.4f}")

# Reliability plots
reliability_plot(
    labels_test,
    probs_uncal,
    title=f"{best_model_name} – Reliability (Uncalibrated)",
    out_path=CALIB_DIR / f"{best_model_name}_reliability_uncal.png",
)
reliability_plot(
    labels_test,
    probs_cal,
    title=f"{best_model_name} – Reliability (Temp-scaled)",
    out_path=CALIB_DIR / f"{best_model_name}_reliability_temp_scaled.png",
)


[Calibration] Optimal temperature: 1.081

[Calibration] TEST comparison (AUC & F1):
  Uncalibrated AUC: 0.8144, F1: 0.4889
  Temp-scaled  AUC: 0.8144, F1: 0.4889
[Calibration] Saved reliability plot to /content/drive/MyDrive/clinical_project_cpu/deep_models/calibration/BiLSTM_reliability_uncal.png
[Calibration] Saved reliability plot to /content/drive/MyDrive/clinical_project_cpu/deep_models/calibration/BiLSTM_reliability_temp_scaled.png


4. ROBUSTNESS & SIMPLE RATIONALES

In [54]:
# ROBUSTNESS & SIMPLE RATIONALES

import copy

# ---------- simple text augmentations ----------
def simple_paraphrase(text: str) -> str:
    repl = {
        "shortness of breath": "difficulty breathing",
        "sob ": "short of breath ",
        "chest pain": "pain in the chest",
        "no evidence of": "no clear sign of",
        "denies": "reports no",
    }
    out = text
    for k, v in repl.items():
        out = out.replace(k, v)
    return out

def reorder_sections(text: str) -> str:
    if "TRIAGE:" in text and "RADIOLOGY:" in text:
        triage_part = re.findall(r"TRIAGE:.*?(?=RADIOLOGY:|$)", text,
                                 flags=re.IGNORECASE)
        radiol_part = re.findall(r"RADIOLOGY:.*?(?=$)", text,
                                 flags=re.IGNORECASE)
        if triage_part and radiol_part:
            return radiol_part[0] + " [SEP] " + triage_part[0]
    return text

def robustness_check(model, df, labels, tokenizer_or_vocab,
                     n_samples=200, model_type="bert"):
    model.eval()
    idxs = np.random.choice(len(df), size=min(n_samples, len(df)),
                            replace=False)
    orig_preds, para_preds, reorder_preds = [], [], []

    with torch.no_grad():
        for idx in idxs:
            row = df.iloc[idx]
            base_text = row["combined_text"]

            texts = {
                "orig": base_text,
                "para": simple_paraphrase(base_text),
                "reorder": reorder_sections(base_text),
            }

            probs = {}
            for key, txt in texts.items():
                if model_type == "bert":
                    enc = tokenizer_or_vocab(
                        txt,
                        truncation=True,
                        max_length=MAX_LEN_BERT,
                        padding="max_length",
                        return_tensors="pt",
                    )
                    ids  = enc["input_ids"].to(device)
                    mask = enc["attention_mask"].to(device)
                else:
                    ids_np, mask_np = text_to_ids(
                        txt, tokenizer_or_vocab, MAX_LEN_LSTM
                    )
                    ids  = torch.tensor(ids_np, dtype=torch.long,
                                        device=device).unsqueeze(0)
                    mask = torch.tensor(mask_np, dtype=torch.float32,
                                        device=device).unsqueeze(0)

                logits = model(ids, mask)
                prob = torch.sigmoid(logits).item()
                probs[key] = prob

            orig_preds.append(1 if probs["orig"]   >= 0.5 else 0)
            para_preds.append(1 if probs["para"]   >= 0.5 else 0)
            reorder_preds.append(1 if probs["reorder"] >= 0.5 else 0)

    orig_preds    = np.array(orig_preds)
    para_preds    = np.array(para_preds)
    reorder_preds = np.array(reorder_preds)

    para_consistency    = (orig_preds == para_preds).mean()
    reorder_consistency = (orig_preds == reorder_preds).mean()
    print(f"\n[Robustness] Paraphrasing consistency: {para_consistency:.3f}")
    print(f"[Robustness] Section reordering consistency: "
          f"{reorder_consistency:.3f}")

    plt.figure(figsize=(4,4))
    sns.barplot(x=["Paraphrase", "Reorder"],
                y=[para_consistency, reorder_consistency])
    plt.ylim(0,1)
    plt.ylabel("Consistency rate")
    plt.title(f"{best_model_name} – Robustness")
    plt.tight_layout()
    path_rob = PLOTS_DIR / f"{best_model_name}_robustness.png"
    plt.savefig(path_rob, dpi=150)
    plt.close()
    print(f"[Robustness] Saved bar plot to {path_rob}")

    return para_consistency, reorder_consistency

if best_model_name == "ClinicalBERT":
    rob_para, rob_reorder = robustness_check(
        best_model, X_test, y_test, tokenizer,
        n_samples=200, model_type="bert"
    )
else:
    rob_para, rob_reorder = robustness_check(
        best_model, X_test, y_test, vocab,
        n_samples=200, model_type="lstm"
    )

# Simple rationales for ClinicalBERT only (optional)
def extract_rationale_tokens_bert(text, model, tokenizer, top_k=10):
    model.eval()
    with torch.no_grad():
        enc = tokenizer(
            text,
            truncation=True,
            max_length=MAX_LEN_BERT,
            padding="max_length",
            return_tensors="pt",
        )
        ids  = enc["input_ids"].to(device)
        mask = enc["attention_mask"].to(device)
        outputs = model.bert(
            input_ids=ids,
            attention_mask=mask,
            output_attentions=True
        )
        attns    = outputs.attentions[-1]          # [B, heads, T, T]
        cls_attn = attns.mean(1)[0, 0]             # [T]
        cls_attn = cls_attn * mask[0]
        scores   = cls_attn.cpu().numpy()
        top_idx  = scores.argsort()[-top_k:][::-1]

        tokens = tokenizer.convert_ids_to_tokens(ids[0])
        rationale_tokens = [
            tokens[i] for i in top_idx
            if tokens[i] not in ["[PAD]", "[CLS]", "[SEP]"]
        ]
    return rationale_tokens, top_idx

def faithfulness_deletion_bert(text, model, tokenizer, top_k=10):
    model.eval()
    enc = tokenizer(
        text,
        truncation=True,
        max_length=MAX_LEN_BERT,
        padding="max_length",
        return_tensors="pt",
    )
    ids  = enc["input_ids"][0]
    mask = enc["attention_mask"][0]

    with torch.no_grad():
        logits = model(ids.unsqueeze(0).to(device),
                       mask.unsqueeze(0).to(device))
        base_prob = torch.sigmoid(logits).item()

    _, top_idx = extract_rationale_tokens_bert(
        text, model, tokenizer, top_k=top_k
    )
    keep_mask = np.ones_like(ids.cpu().numpy(), dtype=bool)
    keep_mask[top_idx] = False

    new_ids = ids.cpu().numpy()[keep_mask]
    new_ids = new_ids[:MAX_LEN_BERT]
    pad_len = MAX_LEN_BERT - len(new_ids)
    new_ids = np.concatenate(
        [new_ids, [tokenizer.pad_token_id]*pad_len]
    )
    new_mask = np.array(
        [1]*min(len(new_ids)-pad_len, MAX_LEN_BERT) + [0]*pad_len
    )

    new_ids  = torch.tensor(new_ids, dtype=torch.long).unsqueeze(0).to(device)
    new_mask = torch.tensor(new_mask, dtype=torch.long).unsqueeze(0).to(device)

    with torch.no_grad():
        new_logits = model(new_ids, new_mask)
        new_prob   = torch.sigmoid(new_logits).item()

    drop = base_prob - new_prob
    return base_prob, new_prob, drop

if best_model_name == "ClinicalBERT":
    idxs = np.random.choice(len(X_test), size=50, replace=False)
    drops = []
    for idx in idxs:
        txt = X_test.iloc[idx]["combined_text"]
        base_p, new_p, diff = faithfulness_deletion_bert(
            txt, best_model, tokenizer, top_k=10
        )
        drops.append(diff)
    avg_drop = float(np.mean(drops))
    print(f"\n[Rationales] Avg prob drop after deleting "
          f"top-10 rationale tokens: {avg_drop:.3f}")
else:
    print("\n[Rationales] Rationales implemented only for ClinicalBERT.")



[Robustness] Paraphrasing consistency: 0.995
[Robustness] Section reordering consistency: 0.980
[Robustness] Saved bar plot to /content/drive/MyDrive/clinical_project_cpu/deep_models/plots/BiLSTM_robustness.png

[Rationales] Rationales implemented only for ClinicalBERT.


5. BiLSTM TEXT EMBEDDINGS

In [55]:
# CHUNK 5 – EXPORT BiLSTM TEXT EMBEDDINGS

def export_lstm_embeddings(model, dataset, split_name, batch_size=128):
    """
    Uses model.encode_text() to get fixed-length embeddings
    and saves them as NPZ: embeddings, labels, stay_ids.
    """
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    model = model.to(device)
    model.eval()

    all_emb, all_labels, all_stays = [], [], []

    with torch.no_grad():
        for batch in loader:
            ids  = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)

            emb = model.encode_text(ids, mask)   # [B, 2H]
            all_emb.append(emb.cpu().numpy())
            all_labels.append(batch["label"].numpy())
            all_stays.append(batch["stay_id"].numpy())

    all_emb    = np.vstack(all_emb)
    all_labels = np.concatenate(all_labels)
    all_stays  = np.concatenate(all_stays)

    out_path = EMB_DIR / f"bilstm_embeddings_{split_name}.npz"
    np.savez_compressed(
        out_path,
        embeddings=all_emb,
        labels=all_labels,
        stay_ids=all_stays,
    )
    print(f"[BiLSTM] Saved {split_name} embeddings to {out_path} "
          f"with shape {all_emb.shape}")
    return out_path

path_lstm_train = export_lstm_embeddings(bilstm_model, train_ds_lstm, "train")
path_lstm_val   = export_lstm_embeddings(bilstm_model, val_ds_lstm,   "val")
path_lstm_test  = export_lstm_embeddings(bilstm_model, test_ds_lstm,  "test")


[BiLSTM] Saved train embeddings to /content/drive/MyDrive/clinical_project_cpu/deep_models/embeddings/bilstm_embeddings_train.npz with shape (141300, 256)
[BiLSTM] Saved val embeddings to /content/drive/MyDrive/clinical_project_cpu/deep_models/embeddings/bilstm_embeddings_val.npz with shape (30273, 256)
[BiLSTM] Saved test embeddings to /content/drive/MyDrive/clinical_project_cpu/deep_models/embeddings/bilstm_embeddings_test.npz with shape (29923, 256)


6. ClinicalModernBERT EMBEDDINGS

In [56]:
# ============================================
# CHUNK 6 – EXPORT ClinicalModernBERT EMBEDDINGS
# ============================================
def export_bert_embeddings(model, dataset, split_name, batch_size=64):
    """
    Uses model.encode_text() to get ModernBERT embeddings
    and saves them as NPZ: embeddings, labels, stay_ids.
    """
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    model = model.to(device)
    model.eval()

    all_emb, all_labels, all_stays = [], [], []

    with torch.no_grad():
        for batch in loader:
            ids  = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)

            emb = model.encode_text(ids, mask)   # [B, H]
            all_emb.append(emb.cpu().numpy())
            all_labels.append(batch["label"].numpy())
            all_stays.append(batch["stay_id"].numpy())

    all_emb    = np.vstack(all_emb)
    all_labels = np.concatenate(all_labels)
    all_stays  = np.concatenate(all_stays)

    out_path = EMB_DIR / f"clinicalmodernbert_embeddings_{split_name}.npz"
    np.savez_compressed(
        out_path,
        embeddings=all_emb,
        labels=all_labels,
        stay_ids=all_stays,
    )
    print(f"[ClinicalModernBERT] Saved {split_name} embeddings to {out_path} "
          f"with shape {all_emb.shape}")
    return out_path

path_bert_train = export_bert_embeddings(bert_model, train_ds_bert, "train")
path_bert_val   = export_bert_embeddings(bert_model, val_ds_bert,   "val")
path_bert_test  = export_bert_embeddings(bert_model, test_ds_bert,  "test")


[ClinicalModernBERT] Saved train embeddings to /content/drive/MyDrive/clinical_project_cpu/deep_models/embeddings/clinicalmodernbert_embeddings_train.npz with shape (141300, 768)
[ClinicalModernBERT] Saved val embeddings to /content/drive/MyDrive/clinical_project_cpu/deep_models/embeddings/clinicalmodernbert_embeddings_val.npz with shape (30273, 768)
[ClinicalModernBERT] Saved test embeddings to /content/drive/MyDrive/clinical_project_cpu/deep_models/embeddings/clinicalmodernbert_embeddings_test.npz with shape (29923, 768)


7. ZIP ALL OUTPUTS FOR DOWNLOAD

In [57]:
# ============================================
# CHUNK 7 – ZIP ALL OUTPUTS FOR DOWNLOAD
# ============================================
import shutil
from google.colab import files

zip_base = "/content/clinical_text_deep_models_outputs"
shutil.make_archive(zip_base, "zip", BASE_OUT_DIR)

zip_path = zip_base + ".zip"
print("Created archive:", zip_path)

# trigger download in Colab
files.download(zip_path)


Created archive: /content/clinical_text_deep_models_outputs.zip


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>