In [1]:
# Install/upgrade transformers and accelerate (for the Trainer)
!pip install transformers --upgrade --quiet
!pip install accelerate --upgrade --quiet

In [2]:
"""
Multi-Task Learning Classifier for Hospital Readmission Prediction
Uses Clinical-Longformer with three task heads:
1. Main: 30-day readmission (binary)
2. Auxiliary: Hospital mortality (binary)
3. Auxiliary: Admission type (multiclass)
"""

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from transformers import (
    AutoTokenizer,
    AutoModel,
    Trainer,
    TrainingArguments,
    EarlyStoppingCallback
)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import (
    roc_auc_score,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score
)
import warnings
warnings.filterwarnings('ignore')


# ===========================================================================
# CONFIGURATION
# ===========================================================================
class CONFIG:
    DATA_PATH = "/kaggle/input/mimic-iv"
    MODEL_NAME = "yikuan8/Clinical-Longformer"
    
    # --- Fast Testing Configs ---
    SAMPLE_SIZE = 10000  # Set to None to use full dataset
    EPOCHS = 1
    
    # --- GPU Optimization ---
    TRAIN_BATCH_SIZE = 2
    VALID_BATCH_SIZE = 4
    GRADIENT_ACCUMULATION = 16
    USE_FP16 = True
    MAX_LENGTH = 4096
    
    # --- Training Hyperparameters ---
    LEARNING_RATE = 2e-5
    WEIGHT_DECAY = 0.01
    WARMUP_RATIO = 0.1
    
    # --- Loss Weights ---
    WEIGHT_READMIT = 1.0
    WEIGHT_MORTALITY = 0.5
    WEIGHT_ADM_TYPE = 0.5
    
    # --- Other ---
    RANDOM_SEED = 42
    TEST_SIZE = 0.2
    OUTPUT_DIR = "./mtl_readmission_model_FROZEN" # Changed output dir


# ===========================================================================
# DATA LOADING AND PREPARATION
# ===========================================================================
def load_and_prepare_data(config):
    """Load and merge all data sources."""
    print("=" * 80)
    print("LOADING DATA")
    print("=" * 80)
    
    # Load main admissions file
    admissions = pd.read_csv(f"{config.DATA_PATH}/admissions_with_readmission_labels.csv")
    print(f"‚úì Loaded admissions: {admissions.shape}")
    
    # Load discharge notes
    discharge = pd.read_csv(f"{config.DATA_PATH}/discharge_notes-001.csv")
    print(f"‚úì Loaded discharge notes: {discharge.shape}")
    
    # Load radiology notes
    radiology = pd.read_csv(f"{config.DATA_PATH}/radiology_notes.csv")
    print(f"‚úì Loaded radiology notes: {radiology.shape}")
    
    # Combine notes: concatenate discharge and radiology by hadm_id
    print("\nCombining notes...")
    discharge_grouped = discharge.groupby('hadm_id')['text'].apply(
        lambda x: ' '.join(x.astype(str))
    ).reset_index()
    discharge_grouped.columns = ['hadm_id', 'discharge_text']
    
    radiology_grouped = radiology.groupby('hadm_id')['text'].apply(
        lambda x: ' '.join(x.astype(str))
    ).reset_index()
    radiology_grouped.columns = ['hadm_id', 'radiology_text']
    
    # Merge notes
    notes_combined = discharge_grouped.merge(
        radiology_grouped, on='hadm_id', how='outer'
    )
    
    # Combine all text
    notes_combined['combined_text'] = (
        notes_combined['discharge_text'].fillna('') + ' ' + 
        notes_combined['radiology_text'].fillna('')
    )
    notes_combined['combined_text'] = notes_combined['combined_text'].str.strip()
    
    # Merge with admissions
    df = admissions.merge(notes_combined[['hadm_id', 'combined_text']], 
                          on='hadm_id', how='left')
    
    # Use combined_text if available, otherwise use original text
    df['final_text'] = df['combined_text'].fillna(df.get('text', ''))
    df['final_text'] = df['final_text'].fillna('')
    
    print(f"‚úì Final merged data: {df.shape}")
    
    # Sample data if needed
    if config.SAMPLE_SIZE is not None:
        df = df.sample(n=min(config.SAMPLE_SIZE, len(df)), 
                       random_state=config.RANDOM_SEED)
        print(f"‚úì Sampled {len(df)} rows for testing")
    
    # Prepare labels
    print("\nPreparing labels...")
    
    # Binary labels
    df['readmitted_30day'] = df['readmitted_30day'].astype(int)
    df['hospital_expire_flag'] = df['hospital_expire_flag'].astype(int)
    
    # Multiclass label (admission_type)
    le = LabelEncoder()
    df['admission_type_encoded'] = le.fit_transform(df['admission_type'])
    num_admission_types = len(le.classes_)
    
    print(f"  - Readmission distribution: {df['readmitted_30day'].value_counts().to_dict()}")
    print(f"  - Mortality distribution: {df['hospital_expire_flag'].value_counts().to_dict()}")
    print(f"  - Admission types: {num_admission_types} classes")
    
    # Calculate class weights for readmission (handle imbalance)
    pos_count = df['readmitted_30day'].sum()
    neg_count = len(df) - pos_count
    pos_weight = neg_count / pos_count if pos_count > 0 else 1.0
    
    print(f"  - Positive weight for readmission: {pos_weight:.2f}")
    
    return df, le, num_admission_types, pos_weight


