In [None]:
!pip install --force-reinstall numpy==1.26.4


In [None]:
# Install required survival analysis libraries
!pip install pycox torchtuples lifelines

In [None]:
# Install the missing codec package for LZW compression in TIFF files
!pip install imagecodecs

In [1]:
# Cell 1: Installation and Setup (FINAL)

print("Installing required Python packages...")
!pip install pycox torchtuples lifelines imagecodecs

# --- Standard Libraries ---
import os
import re
import glob
import pandas as pd
import numpy as np
from PIL import Image
from pathlib import Path

# --- PyTorch and Vision ---
import torch
import torch.nn as nn
from torchvision import models, transforms
from torchvision.models import ViT_B_16_Weights 
from torch.utils.data import Dataset, DataLoader

# --- Survival Analysis and Evaluation (CRITICAL IMPORTS) ---
from pycox.models import CoxTime
from pycox.models import loss as pycox_loss # Import loss module for correct name
import torchtuples as tt
from lifelines.utils import concordance_index
from sklearn.model_selection import train_test_split
from tifffile import imread # Assumes imagecodecs installed

# --- CONSTANTS ---
KAGGLE_INPUT_DIR = "/kaggle/input/" 
DATASET_NAME = "time-series-dataset" # ðŸš¨ UPDATE THIS TO YOUR SPECIFIC KAGGLE DATASET FOLDER NAME
IMAGE_DIR = os.path.join(KAGGLE_INPUT_DIR, DATASET_NAME)

FEATURE_OUTPUT_DIR = "vit_extracted_features"

# Model Parameters
VIT_OUT_DIM = 768    
TRANSFORMER_DIM = 128
MAX_SEQ_LEN = 5      
IMG_SIZE = 224       

# Training Parameters (Aggressive Regularization & Safer Learning)
BATCH_SIZE = 16         # Reduced for stability and better CoxPH gradients
NUM_EPOCHS = 200        # High epochs for full convergence
LEARNING_RATE = 1e-5    # Very low/safer starting rate
WEIGHT_DECAY = 1e-3     # Strong L2 Regularization (Penalizes large weights)
GRADIENT_CLIP_VALUE = 1.0 
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f"\nRunning on device: {DEVICE}")

Installing required Python packages...

Running on device: cuda


In [2]:
# Cell 2: Data Preprocessing and Metadata Generation (FINAL)

def create_image_metadata_df(image_path_root):
    """Parses filenames to extract Patient ID and a numerical time step."""
    file_paths = glob.glob(os.path.join(image_path_root, "**", "*.tif"), recursive=True)
    records = []
    # Adjust this regex pattern to match your file naming convention precisely!
    pattern = re.compile(r'(\w+)\_SR\_\d+\_IM(\d+)\.tif') 
    
    for path in file_paths:
        file_name = os.path.basename(path)
        match = pattern.search(file_name)
        if match:
            patient_id = match.group(1).replace('_patient', '')
            time_step = int(match.group(2)) 
            records.append({'patient_id': patient_id, 'time_step': time_step, 'file_path': path})
            
    df = pd.DataFrame(records).sort_values(['patient_id', 'time_step']).reset_index(drop=True)
    return df

def generate_placeholder_survival_df(image_df):
    """
    ðŸš¨ CRITICAL: REPLACE THIS WITH CODE TO LOAD YOUR ACTUAL SURVIVAL DATA.
    (Placeholder data is kept for execution, but must be replaced for good results)
    """
    unique_patients = image_df['patient_id'].unique()
    np.random.seed(42)
    # Generate less random placeholder data for better initial results
    T = np.random.randint(100, 1500, len(unique_patients)).astype(np.float32)
    E = np.random.randint(0, 2, len(unique_patients)).astype(np.int64) 

    survival_df = pd.DataFrame({'patient_id': unique_patients, 'time': T, 'event': E})
    return survival_df

# --- EXECUTION ---
full_metadata_df = create_image_metadata_df(IMAGE_DIR)
survival_df = generate_placeholder_survival_df(full_metadata_df)

patient_ids = full_metadata_df[['patient_id']].drop_duplicates()
global final_df
final_df = pd.merge(patient_ids, survival_df, on='patient_id') 

print(f"Total Unique Patients: {len(final_df)}")
print("\nSample Survival Data (time, event):")
print(final_df.head())

Total Unique Patients: 455

