In [1]:
!/opt/conda/envs/rapids/bin/python -m pip install transformers[torch]
!pip install --upgrade typing_extensions



In [2]:
import os
import time
import random
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import TensorDataset, DataLoader

from transformers import BertModel, BertConfig, AutoTokenizer, AutoModel
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, recall_score, precision_score

# ----- Focal Loss Definition -----
class FocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=None, reduction='mean', pos_weight=None):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
        self.pos_weight = pos_weight

    def forward(self, logits, targets):
        bce_loss = F.binary_cross_entropy_with_logits(
            logits, targets, reduction='none', pos_weight=self.pos_weight
        )
        pt = torch.exp(-bce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * bce_loss

        if self.alpha is not None:
            focal_loss = self.alpha * focal_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# ----- Get Positive Weight for Loss -----
def get_pos_weight(labels_series, device):
    positive = labels_series.sum()
    negative = len(labels_series) - positive
    if positive == 0:
        weight = torch.tensor(1.0, dtype=torch.float, device=device)
    else:
        weight = torch.tensor(negative / positive, dtype=torch.float, device=device)
    return weight

# ----- BioClinicalBERT Fine-Tuning Wrapper -----
class BioClinicalBERT_FT(nn.Module):
    def __init__(self, base_model, config, device):
        super(BioClinicalBERT_FT, self).__init__()
        self.BioBert = base_model
        self.device = device

    def forward(self, input_ids, attention_mask):
        outputs = self.BioBert(input_ids=input_ids, attention_mask=attention_mask)
        cls_embedding = outputs.last_hidden_state[:, 0, :]  # CLS token
        return cls_embedding

# ----- Function to Compute Aggregated Text Embeddings -----
def apply_bioclinicalbert_on_patient_notes(df, note_columns, tokenizer, model, device, aggregation="mean"):
    """
    For each unique patient (by subject_id), extracts all non-null notes from the specified note columns,
    tokenizes them, computes the CLS embedding via BioClinicalBERT, and aggregates them (mean or max).
    """
    patient_ids = df["subject_id"].unique()
    aggregated_embeddings = []
    for pid in tqdm(patient_ids, desc="Aggregating text embeddings"):
        patient_data = df[df["subject_id"] == pid]
        notes = []
        for col in note_columns:
            vals = patient_data[col].dropna().tolist()
            notes.extend([v for v in vals if isinstance(v, str) and v.strip() != ""])
        if len(notes) == 0:
            aggregated_embeddings.append(np.zeros(model.BioBert.config.hidden_size))
        else:
            embeddings = []
            for note in notes:
                encoded = tokenizer.encode_plus(
                    text=note,
                    add_special_tokens=True,
                    max_length=128,
                    padding='max_length',
                    truncation=True,
                    return_attention_mask=True,
                    return_tensors='pt'
                )
                input_ids = encoded['input_ids'].to(device)
                attn_mask = encoded['attention_mask'].to(device)
                with torch.no_grad():
                    emb = model(input_ids, attn_mask)
                embeddings.append(emb.cpu().numpy())
            embeddings = np.vstack(embeddings)
            agg_emb = np.mean(embeddings, axis=0) if aggregation == "mean" else np.max(embeddings, axis=0)
            aggregated_embeddings.append(agg_emb)
    aggregated_embeddings = np.vstack(aggregated_embeddings)
    return aggregated_embeddings

# ----- Evaluation Function for Mortality -----
def evaluate_model(model, dataloader, device, threshold=0.5):
    model.eval()
    all_logits = []
    all_labels = []
    with torch.no_grad():
        for batch in dataloader:
            aggregated_text_embedding, labels_mortality = [x.to(device) for x in batch]
            logits = model(aggregated_text_embedding)
            all_logits.append(logits.cpu())
            all_labels.append(labels_mortality.cpu())
    all_logits = torch.cat(all_logits, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    probs = torch.sigmoid(all_logits).numpy()
    labels_np = all_labels.numpy()

    try:
        aucroc = roc_auc_score(labels_np, probs)
    except Exception:
        aucroc = float('nan')
    try:
        auprc = average_precision_score(labels_np, probs)
    except Exception:
        auprc = float('nan')
    preds = (probs > threshold).astype(int)
    f1 = f1_score(labels_np, preds, zero_division=0)
    recall = recall_score(labels_np, preds, zero_division=0)
    precision = precision_score(labels_np, preds, zero_division=0)
    metrics = {"aucroc": aucroc, "auprc": auprc, "f1": f1, "recall": recall, "precision": precision}
    return metrics, all_logits

# ----- Text-based Mortality Model (Structure Similar to Fusion Code) -----
class MultimodalTransformer(nn.Module):
    def __init__(self, text_embed_size, hidden_size=512):
        super(MultimodalTransformer, self).__init__()
        self.text_projector = nn.Sequential(
            nn.Linear(text_embed_size, 256),
            nn.ReLU()
        )
        self.classifier = nn.Sequential(
            nn.Linear(256, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, 1)  # Single output for mortality
        )

    def forward(self, aggregated_text_embedding):
        text_proj = self.text_projector(aggregated_text_embedding)
        mortality_logits = self.classifier(text_proj)
        return mortality_logits

# ----- Training Step Function -----
def train_step(model, dataloader, optimizer, device, criterion):
    model.train()
    running_loss = 0.0
    for batch in dataloader:
        aggregated_text_embedding, labels_mortality = [x.to(device) for x in batch]
        optimizer.zero_grad()
        mortality_logits = model(aggregated_text_embedding)
        # Updated loss: only mortality
        loss = criterion(mortality_logits, labels_mortality.unsqueeze(1))
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss

# ----- Main Training Pipeline -----
def train_pipeline():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    # Load unstructured data
    df = pd.read_csv("final_unstructured.csv", low_memory=False)

    # Identify note columns (assumed to start with "note_")
    note_columns = [col for col in df.columns if col.startswith("note_")]
    if len(note_columns) == 0:
        raise ValueError("No note columns found in the data.")

    # Filter rows with valid note text
    def has_valid_note(row):
        for col in note_columns:
            if pd.notnull(row[col]) and isinstance(row[col], str) and row[col].strip():
                return True
        return False

    df_filtered = df[df.apply(has_valid_note, axis=1)].copy()
    print("After filtering, number of rows:", len(df_filtered))

    # Ensure required columns exist
    if "subject_id" not in df_filtered.columns:
        raise ValueError("Column 'subject_id' not found in data.")
    if "short_term_mortality" not in df_filtered.columns:
        raise ValueError("Column 'short_term_mortality' not found in data.")

    # Initialize tokenizer and BioClinicalBERT for text embeddings
    tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
    bioclinical_bert_base = BertModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
    bioclinical_bert_ft = BioClinicalBERT_FT(bioclinical_bert_base, bioclinical_bert_base.config, device).to(device)

    # Compute aggregated text embeddings per patient
    print("Computing aggregated text embeddings for each patient...")
    aggregated_text_embeddings_np = apply_bioclinicalbert_on_patient_notes(
        df_filtered, note_columns, tokenizer, bioclinical_bert_ft, device, aggregation="mean"
    )
    print("Aggregated text embeddings shape:", aggregated_text_embeddings_np.shape)

    # Group by subject_id to obtain one label per patient (using first occurrence)
    grouped = df_filtered.groupby("subject_id", sort=False).first().reset_index()
    unique_subject_ids = df_filtered["subject_id"].unique()
    grouped['order'] = pd.Categorical(grouped['subject_id'], categories=unique_subject_ids, ordered=True)
    grouped = grouped.sort_values('order')

    labels_mortality = torch.tensor(grouped["short_term_mortality"].values, dtype=torch.float32)
    aggregated_text_embeddings_t = torch.tensor(aggregated_text_embeddings_np, dtype=torch.float32)

    # Create dataset and dataloader
    dataset = TensorDataset(aggregated_text_embeddings_t, labels_mortality)
    dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

    # Compute positive weight for mortality
    mortality_pos_weight = get_pos_weight(grouped["short_term_mortality"], device)

    # Initialize model, optimizer, scheduler, and loss criterion
    model = MultimodalTransformer(text_embed_size=768, hidden_size=512).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, verbose=True)
    criterion = FocalLoss(gamma=2, pos_weight=mortality_pos_weight, reduction='mean')

    # Training Loop
    num_epochs = 5
    for epoch in range(num_epochs):
        train_loss = train_step(model, dataloader, optimizer, device, criterion)
        epoch_loss = train_loss / len(dataloader)
        print(f"[Epoch {epoch+1}] Train Loss: {epoch_loss:.4f}")
        scheduler.step(epoch_loss)
        metrics, _ = evaluate_model(model, dataloader, device, threshold=0.5)
        print(f"Metrics at threshold=0.5 after epoch {epoch+1}: {metrics}")

    print("Training complete.")

if __name__ == "__main__":
    train_pipeline()


Using device: cuda
After filtering, number of rows: 46091
Computing aggregated text embeddings for each patient...


Aggregating text embeddings: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 46091/46091 [2:40:08<00:00,  4.80it/s]


Aggregated text embeddings shape: (46091, 768)
[Epoch 1] Train Loss: 0.3989
Metrics at threshold=0.5 after epoch 1: {'aucroc': 0.9469224961352024, 'auprc': 0.7256341613759754, 'f1': 0.47859308671922385, 'recall': 0.928252175958598, 'precision': 0.3224119617615818}
[Epoch 2] Train Loss: 0.3287
Metrics at threshold=0.5 after epoch 2: {'aucroc': 0.9535491564688638, 'auprc': 0.7499799206606024, 'f1': 0.4532466804863018, 'recall': 0.951540813926135, 'precision': 0.2974702162082659}
[Epoch 3] Train Loss: 0.3110
Metrics at threshold=0.5 after epoch 3: {'aucroc': 0.9582834744091259, 'auprc': 0.7676505781935334, 'f1': 0.5352877307274702, 'recall': 0.9277816984239002, 'precision': 0.37615641392465426}
[Epoch 4] Train Loss: 0.2971
Metrics at threshold=0.5 after epoch 4: {'aucroc': 0.9622182307345972, 'auprc': 0.7851289705168459, 'f1': 0.4778356854248404, 'recall': 0.9597741707833451, 'precision': 0.3181038515515359}
[Epoch 5] Train Loss: 0.2836
Metrics at threshold=0.5 after epoch 5: {'aucroc': 0