# ===========================================================================
# DATASET CLASS
# ===========================================================================
class MTLDataset(Dataset):
    """Custom dataset for multi-task learning."""
    
    def __init__(self, texts, readmit_labels, mortality_labels, 
                 adm_type_labels, tokenizer, max_length):
        self.texts = texts
        self.readmit_labels = readmit_labels
        self.mortality_labels = mortality_labels
        self.adm_type_labels = adm_type_labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        
        # Tokenize
        encoding = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'readmit_label': torch.tensor(self.readmit_labels[idx], dtype=torch.float),
            'mortality_label': torch.tensor(self.mortality_labels[idx], dtype=torch.float),
            'adm_type_label': torch.tensor(self.adm_type_labels[idx], dtype=torch.long)
        }


# ===========================================================================
# MULTI-TASK MODEL
# ===========================================================================
class MTLReadmissionModel(nn.Module):
    """Multi-Task Learning model with shared trunk and three task heads."""
    
    def __init__(self, model_name, num_admission_types):
        super().__init__()
        
        # Shared trunk: Clinical-Longformer
        self.backbone = AutoModel.from_pretrained(model_name)
        hidden_size = self.backbone.config.hidden_size
        
        # Task heads
        self.readmit_head = nn.Linear(hidden_size, 1)  # Binary
        self.mortality_head = nn.Linear(hidden_size, 1)  # Binary
        self.adm_type_head = nn.Linear(hidden_size, num_admission_types)  # Multiclass
        
        # Dropout for regularization
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, input_ids, attention_mask, **kwargs):
        # By adding **kwargs, this function now ignores any
        # extra keys (like 'readmit_label', 'labels', etc.)
        # passed to it during the evaluation_loop/prediction_step.
        
        # Get [CLS] token representation from backbone
        outputs = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        
        # Use [CLS] token (first token)
        cls_output = outputs.last_hidden_state[:, 0, :]
        cls_output = self.dropout(cls_output)
        
        # Three separate predictions
        readmit_logits = self.readmit_head(cls_output)
        mortality_logits = self.mortality_head(cls_output)
        adm_type_logits = self.adm_type_head(cls_output)
        
        return {
            'readmit_logits': readmit_logits,
            'mortality_logits': mortality_logits,
            'adm_type_logits': adm_type_logits
        }