Sample Survival Data (time, event):
    patient_id    time  event
0    137covid1  1226.0      1
1   137covid10   960.0      0
2  137covid100  1394.0      1
3  137covid101  1230.0      1
4  137covid103  1195.0      1


In [8]:
# Cell 4: Universal Survival Model Definition (ViT integrated into Transformer)

# --- CRITICAL IMPORTS ---
import torch
import torch.nn as nn
from torchvision import models, transforms
from torchvision.models import ViT_B_16_Weights 
from torch.utils.data import Dataset, DataLoader
from tifffile import imread
from PIL import Image
import os
import numpy as np

# --- 4.1. Positional Encoding Module ---
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=MAX_SEQ_LEN):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

# --- 4.2. Feature Dataset (Loads RAW Images) ---
class UniversalImageSequenceDataset(Dataset):
    """Loads raw image sequences and returns a padded tensor of images."""
    def __init__(self, dataframe, image_root_dir, max_seq_len, transform):
        self.dataframe = dataframe
        self.image_root_dir = image_root_dir
        self.max_seq_len = max_seq_len
        self.transform = transform
        
        # Pre-process metadata to link patient ID to a time-ordered list of paths
        self.patient_sequences = self._get_patient_sequences()

    def _get_patient_sequences(self):
        # This re-runs the metadata parsing to get file paths (requires Cell 2 logic)
        records = []
        file_paths = glob.glob(os.path.join(self.image_root_dir, "**", "*.tif"), recursive=True)
        pattern = re.compile(r'(\w+)\_SR\_\d+\_IM(\d+)\.tif')
        
        for path in file_paths:
            match = pattern.search(os.path.basename(path))
            if match:
                patient_id = match.group(1).replace('_patient', '')
                time_step = int(match.group(2)) 
                records.append({'patient_id': patient_id, 'time_step': time_step, 'file_path': path})
        
        meta_df = pd.DataFrame(records).sort_values(['patient_id', 'time_step'])
        
        # Merge with survival data and create a sequence of paths for each patient
        sequences = {}
        for idx, row in self.dataframe.iterrows():
            patient_id = row['patient_id']
            patient_paths = meta_df[meta_df['patient_id'] == patient_id]['file_path'].tolist()
            sequences[patient_id] = patient_paths
        return sequences

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        patient_id = row['patient_id']
        paths = self.patient_sequences.get(patient_id, [])
        
        image_tensors = []
        for path in paths:
            try:
                # Load and prepare image (robustly handling TIFF)
                img_array = imread(path)
                if img_array.dtype != np.uint8:
                    max_val = np.iinfo(img_array.dtype).max 
                    img_array = (img_array.astype(np.float32) / max_val * 255).astype(np.uint8)
                if len(img_array.shape) == 2:
                    img_array = np.stack([img_array]*3, axis=-1) 
                
                img = Image.fromarray(img_array)
                img_tensor = self.transform(img)
                image_tensors.append(img_tensor)
            except Exception:
                # Use a zero tensor placeholder if load fails (important for sequence integrity)
                image_tensors.append(torch.zeros(3, IMG_SIZE, IMG_SIZE))

        # Pad/Truncate the sequence of tensors
        seq_len = len(image_tensors)
        
        if seq_len > self.max_seq_len:
            image_tensors = image_tensors[:self.max_seq_len]
        elif seq_len < self.max_seq_len:
            pad_tensor = torch.zeros(3, IMG_SIZE, IMG_SIZE)
            for _ in range(self.max_seq_len - seq_len):
                image_tensors.append(pad_tensor)

        X_sequence = torch.stack(image_tensors)
        
        time = row['time']
        event = row['event']
        
        return X_sequence, (time, event)