# ===========================================================================
# CUSTOM TRAINER
# ===========================================================================
class MTLTrainer(Trainer):
    """Custom trainer with multi-task loss."""
    
    def __init__(self, pos_weight, weight_readmit, weight_mortality, 
                 weight_adm_type, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.pos_weight = torch.tensor([pos_weight])
        self.weight_readmit = weight_readmit
        self.weight_mortality = weight_mortality
        self.weight_adm_type = weight_adm_type
        
        # Loss functions
        self.bce_loss = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight)
        self.bce_loss_mortality = nn.BCEWithLogitsLoss()
        self.ce_loss = nn.CrossEntropyLoss()
    
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        # Move pos_weight to correct device
        model_device = next(model.parameters()).device
        if self.pos_weight.device != model_device:
            self.pos_weight = self.pos_weight.to(model_device)
            self.bce_loss = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight)
            
        # Extract labels
        readmit_labels = inputs.pop('readmit_label')
        mortality_labels = inputs.pop('mortality_label')
        adm_type_labels = inputs.pop('adm_type_label')
        
        # Remove the composite 'labels' key, which is only for metrics
        inputs.pop('labels', None) 
        
        # Forward pass
        # 'inputs' now ONLY contains 'input_ids' and 'attention_mask'
        outputs = model(**inputs)
        
        # Calculate individual losses
        loss_readmit = self.bce_loss(outputs['readmit_logits'].squeeze(),readmit_labels)
        
        loss_mortality = self.bce_loss_mortality(outputs['mortality_logits'].squeeze(),mortality_labels)
        
        loss_adm_type = self.ce_loss(outputs['adm_type_logits'],adm_type_labels)
        
        # Combined loss with weights
        total_loss = (self.weight_readmit * loss_readmit + \
                      self.weight_mortality * loss_mortality + \
                      self.weight_adm_type * loss_adm_type)
        
        return (total_loss, outputs) if return_outputs else total_loss

    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        """Custom prediction step to return predictions in the correct format."""
        
        # Extract labels
        readmit_labels = inputs.get('readmit_label')
        mortality_labels = inputs.get('mortality_label')
        adm_type_labels = inputs.get('adm_type_label')
        
        # Stack labels for metrics
        labels = torch.stack([
            readmit_labels,
            mortality_labels,
            adm_type_labels.float()
        ], dim=1)
        
        # Remove individual label keys and composite labels
        inputs_for_model = {
            'input_ids': inputs['input_ids'],
            'attention_mask': inputs['attention_mask']
        }
        
        # Get predictions
        with torch.no_grad():
            outputs = model(**inputs_for_model)
            
            # Calculate loss
            loss = self.compute_loss(model, inputs.copy(), return_outputs=False)
        
        # Stack predictions as tuple (readmit_logits, mortality_logits, adm_type_logits)
        predictions = (
            outputs['readmit_logits'].detach().cpu(),
            outputs['mortality_logits'].detach().cpu(),
            outputs['adm_type_logits'].detach().cpu()
        )
        
        if prediction_loss_only:
            return (loss, None, None)
        
        return (loss, predictions, labels.detach().cpu())

        
# ===========================================================================
# EVALUATION METRICS
# ===========================================================================
    
def compute_metrics(eval_pred):
    """Compute metrics for the main task (readmission) only."""
    predictions, labels = eval_pred
    
    # predictions is now a tuple: (readmit_logits, mortality_logits, adm_type_logits)
    # Extract readmission predictions
    if isinstance(predictions, tuple):
        readmit_logits = predictions[0]
    else:
        readmit_logits = predictions
    
    # Extract readmission labels (first column)
    readmit_labels = labels[:, 0]
    
    # Convert logits to probabilities and predictions
    readmit_logits_np = readmit_logits.numpy() if isinstance(readmit_logits, torch.Tensor) else readmit_logits
    readmit_probs = 1 / (1 + np.exp(-readmit_logits_np.squeeze()))  # Sigmoid
    readmit_preds = (readmit_probs > 0.5).astype(int)
    
    # Handle edge case: if all predictions are same class
    try:
        roc_auc = roc_auc_score(readmit_labels, readmit_probs)
    except ValueError:
        roc_auc = 0.0
    
    # Calculate metrics
    metrics = {
        'roc_auc': roc_auc,
        'accuracy': accuracy_score(readmit_labels, readmit_preds),
        'precision': precision_score(readmit_labels, readmit_preds, zero_division=0),
        'recall': recall_score(readmit_labels, readmit_preds, zero_division=0),
        'f1': f1_score(readmit_labels, readmit_preds, zero_division=0)
    }
    
    return metrics


def custom_data_collator(features):
    """Custom collator to handle multiple labels."""
    batch = {
        'input_ids': torch.stack([f['input_ids'] for f in features]),
        'attention_mask': torch.stack([f['attention_mask'] for f in features]),
        'readmit_label': torch.stack([f['readmit_label'] for f in features]),
        'mortality_label': torch.stack([f['mortality_label'] for f in features]),
        'adm_type_label': torch.stack([f['adm_type_label'] for f in features]),
    }
    
    return batch

2025-11-17 21:02:50.007225: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763413370.029461     424 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763413370.036334     424 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
# ===========================================================================
# MAIN SCRIPT
# ===========================================================================

print("\n" + "=" * 80)
print("MULTI-TASK LEARNING FOR HOSPITAL READMISSION PREDICTION")
print("          (BASE MODEL FROZEN - FEATURE EXTRACTION)")
print("=" * 80 + "\n")

# Set random seeds
torch.manual_seed(CONFIG.RANDOM_SEED)
np.random.seed(CONFIG.RANDOM_SEED)

# Load data
df, label_encoder, num_admission_types, pos_weight = load_and_prepare_data(CONFIG)

# Split data
print("\n" + "=" * 80)
print("SPLITTING DATA")
print("=" * 80)
train_df, val_df = train_test_split(
    df, 
    test_size=CONFIG.TEST_SIZE, 
    random_state=CONFIG.RANDOM_SEED,
    stratify=df['readmitted_30day']
)
print(f"‚úì Train size: {len(train_df)}")
print(f"‚úì Validation size: {len(val_df)}")

# Load tokenizer
print("\n" + "=" * 80)
print("LOADING TOKENIZER")
print("=" * 80)
tokenizer = AutoTokenizer.from_pretrained(CONFIG.MODEL_NAME)
print(f"‚úì Loaded tokenizer: {CONFIG.MODEL_NAME}")

# Create datasets
print("\n" + "=" * 80)
print("CREATING DATASETS")
print("=" * 80)
train_dataset = MTLDataset(
    texts=train_df['final_text'].values,
    readmit_labels=train_df['readmitted_30day'].values,
    mortality_labels=train_df['hospital_expire_flag'].values,
    adm_type_labels=train_df['admission_type_encoded'].values,
    tokenizer=tokenizer,
    max_length=CONFIG.MAX_LENGTH
)

val_dataset = MTLDataset(
    texts=val_df['final_text'].values,
    readmit_labels=val_df['readmitted_30day'].values,
    mortality_labels=val_df['hospital_expire_flag'].values,
    adm_type_labels=val_df['admission_type_encoded'].values,
    tokenizer=tokenizer,
    max_length=CONFIG.MAX_LENGTH
)
print(f"‚úì Train dataset: {len(train_dataset)} samples")
print(f"‚úì Validation dataset: {len(val_dataset)} samples")



MULTI-TASK LEARNING FOR HOSPITAL READMISSION PREDICTION
          (BASE MODEL FROZEN - FEATURE EXTRACTION)

LOADING DATA
‚úì Loaded admissions: (374139, 16)
‚úì Loaded discharge notes: (331731, 7)
‚úì Loaded radiology notes: (1144023, 7)

Combining notes...
‚úì Final merged data: (374139, 18)
‚úì Sampled 100 rows for testing

Preparing labels...
  - Readmission distribution: {0: 79, 1: 21}
  - Mortality distribution: {0: 97, 1: 3}
  - Admission types: 9 classes
  - Positive weight for readmission: 3.76

SPLITTING DATA
‚úì Train size: 80
‚úì Validation size: 20

LOADING TOKENIZER
‚úì Loaded tokenizer: yikuan8/Clinical-Longformer

CREATING DATASETS
‚úì Train dataset: 80 samples
‚úì Validation dataset: 20 samples


In [4]:

# Initialize model
print("\n" + "=" * 80)
print("INITIALIZING MODEL")
print("=" * 80)
model = MTLReadmissionModel(CONFIG.MODEL_NAME, num_admission_types)
print(f"‚úì Model created with {num_admission_types} admission type classes")


# ---
# --- ‚ùÑÔ∏è START: USER-REQUESTED CHANGE (FREEZE BASE MODEL) ‚ùÑÔ∏è ---
# ---
print("\n" + "-" * 80)
print("‚ùÑÔ∏è FREEZING BASE MODEL PARAMETERS (Clinical-Longformer)...")
for param in model.backbone.parameters():
    param.requires_grad = False