# --- 4.3. Universal Model (ViT + Transformer + CoxPH) ---
class UniversalSurvivalModel(nn.Module):
    def __init__(self, input_dim=VIT_OUT_DIM, transformer_dim=TRANSFORMER_DIM):
        super().__init__()
        
        # 1. Image Feature Extractor (ViT - NOW PART OF THE TRAINABLE NET)
        self.vit = models.vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
        self.vit.heads = nn.Identity()
        
        # 2. Sequence Modeling Components
        self.proj = nn.Linear(input_dim, transformer_dim)
        self.pos_encoder = SinusoidalPositionalEncoding(d_model=transformer_dim)
        self.dropout_in = nn.Dropout(p=0.25) 

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=transformer_dim, 
            nhead=4, 
            dim_feedforward=2 * transformer_dim, 
            batch_first=True, 
            dropout=0.25
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
        
        # 3. Final Risk Output
        self.cox_head = nn.Linear(transformer_dim, 1)

    def forward(self, x_seq, t_star=None):
        B, S, C, H, W = x_seq.shape # Batch, Sequence Length, Channel, Height, Width
        
        # Reshape for ViT: (B * S, C, H, W)
        x_flat = x_seq.view(B * S, C, H, W)
        
        # 1. Feature Extraction (ViT)
        with torch.no_grad(): # ViT is frozen for efficiency, fine-tune later if needed
             features_flat = self.vit(x_flat) 
        
        # Reshape back to sequence: (B, S, Feature Dim)
        features = features_flat.view(B, S, -1)
        
        # 2. Sequence Processing (Transformer)
        x = self.proj(features)                  
        x = self.pos_encoder(x)          
        x = self.dropout_in(x) 
        
        z = self.transformer_encoder(x) 
        
        final_z = z[:, -1, :] # Use the final sequence state
        
        h_hat = self.cox_head(final_z)
        return h_hat.squeeze()
        
# --- ViT Preprocessing Transform (for DataLoader) ---
image_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
# Cell 5: Training Function and Setup with EARLY STOPPING ðŸ›‘

# --- CRITICAL IMPORTS (Ensure these are available from previous cells) ---
from pycox.models import CoxTime
from pycox.models.loss import CoxPHLoss 
import torchtuples as tt
from sklearn.model_selection import train_test_split
import numpy as np # Needed for infinity/best_loss tracking

# FIX: Initialize the loss using the class name CoxPHLoss() directly
cox_loss = CoxPHLoss()

# --- EARLY STOPPING CONSTANT ---
PATIENCE = 10 # Stop if validation loss doesn't improve for 10 epochs
MAX_EPOCHS = 200 # Set a high ceiling just in case (the original 200)

# --- EVALUATION FUNCTION (To get loss on validation/test data) ---
def evaluate_loss(model, data_loader, device):
    """Calculates the average CoxPH loss on the given data loader."""
    net = model.net
    net.eval()
    total_loss = 0.0
    with torch.no_grad():
        for X_batch, (T_batch, E_batch) in data_loader:
            X_batch = X_batch.to(device)
            T_batch = T_batch.to(device)
            E_batch = E_batch.to(device)
            
            h_hat = net(X_batch) 
            
            loss = model.loss.forward(h_hat.cpu(), T_batch.cpu(), E_batch.cpu())
            total_loss += loss.item() * len(X_batch) # Multiply by batch size to get total loss

    return total_loss / len(data_loader.dataset) # Return average loss per sample

# --- Training Function with EARLY STOPPING ---
def train_model_early_stop(model, train_loader, test_loader, device, max_epochs, patience):
    """Executes the training loop with Early Stopping based on test loss."""
    net = model.net 
    print(f"\n--- Starting Universal Model Training (Max {max_epochs} Epochs with Patience={patience}) ---") 
    
    best_loss = np.inf
    patience_counter = 0
    best_epoch = 0
    
    for epoch in range(1, max_epochs + 1):
        # 1. Training Step
        net.train()
        total_train_loss = 0.0
        for X_batch, (T_batch, E_batch) in train_loader:
            # ... (Standard training loop remains the same) ...
            X_batch = X_batch.to(device)
            T_batch = T_batch.to(device)
            E_batch = E_batch.to(device)
            
            model.optimizer.zero_grad()
            h_hat = net(X_batch) 
            
            loss = model.loss.forward(h_hat.cpu(), T_batch.cpu(), E_batch.cpu())
            loss = loss.to(device)
            
            loss.backward()
            nn.utils.clip_grad_norm_(net.parameters(), max_norm=GRADIENT_CLIP_VALUE)
            model.optimizer.step()
            total_train_loss += loss.item()
            
        avg_train_loss = total_train_loss / len(train_loader)

        # 2. Validation Step (Early Stopping Check)
        avg_val_loss = evaluate_loss(model, test_loader, device)
        
        print(f"Epoch {epoch}/{max_epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Patience: {patience_counter}")
        
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            patience_counter = 0
            best_epoch = epoch
            # CRITICAL: Save the best model state dict
            torch.save(net.state_dict(), 'best_model_state.pt') 
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"\nðŸ›‘ EARLY STOPPING triggered at epoch {epoch}. Best loss: {best_loss:.4f} at epoch {best_epoch}.")
                break
    
    # Load the best weights back into the model before exiting
    print(f"Loading best weights from epoch {best_epoch}...")
    net.load_state_dict(torch.load('best_model_state.pt'))
    
    print("\n--- Training Complete ---") 