print("‚úì Base model frozen. Only the classification heads will be fine-tuned.")
print("-" * 80)
# ---
# --- ‚ùÑÔ∏è END: USER-REQUESTED CHANGE ‚ùÑÔ∏è ---
# ---


# Training arguments
training_args = TrainingArguments(
    output_dir=CONFIG.OUTPUT_DIR,
    num_train_epochs=CONFIG.EPOCHS,
    per_device_train_batch_size=CONFIG.TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=CONFIG.VALID_BATCH_SIZE,
    gradient_accumulation_steps=CONFIG.GRADIENT_ACCUMULATION,
    learning_rate=CONFIG.LEARNING_RATE,
    weight_decay=CONFIG.WEIGHT_DECAY,
    warmup_ratio=CONFIG.WARMUP_RATIO,
    fp16=CONFIG.USE_FP16,
    logging_steps=10,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="roc_auc",
    greater_is_better=True,
    save_total_limit=2,
    report_to="none",
    seed=CONFIG.RANDOM_SEED,
    remove_unused_columns=False,
)

# Initialize trainer
print("\n" + "=" * 80)
print("INITIALIZING TRAINER")
print("=" * 80)
trainer = MTLTrainer(
    pos_weight=pos_weight,
    weight_readmit=CONFIG.WEIGHT_READMIT,
    weight_mortality=CONFIG.WEIGHT_MORTALITY,
    weight_adm_type=CONFIG.WEIGHT_ADM_TYPE,
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=custom_data_collator,
    compute_metrics=compute_metrics,
)
print("‚úì Trainer initialized")

# Train
print("\n" + "=" * 80)
print("TRAINING MODEL (ONLY HEADS)")
print("=" * 80)
trainer.train()


INITIALIZING MODEL


pytorch_model.bin:   0%|          | 0.00/595M [00:00<?, ?B/s]

Some weights of LongformerModel were not initialized from the model checkpoint at yikuan8/Clinical-Longformer and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


‚úì Model created with 9 admission type classes

--------------------------------------------------------------------------------
‚ùÑÔ∏è FREEZING BASE MODEL PARAMETERS (Clinical-Longformer)...
‚úì Base model frozen. Only the classification heads will be fine-tuned.
--------------------------------------------------------------------------------

INITIALIZING TRAINER


model.safetensors:   0%|          | 0.00/595M [00:00<?, ?B/s]

‚úì Trainer initialized

TRAINING MODEL (ONLY HEADS)


Epoch,Training Loss,Validation Loss,Roc Auc,Accuracy,Precision,Recall,F1
1,No log,2.500375,0.578125,0.2,0.2,1.0,0.333333


TrainOutput(global_step=3, training_loss=2.692969640096029, metrics={'train_runtime': 39.9602, 'train_samples_per_second': 2.002, 'train_steps_per_second': 0.075, 'total_flos': 0.0, 'train_loss': 2.692969640096029, 'epoch': 1.0})

In [5]:
# Final evaluation
print("\n" + "=" * 80)
print("FINAL EVALUATION (MAIN TASK: 30-DAY READMISSION)")
print("=" * 80)
eval_results = trainer.evaluate()

print("\nüìä RESULTS (BASE MODEL FROZEN):")
print(f"  ROC-AUC:   {eval_results['eval_roc_auc']:.4f}")
print(f"  Accuracy:  {eval_results['eval_accuracy']:.4f}")
print(f"  Precision: {eval_results['eval_precision']:.4f}")
print(f"  Recall:    {eval_results['eval_recall']:.4f}")
print(f"  F1 Score:  {eval_results['eval_f1']:.4f}")

print("\n" + "=" * 80)
print("TRAINING COMPLETE!")
print("=" * 80)
print(f"Model saved to: {CONFIG.OUTPUT_DIR}")


FINAL EVALUATION (MAIN TASK: 30-DAY READMISSION)



üìä RESULTS (BASE MODEL FROZEN):
  ROC-AUC:   0.5781
  Accuracy:  0.2000
  Precision: 0.2000
  Recall:    1.0000
  F1 Score:  0.3333

TRAINING COMPLETE!
Model saved to: ./mtl_readmission_model_FROZEN