# --- EXECUTION SETUP ---
print("DEBUG: Starting Universal Model Setup and Raw Data Loading...")

# Data split
df_train, df_test = train_test_split(final_df, test_size=0.2, random_state=42)

# Create Datasets and Loaders (Requires UniversalImageSequenceDataset and image_transform from Cell 4)
train_dataset = UniversalImageSequenceDataset(df_train, IMAGE_DIR, MAX_SEQ_LEN, image_transform)
test_dataset = UniversalImageSequenceDataset(df_test, IMAGE_DIR, MAX_SEQ_LEN, image_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# We need to use the test_loader for validation in the loop
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False) 
print(f"DEBUG: DataLoader initialized. Total training batches: {len(train_loader)}")

# Model Initialization (Requires UniversalSurvivalModel defined in Cell 4)
net = UniversalSurvivalModel().to(DEVICE)

# Initialize Adam with HIGH WEIGHT_DECAY for L2 regularization
optimizer = tt.optim.Adam(LEARNING_RATE, weight_decay=WEIGHT_DECAY) 
global model
model = CoxTime(
    net=net,
    loss=cox_loss,
    optimizer=optimizer, 
    device=DEVICE
)

print(f"DEBUG: CoxTime Model initialized. Device: {DEVICE}")

# Prepare targets for baseline hazard estimation later in Cell 6
T_train_np = df_train['time'].values
E_train_np = df_train['event'].values
global tt_target_train
tt_target_train = tt.tuplefy(T_train_np, E_train_np)

# --- RUN TRAINING ---
try:
    # Use the Early Stopping function
    train_model_early_stop(model, train_loader, test_loader, DEVICE, MAX_EPOCHS, PATIENCE)
except Exception as e:
    print(f"\n!!! CRITICAL TRAINING ERROR: {e} !!!")

DEBUG: Starting Universal Model Setup and Raw Data Loading...
DEBUG: DataLoader initialized. Total training batches: 23
DEBUG: CoxTime Model initialized. Device: cuda

--- Starting Universal Model Training (Max 200 Epochs with Patience=10) ---
Epoch 1/200 | Train Loss: 2.0379 | Val Loss: 1.8000 | Patience: 0
Epoch 2/200 | Train Loss: 1.9711 | Val Loss: 1.7993 | Patience: 0
Epoch 3/200 | Train Loss: 1.9669 | Val Loss: 1.7991 | Patience: 0
Epoch 4/200 | Train Loss: 1.9655 | Val Loss: 1.7986 | Patience: 0
Epoch 5/200 | Train Loss: 2.0113 | Val Loss: 1.7985 | Patience: 0
Epoch 6/200 | Train Loss: 1.9631 | Val Loss: 1.7985 | Patience: 0
Epoch 7/200 | Train Loss: 1.9739 | Val Loss: 1.7983 | Patience: 0
Epoch 8/200 | Train Loss: 1.9659 | Val Loss: 1.7981 | Patience: 0
Epoch 9/200 | Train Loss: 1.9440 | Val Loss: 1.7981 | Patience: 0
Epoch 10/200 | Train Loss: 1.9320 | Val Loss: 1.7976 | Patience: 0


In [None]:
# Cell 6: Evaluation and Baseline Calculation (FINAL ROBUST SOLUTION - SHAPE FIX)

import torchtuples as tt 
import numpy as np
import torch
import pandas as pd
from lifelines.utils import concordance_index

# --- BASELINE CALCULATION FUNCTION ---
def compute_baseline_and_ready_model(model, train_loader, tt_target_train, device):
    """Computes the baseline hazard, which enables predict_survival_function."""
    print("\nDEBUG: Estimating Baseline Survival Function...")
    
    X_train_list = [X_batch for X_batch, _ in train_loader]
    X_train_all_unsorted = torch.cat(X_train_list) 
    
    T_train_np = tt_target_train[0]
    E_train_np = tt_target_train[1]
    sort_idx = np.argsort(T_train_np)

    X_train_all_sorted = X_train_all_unsorted[sort_idx].to(device)
    T_train_np_sorted = T_train_np[sort_idx]
    E_train_np_sorted = E_train_np[sort_idx]

    model.compute_baseline_hazards(
        input=X_train_all_sorted, 
        target=tt.tuplefy(T_train_np_sorted, E_train_np_sorted)
    )
    print("DEBUG: Baseline Survival Function estimated successfully.")
    return model

# --- EVALUATION FUNCTION ---
def evaluate_model(model, test_loader, device):
    """Calculates the Concordance Index (C-index) and Survival Functions."""
    net = model.net
    net.eval()
    
    h_hat_list, T_list, E_list, X_list = [], [], [], []
    
    with torch.no_grad():
        for X_batch, (T_batch, E_batch) in test_loader:
            X_batch = X_batch.to(device)
            h_hat = net(X_batch)
            h_hat_list.append(h_hat.cpu().numpy().flatten())
            T_list.extend(T_batch.cpu().numpy())
            E_list.extend(E_batch.cpu().numpy())
            X_list.append(X_batch.cpu())

    h_test_all = np.concatenate(h_hat_list)
    T_test_all = np.array(T_list)
    E_test_all = np.array(E_list)
    
    c_index = concordance_index(T_test_all, -h_test_all, E_test_all) 
    
    if hasattr(model, 'predict_survival_function'):
        surv_method = model.predict_survival_function
    elif hasattr(model, 'predict_surv'):
        surv_method = model.predict_surv
    else:
        raise AttributeError("Prediction method not found.")

    surv_output = surv_method(tt.tuplefy(torch.cat(X_list)))

    return c_index, surv_output

# --- EXECUTION ---

# 1. Prepare the model by calculating the baseline hazard
global model
model = compute_baseline_and_ready_model(model, train_loader, tt_target_train, DEVICE)

# 2. Run the final evaluation
c_index, surv_result = evaluate_model(model, test_loader, DEVICE) 

# ðŸš¨ CRITICAL FIX: Transpose the array (surv_result.T) to match the index shape
if isinstance(surv_result, np.ndarray) or isinstance(surv_result, torch.Tensor):
    # Transpose the prediction array before creating the DataFrame
    global surv_df
    surv_df = pd.DataFrame(surv_result.T, index=model.baseline_hazards_.index) 
elif hasattr(surv_result, 'to_pandas'):
    surv_df = surv_result.to_pandas()
else:
    surv_df = surv_result

print(f"\n--- Model Performance ---")
print(f"Test Concordance Index (C-index): {c_index:.4f}")

In [None]:
# Cell 7: Final Visualization and CLINICAL RESOURCE PREDICTION ðŸš‘

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# --- 7.1. Clinical Prediction Constants ---
CRITICAL_TIME = 730 # 2 years (730 days) - The timeframe for the prediction
MAX_FINANCIAL_LOSS = 50000.0 # Kept for the Expected Loss calculation

# --- NEW CLINICAL THRESHOLDS ---
# Probability of Survival (S(t)) thresholds for clinical placement at T_critical
P_ICU_THRESHOLD = 0.40      # Survival <= 40% requires ICU
P_WARD_THRESHOLD = 0.70     # Survival <= 70% requires Inpatient Ward

# --- 7.2. Core Prediction Functions ---

def calculate_mean_time_to_event(surv_df):
    """Calculates Mean Time to Event (MTTE)."""
    return surv_df.apply(lambda col: np.trapz(col.values, surv_df.index.values), axis=0)

def get_survival_at_critical_time(surv_df, critical_time):
    """Safely retrieves predicted S(t) at the critical time using interpolation."""
    s_t = surv_df.apply(lambda col: np.interp(critical_time, surv_df.index.values, col.values), axis=0)
    return s_t

def assign_clinical_triage(s_t, p_icu, p_ward):
    """
    Assigns each patient to a level of care based on their predicted Survival Probability.
    """
    def assign_care_level(prob_survival):
        if prob_survival <= p_icu:
            return '1 - Critical: ICU Admission'
        elif prob_survival <= p_ward:
            return '2 - Serious: Inpatient Ward'
        else:
            return '3 - Stable: Outpatient Follow-up'

    return s_t.apply(assign_care_level).rename('Clinical_Care_Level')

def calculate_risk_volatility(surv_df, critical_time):
    """Calculates the Volatility Score (rate of risk change)."""
    gradients = surv_df.apply(lambda col: np.gradient(col.values, surv_df.index.values), axis=0)
    
    def interpolate_gradient(gradients_series, index, critical_time):
        return np.interp(critical_time, index.values, gradients_series.values)

    volatility_at_t = gradients.apply(lambda col: interpolate_gradient(col, surv_df.index, critical_time), axis=0)
    abs_volatility = np.abs(volatility_at_t)
    min_vol, max_vol = abs_volatility.min(), abs_volatility.max()
    
    if max_vol == min_vol:
        volatility_score = pd.Series(50, index=abs_volatility.index)
    else:
        volatility_score = 100 * (abs_volatility - min_vol) / (max_vol - min_vol)
        
    return volatility_score.rename('Volatility_Score')


# --- 7.3. EXECUTION AND VISUALIZATION ---
print("\n--- CLINICAL RESOURCE PREDICTION & FINANCIAL RISK ANALYSIS ---")

# 1. Get Survival Probability at T_critical
survival_at_t = get_survival_at_critical_time(surv_df, CRITICAL_TIME)

# 2. NEW: Assign Clinical Care Level (Predicting ICU Need)
patient_clinical_triage = assign_clinical_triage(survival_at_t, P_ICU_THRESHOLD, P_WARD_THRESHOLD)

# 3. Calculate Volatility Score
patient_volatility = calculate_risk_volatility(surv_df, CRITICAL_TIME)

# 4. Calculate Expected Loss (Financial Metric)
prob_event = 1.0 - survival_at_t
expected_loss = prob_event * MAX_FINANCIAL_LOSS

# --- PRINTING RESULTS ---
print(f"Prediction Window: T={CRITICAL_TIME} Days.")

print(f"\n1. Patient Clinical Care Level (ICU Prediction): ðŸ©º")
print(patient_clinical_triage.head())

print(f"\n2. Patient Expected Loss (Financial Risk): ðŸ’µ")
print(expected_loss.head().map('${:,.2f}'.format))

print(f"\n3. Patient Survival Volatility Score (Rate of Risk Change): ðŸ’¥")
print(patient_volatility.head().map('{:.1f}'.format))


# --- Descriptive Statistics of Clinical Triage ---
clinical_counts = patient_clinical_triage.value_counts(normalize=True).map('{:.1%}'.format)
print(f"\n--- Hospital Resource Allocation Summary ---")
print(clinical_counts)


# --- 7.4. Visualization ---
plt.figure(figsize=(12, 7))

triage_groups = patient_clinical_triage.unique()
palette = {
    '1 - Critical: ICU Admission': 'red', 
    '2 - Serious: Inpatient Ward': 'orange', 
    '3 - Stable: Outpatient Follow-up': 'green'
}

for group in triage_groups:
    patients_in_group = patient_clinical_triage[patient_clinical_triage == group].index
    
    if len(patients_in_group) > 0:
        surv_df[patients_in_group].iloc[:, :min(5, len(patients_in_group))].plot(
            legend=False, 
            alpha=0.6, 
            color=palette.get(group, 'blue'),
            ax=plt.gca(),
            label=group if group == triage_groups[0] else "" 
        )

plt.title('Survival Curves Grouped by Predicted Clinical Care Level')
plt.ylabel('Probability of Survival, S(t)')
plt.xlabel(f'Time (Duration in days)')

plt.axvline(CRITICAL_TIME, color='r', linestyle=':', label=f'Critical Prediction Time ({CRITICAL_TIME} days)')

# Create custom legend for the groups
from matplotlib.lines import Line2D
legend_elements = [
    Line2D([0], [0], color=palette['1 - Critical: ICU Admission'], lw=3, label=f'ICU (S(t) $\\leq {P_ICU_THRESHOLD:.0f}$%)'),
    Line2D([0], [0], color=palette['2 - Serious: Inpatient Ward'], lw=3, label=f'Ward (S(t) $\\leq {P_WARD_THRESHOLD:.0f}$%)'),
    Line2D([0], [0], color=palette['3 - Stable: Outpatient Follow-up'], lw=3, label='Outpatient'),
    Line2D([0], [0], color='r', linestyle=':', lw=1, label=f'T={CRITICAL_TIME} Days')
]
plt.legend(handles=legend_elements, loc='best')
plt.grid(True, linestyle='--', alpha=0.6)
plt.show()