## Q4.1

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import re
from tqdm.auto import tqdm
from sklearn.metrics import roc_auc_score, average_precision_score
import ollama
import concurrent.futures
import time

# Paths and constants
ML_READY_PATH = 'ml_ready_data'
PROCESSED_PATH = 'processed_data'
MODELS_DIR = 'models'
RESULTS_DIR = 'results'

os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

STATIC_VARS = ['Age', 'Gender', 'Height', 'Weight']
TARGET_VAR = 'In_hospital_death'
LLM_MODEL = 'gemma2:2b'

# Try to load processed data
print("Loading preprocessed data...")
try:
    # Load the raw processed data for clinical interpretability
    print("Loading raw processed data for clinical interpretability...")
    train_data_raw = pd.read_parquet(f'{PROCESSED_PATH}/set-a.parquet')
    test_data_raw = pd.read_parquet(f'{PROCESSED_PATH}/set-c.parquet')
    
    # Load the patient features (aggregated data) for the target variable and indices
    X_train_agg = pd.read_parquet(f'{ML_READY_PATH}/patient_features-a.parquet')
    X_test_agg = pd.read_parquet(f'{ML_READY_PATH}/patient_features-c.parquet')
    
    # Set index to PatientID for convenience
    X_train_agg = X_train_agg.set_index('PatientID') if 'PatientID' in X_train_agg.columns else X_train_agg
    X_test_agg = X_test_agg.set_index('PatientID') if 'PatientID' in X_test_agg.columns else X_test_agg
    
    # Extract labels
    y_train = X_train_agg[TARGET_VAR] if TARGET_VAR in X_train_agg.columns else None
    y_test = X_test_agg[TARGET_VAR] if TARGET_VAR in X_test_agg.columns else None

    print("Successfully loaded preprocessed data.")
    DATA_AVAILABLE = True
except Exception as e:
    print(f"Error loading preprocessed data: {e}")
    DATA_AVAILABLE = False

# Define helper functions for metrics calculation
def calculate_metrics(y_true, y_pred_proba, set_name="Test"):
    """Calculates and prints AuROC and AuPRC."""
    try:
        y_true, y_pred_proba = np.asarray(y_true), np.asarray(y_pred_proba)
        if not np.all(np.isfinite(y_true)): 
            print(f"Warning: Non-finite y_true for {set_name}. Skip."); 
            return np.nan, np.nan
        if not np.all(np.isfinite(y_pred_proba)): 
            print(f"Warning: Non-finite y_pred_proba for {set_name}. Replace."); 
            y_pred_proba = np.nan_to_num(y_pred_proba, nan=0.5, posinf=1.0, neginf=0.0)
        if len(np.unique(y_true)) < 2: 
            print(f"Warning: Only one class in y_true for {set_name}."); 
            return np.nan, np.nan
        
        auroc = roc_auc_score(y_true, y_pred_proba)
        auprc = average_precision_score(y_true, y_pred_proba)
        
        print(f"{set_name} AuROC: {auroc:.4f}")
        print(f"{set_name} AuPRC: {auprc:.4f}")
        return auroc, auprc
    except ValueError as e: 
        print(f"Metrics Error for {set_name}: {e}")
        return np.nan, np.nan

def create_raw_patient_features(data):
    """
    Create patient-level features from the raw time series data.
    """
    # Group by patient
    patient_features = []
    
    for patient_id in data['PatientID'].unique():
        # Get the patient's data
        patient_data = data[data['PatientID'] == patient_id].copy()
        
        # Sort by hour
        patient_data = patient_data.sort_values('Hour')
        
        # Initialize feature dict with patient ID and target
        features = {
            'PatientID': patient_id,
            TARGET_VAR: patient_data[TARGET_VAR].iloc[0]  # All rows have the same target value
        }
        
        # Add static features (use the first row since they're constant)
        for var in STATIC_VARS:
            if var in patient_data.columns:
                features[var] = patient_data[var].iloc[0]
        
        # Add ICUType as a reference (not for training)
        if 'ICUType' in patient_data.columns:
            features['ICUType'] = patient_data['ICUType'].iloc[0]
        
        # Identify time series variables
        id_vars = ['PatientID', 'RecordID', 'Hour']
        exclude_cols = STATIC_VARS + id_vars + ['ICUType', TARGET_VAR]
        time_series_vars = [col for col in patient_data.columns if col not in exclude_cols]
        
        # For time series variables, calculate summary statistics
        for var in time_series_vars:
            # Skip if all values are NaN
            if patient_data[var].isna().all():
                continue
                
            # Basic statistics
            features[f'{var}_mean'] = patient_data[var].mean()
            features[f'{var}_std'] = patient_data[var].std()
            features[f'{var}_min'] = patient_data[var].min()
            features[f'{var}_max'] = patient_data[var].max()
        
        patient_features.append(features)
    
    # Convert to DataFrame
    patient_features_df = pd.DataFrame(patient_features)
    
    return patient_features_df

# If we haven't created raw patient features yet, do it now
print("Creating raw patient features for clinical interpretation...")
raw_train_features = create_raw_patient_features(train_data_raw)
raw_test_features = create_raw_patient_features(test_data_raw)

# Set index to PatientID
raw_train_features = raw_train_features.set_index('PatientID')
raw_test_features = raw_test_features.set_index('PatientID')

# Print some info about the raw features
print(f"Raw train features shape: {raw_train_features.shape}")
print(f"Raw test features shape: {raw_test_features.shape}")

# Identify time series variables from the raw data
id_vars = ['PatientID', 'RecordID', 'Hour']
exclude_cols = STATIC_VARS + id_vars + ['ICUType', TARGET_VAR]
TIME_SERIES_VARS = [col for col in train_data_raw.columns if col not in exclude_cols]
print(f"Found {len(TIME_SERIES_VARS)} time series variables.")

# Function to create text summary from patient features
def create_text_summary_from_agg(patient_id, df_agg_features, feats_to_include):
    """
    Generates text summary from aggregated patient features.
    Uses the pre-computed aggregated features from your preprocessing.
    """
    if patient_id not in df_agg_features.index: 
        return "Patient data not found."
    
    patient_data_agg = df_agg_features.loc[patient_id]
    summary = []
    
    # Add static features
    for feat in STATIC_VARS:
        if feat in patient_data_agg:
            value = patient_data_agg[feat]
            if pd.isna(value):
                continue
                
            # Format demographic information based on feature type
            if feat == 'Age':
                age = int(value)
                age_group = ""
                if age < 30: age_group = "(young adult)"
                elif age < 40: age_group = "(30s)"
                elif age < 50: age_group = "(40s)"
                elif age < 60: age_group = "(50s)"
                elif age < 70: age_group = "(60s)"
                elif age < 80: age_group = "(70s)"
                else: age_group = "(elderly)"
                summary.append(f"Age {age} {age_group}")
            
            elif feat == 'Gender':
                gender = "Male" if round(value) == 1 else "Female"
                summary.append(f"Gender {gender}")
            
            elif feat == 'Height':
                summary.append(f"Height {value:.1f} cm")
            
            elif feat == 'Weight':
                summary.append(f"Weight {value:.1f} kg")
    
    # Calculate BMI if both height and weight are available
    if 'Height' in patient_data_agg and 'Weight' in patient_data_agg:
        height = patient_data_agg['Height']
        weight = patient_data_agg['Weight']
        if not pd.isna(height) and not pd.isna(weight) and height > 0 and weight > 0:
            height_m = height / 100
            bmi = weight / (height_m * height_m)
            bmi_category = ""
            if bmi < 18.5: bmi_category = "(underweight)"
            elif bmi < 25: bmi_category = "(normal)"
            elif bmi < 30: bmi_category = "(overweight)"
            else: bmi_category = "(obese)"
            summary.append(f"BMI {bmi:.1f} {bmi_category}")
    
    # Track high-risk conditions based on clinical guidelines
    high_risk_conditions = []
    
    # Process dynamic features using the aggregated statistics
    for feat in feats_to_include:
        # Look for derived statistics columns in the aggregated features
        mean_col = f'{feat}_mean'
        max_col = f'{feat}_max'
        min_col = f'{feat}_min'
        
        # Check if these columns exist in patient data
        feat_summary = []
        
        # Add mean value if available
        if mean_col in patient_data_agg:
            mean_val = patient_data_agg[mean_col]
            if not pd.isna(mean_val):
                feat_summary.append(f"mean {mean_val:.1f}")
                
                # Check for specific risk thresholds based on clinical knowledge
                if feat == 'HR' and mean_val > 120: 
                    high_risk_conditions.append("tachycardia")
                elif feat == 'HR' and mean_val < 50: 
                    high_risk_conditions.append("bradycardia")
                elif feat == 'RespRate' and mean_val > 30: 
                    high_risk_conditions.append("respiratory distress")
                elif feat == 'RespRate' and mean_val < 8: 
                    high_risk_conditions.append("respiratory depression")
                elif feat == 'MAP' and mean_val < 65: 
                    high_risk_conditions.append("low MAP")
                elif feat == 'GCS' and mean_val < 8: 
                    high_risk_conditions.append("severe altered mental status")
                elif feat == 'GCS' and mean_val < 13: 
                    high_risk_conditions.append("altered mental status")
                elif feat == 'Lactate' and mean_val > 4.0: 
                    high_risk_conditions.append("severe lactic acidosis")
                elif feat == 'Creatinine' and mean_val > 1.3: 
                    high_risk_conditions.append("renal dysfunction")
        
        # Add max value if available
        if max_col in patient_data_agg:
            max_val = patient_data_agg[max_col]
            if not pd.isna(max_val):
                feat_summary.append(f"max {max_val:.1f}")
                
                # Additional checks based on max values
                if feat == 'Temp' and max_val > 38.5:
                    high_risk_conditions.append("fever")
                elif feat == 'SysABP' and max_val > 160:
                    high_risk_conditions.append("hypertension")
        
        # Add min value if available
        if min_col in patient_data_agg:
            min_val = patient_data_agg[min_col]
            if not pd.isna(min_val):
                feat_summary.append(f"min {min_val:.1f}")
                
                # Additional checks based on min values
                if feat == 'SysABP' and min_val < 90:
                    high_risk_conditions.append("hypotension")
                elif feat == 'Temp' and min_val < 36.0:
                    high_risk_conditions.append("hypothermia")
        
        # Add feature summary if we have any data
        if feat_summary:
            # Add appropriate units based on the variable
            units = ""
            if feat in ['HR']: units = "bpm"
            elif feat in ['RespRate']: units = "breaths/min"
            elif feat in ['SysABP', 'DiasABP', 'MAP', 'NISysABP', 'NIDiasABP', 'NIMAP']: units = "mmHg"
            elif feat in ['Temp']: units = "°C"
            elif feat in ['Glucose']: units = "mg/dL"
            elif feat in ['Lactate']: units = "mmol/L"
            elif feat in ['Creatinine', 'BUN']: units = "mg/dL"
            elif feat in ['HCT']: units = "%"
            elif feat in ['WBC', 'Platelets']: units = "K/uL"
            elif feat in ['PaO2', 'PaCO2']: units = "mmHg"
            elif feat in ['Urine']: units = "mL"
            
            summary.append(f"{feat}: {' '.join(feat_summary)}{' ' + units if units else ''}")
    
    # Add high risk conditions if any
    if high_risk_conditions:
        unique_conditions = list(set(high_risk_conditions))  # Remove duplicates
        summary.append(f"High-risk factors: {', '.join(unique_conditions)}")
    
    return "Patient Summary: " + ", ".join(summary) + "."

# Create a prompt template string that can be reused
prompt_template = """Given the following ICU patient summary over 48 hours:
{summary_text}

As a medical expert, estimate the likelihood of in-hospital mortality based ONLY on this data.
Consider these risk factors (based on analysis of ICU patient data):

1. Demographics:
   - Advanced age (particularly >70 years)
   - Gender (males have slightly higher mortality in some ICU settings)

2. Vital Signs:
   - Low or high heart rate (bradycardia <50, tachycardia >120)
   - Abnormal blood pressure (SBP <90 or >160, MAP <65)
   - Abnormal temperature (hypothermia <36°C, fever >38.5°C)
   - Respiratory distress (RR >30 or <8)

3. Neurological:
   - Low GCS (Glasgow Coma Scale) - especially if <8
   - Altered mental status

4. Laboratory:
   - Elevated lactate (>2 mmol/L, especially >4)
   - Abnormal renal function (high creatinine >1.3, high BUN >20)
   - Abnormal WBC (leukocytosis or leukopenia)
   - Thrombocytopenia (platelets <150K)
   - Severe anemia (HCT <30%)

5. Respiratory:
   - Hypoxemia (low PaO2 or SaO2)
   - High oxygen requirements (FiO2 >0.5)
   - Mechanical ventilation
   
6. Other:
   - Low urine output
   - Multiple organ dysfunction

Rate the mortality risk on a scale of 1-10 where:
1 = Very low risk (<5% chance of death)
10 = Very high risk (>80% chance of death)

Respond with ONLY the integer number between 1 and 10.
Score:"""

# Function to query the LLM for a mortality score (with batching optimization)
def query_llm_score_with_retries(summary_text, model_name, max_retries=3):
    """Query LLM with retry logic"""
    prompt = prompt_template.format(summary_text=summary_text)
    
    for attempt in range(max_retries):
        try: 
            response = ollama.chat(model=model_name, messages=[{'role': 'user', 'content': prompt}])
            content = response['message']['content']
            
            # Try to extract just the score number
            match = re.search(r'\b([1-9]|10)\b', content)
            score = int(match.group(0)) if match else np.nan
            
            # Fallback extraction if primary fails
            if pd.isna(score): 
                match_digit = re.search(r'\d+', content)
                score = max(1, min(10, int(match_digit.group(0)))) if match_digit else np.nan
            
            if not pd.isna(score):
                return score
            
            # If we're here, we couldn't parse a valid score
            if attempt < max_retries - 1:
                time.sleep(1)  # Short delay before retry
                
        except Exception as e:
            if attempt < max_retries - 1:
                time.sleep(1)  # Short delay before retry
    
    return np.nan  # Return NaN if all attempts failed

# Optimized batch processing function with a thread pool
def process_patients_in_batches(patient_ids, features_df, features_for_summary):
    """Process patients in batches using thread pool for parallel processing"""
    # Create a thread pool
    scores = []
    batch_size = 10  # Adjust based on your system capabilities
    
    with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
        # Process patients in batches
        for batch_start in range(0, len(patient_ids), batch_size):
            batch_end = min(batch_start + batch_size, len(patient_ids))
            batch_ids = patient_ids[batch_start:batch_end]
            
            # Create patient summaries for this batch
            summaries = [create_text_summary_from_agg(pid, features_df, features_for_summary) for pid in batch_ids]
            
            # Submit queries to the thread pool
            future_to_idx = {executor.submit(query_llm_score_with_retries, summary, LLM_MODEL): i 
                            for i, summary in enumerate(summaries)}
            
            # Process completed futures
            batch_scores = [np.nan] * len(summaries)
            for future in concurrent.futures.as_completed(future_to_idx):
                idx = future_to_idx[future]
                try:
                    score = future.result()
                    batch_scores[idx] = score
                    
                    # Print occasional updates (every 5th patient)
                    if (batch_start + idx) % 5 == 0:
                        summary = summaries[idx]
                        print(f"\nExample {batch_start + idx}:")
                        print(f"Summary: {summary}")
                        print(f"LLM Score: {score}/10")
                except Exception as e:
                    print(f"Error processing patient: {e}")
            
            # Add batch scores to overall scores
            scores.extend(batch_scores)
            
            # Show progress
            print(f"Processed {batch_end}/{len(patient_ids)} patients")
            
    return scores

# Define features for the patient summaries
features_for_summary_llm = [
    # Vital signs (well-measured with clinical relevance)
    'HR', 'RespRate', 'SysABP', 'DiasABP', 'MAP', 'Temp',
    
    # Neurological status (strong predictor of outcomes)
    'GCS',
    
    # Respiratory parameters
    'PaO2', 'SaO2', 'FiO2', 'MechVent',
    
    # Key lab values
    'Lactate', 'Creatinine', 'BUN', 'Glucose', 'WBC', 'Platelets', 'HCT',
    
    # Other important indicators
    'Urine', 'pH'
]

# Verify features exist in the raw data
available_features = []
for feature in features_for_summary_llm:
    mean_feature = f"{feature}_mean"
    if mean_feature in raw_train_features.columns:
        available_features.append(feature)

missing_features = set(features_for_summary_llm) - set(available_features)
if missing_features:
    print(f"Warning: Some desired features not found: {missing_features}")
    features_for_summary_llm = available_features

print(f"Using {len(features_for_summary_llm)} features for patient summaries.")

# Test one patient summary to verify it looks reasonable
test_pid = X_test_agg.index[0]
test_summary = create_text_summary_from_agg(test_pid, raw_test_features, features_for_summary_llm)
print(f"\nSample patient summary:\n{test_summary}")

# Main execution
print(f"\n--- Evaluating LLM Prompting ({LLM_MODEL}) ---")

# Process a subset of test patients first
test_size = 20  # Modify this for your needs (start small!)
print(f"Testing on first {test_size} patients...")
test_patient_ids = X_test_agg.index[:test_size].tolist()
print("testing patients with ids " + str(test_patient_ids))

print("Processing test patients...")
test_scores = process_patients_in_batches(test_patient_ids, raw_test_features, features_for_summary_llm)

# Validate scores and convert to probabilities (1-10 scale to 0-1)
valid_scores = [s for s in test_scores if not pd.isna(s)]
if len(valid_scores) < len(test_scores) * 0.8:  # Less than 80% valid scores
    print(f"Warning: Only {len(valid_scores)}/{len(test_scores)} test predictions were valid.")
    print("Check Ollama server connection and model availability.")
    proceed = input("Do you want to continue with more patients? (y/n): ")
    if proceed.lower() != 'y':
        print("Stopping after test run")
    else:
        # Process remaining patients if user wants to continue
        remaining_size = 100  # Adjust as needed
        if remaining_size > 0:
            print(f"Processing additional {remaining_size} patients...")
            additional_patient_ids = X_test_agg.index[test_size:test_size+remaining_size].tolist()
            additional_scores = process_patients_in_batches(additional_patient_ids, raw_test_features, features_for_summary_llm)
            test_patient_ids.extend(additional_patient_ids)
            test_scores.extend(additional_scores)
else:
    # If test set was successful, continue with more patients
    remaining_size = 100  # Adjust as needed
    if remaining_size > 0:
        print(f"Test run successful. Processing additional {remaining_size} patients...")
        additional_patient_ids = X_test_agg.index[test_size:test_size+remaining_size].tolist()
        additional_scores = process_patients_in_batches(additional_patient_ids, raw_test_features, features_for_summary_llm)
        test_patient_ids.extend(additional_patient_ids)
        test_scores.extend(additional_scores)

# Convert scores to probabilities and evaluate performance
y_true = y_test.loc[test_patient_ids].values
y_pred_proba = (np.array(test_scores, dtype=float) - 1.0) / 9.0  # Convert 1-10 to 0-1 scale

# Handle missing predictions
valid_indices = ~np.isnan(y_pred_proba) & ~np.isnan(y_true)

if np.sum(valid_indices) == 0:
    print("Error: No valid LLM scores/targets.")
else:
    print(f"Valid LLM scores: {np.sum(valid_indices)}/{len(y_true)}")
    
    # Calculate and report metrics
    auroc_llm, auprc_llm = calculate_metrics(
        y_true[valid_indices], 
        y_pred_proba[valid_indices], 
        set_name=f"Test (LLM Prompt {LLM_MODEL})"
    )
    
    # Save results 
    with open(f"{RESULTS_DIR}/llm_prompt_results.txt", "w") as f:
        f.write(f"LLM Model: {LLM_MODEL}\n")
        f.write(f"Valid predictions: {np.sum(valid_indices)}/{len(y_true)}\n")
        f.write(f"AuROC: {auroc_llm:.4f}\n")
        f.write(f"AuPRC: {auprc_llm:.4f}\n")
        
    # Create a histogram of prediction scores
    plt.figure(figsize=(10, 6))
    plt.hist(np.array(test_scores)[valid_indices], bins=10, alpha=0.7)
    plt.title(f'Distribution of LLM Risk Scores ({LLM_MODEL})')
    plt.xlabel('Risk Score (1-10)')
    plt.ylabel('Count')
    plt.savefig(f"{RESULTS_DIR}/llm_score_distribution.png")
    plt.close()

print(f"\n--- Finished LLM Evaluation ---")

Loading preprocessed data...
Loading raw processed data for clinical interpretability...
Successfully loaded preprocessed data.
Creating raw patient features for clinical interpretation...
Raw train features shape: (4000, 150)
Raw test features shape: (4000, 150)
Found 36 time series variables.
Using 20 features for patient summaries.

Sample patient summary:
Patient Summary: Age 72 (70s), Gender Male, Height -1.0 cm, Weight 60.6 kg, HR: mean 52.2 max 60.0 min 43.0 bpm, RespRate: mean 16.6 max 20.0 min 12.0 breaths/min, SysABP: mean 125.4 max 150.0 min 94.0 mmHg, DiasABP: mean 54.6 max 67.0 min 49.0 mmHg, MAP: mean 77.4 max 96.0 min 62.0 mmHg, Temp: mean 36.6 max 37.2 min 35.1 °C, GCS: mean 13.9 max 15.0 min 10.0, PaO2: mean 123.0 max 123.0 min 123.0 mmHg, Lactate: mean 1.3 max 1.3 min 1.3 mmol/L, Creatinine: mean 0.8 max 0.9 min 0.7 mg/dL, BUN: mean 8.2 max 9.0 min 6.0 mg/dL, Glucose: mean 123.3 max 145.0 min 112.0 mg/dL, WBC: mean 8.2 max 9.0 min 7.3 K/uL, Platelets: mean 313.5 max 3

## Q4.2

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import re
import time
from tqdm.auto import tqdm
from sklearn.preprocessing import RobustScaler
from sklearn.metrics import roc_auc_score, average_precision_score, silhouette_score
from sklearn.linear_model import LogisticRegression
from sklearn.manifold import TSNE
import pickle
import warnings
import ollama
import concurrent.futures

warnings.filterwarnings('ignore')

# Define paths (matching your preprocessing code)
ML_READY_PATH = 'ml_ready_data'   # Where your scaled and feature-engineered files are
PROCESSED_PATH = 'processed_data' # Where your initial imputed files are
MODELS_DIR = 'models'             # Directory to save trained models
RESULTS_DIR = 'results'           # Directory to save plots and result summaries

os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

STATIC_VARS = ['Age', 'Gender', 'Height', 'Weight']
TARGET_VAR = 'In_hospital_death'
LLM_MODEL = 'gemma2:2b'

# Set random seed for reproducibility
SEED = 42
np.random.seed(SEED)

# Try to load processed data
print("Loading preprocessed data...")
try:
    # Load the patient features (aggregated data)
    X_train_agg = pd.read_parquet(f'{ML_READY_PATH}/patient_features-a.parquet')
    X_test_agg = pd.read_parquet(f'{ML_READY_PATH}/patient_features-c.parquet')
    
    # Set index to PatientID for convenience
    X_train_agg = X_train_agg.set_index('PatientID') if 'PatientID' in X_train_agg.columns else X_train_agg
    X_test_agg = X_test_agg.set_index('PatientID') if 'PatientID' in X_test_agg.columns else X_test_agg
    
    # Extract labels
    y_train = X_train_agg[TARGET_VAR] if TARGET_VAR in X_train_agg.columns else None
    y_test = X_test_agg[TARGET_VAR] if TARGET_VAR in X_test_agg.columns else None
    
    # Identify time series variables by looking at column names
    TIME_SERIES_VARS = []
    for col in X_train_agg.columns:
        if '_mean' in col:
            base_var = col.replace('_mean', '')
            if base_var not in STATIC_VARS and base_var != TARGET_VAR:
                TIME_SERIES_VARS.append(base_var)
    
    TIME_SERIES_VARS = list(set(TIME_SERIES_VARS))  # Deduplicate
    print(f"Found {len(TIME_SERIES_VARS)} time series variables.")
    
    print("Successfully loaded preprocessed data.")
    DATA_AVAILABLE = True
except Exception as e:
    print(f"Error loading preprocessed data: {e}")
    DATA_AVAILABLE = False

# Define helper functions for metrics calculation
def calculate_metrics(y_true, y_pred_proba, set_name="Test"):
    """Calculates and prints AuROC and AuPRC."""
    try:
        y_true, y_pred_proba = np.asarray(y_true), np.asarray(y_pred_proba)
        if not np.all(np.isfinite(y_true)): 
            print(f"Warning: Non-finite y_true for {set_name}. Skip.")
            return np.nan, np.nan
        if not np.all(np.isfinite(y_pred_proba)): 
            print(f"Warning: Non-finite y_pred_proba for {set_name}. Replace.")
            y_pred_proba = np.nan_to_num(y_pred_proba, nan=0.5, posinf=1.0, neginf=0.0)
        if len(np.unique(y_true)) < 2: 
            print(f"Warning: Only one class in y_true for {set_name}.")
            return np.nan, np.nan
            
        auroc = roc_auc_score(y_true, y_pred_proba)
        auprc = average_precision_score(y_true, y_pred_proba)
        
        print(f"{set_name} AuROC: {auroc:.4f}")
        print(f"{set_name} AuPRC: {auprc:.4f}")
        return auroc, auprc
    except ValueError as e: 
        print(f"Metrics Error for {set_name}: {e}")
        return np.nan, np.nan

# Function for creating text summary from patient features
def create_text_summary_from_agg(patient_id, df_agg_features, feats_to_include):
    """
    Generates text summary from aggregated patient features.
    Uses the pre-computed aggregated features from your preprocessing.
    """
    if patient_id not in df_agg_features.index: 
        return "Patient data not found."
    
    patient_data_agg = df_agg_features.loc[patient_id]
    summary = []
    
    # Add static features
    for feat in STATIC_VARS:
        if feat in patient_data_agg:
            value = patient_data_agg[feat]
            if pd.isna(value):
                continue
                
            # Format demographic information based on feature type
            if feat == 'Age':
                age = int(value)
                age_group = ""
                if age < 30: age_group = "(young adult)"
                elif age < 40: age_group = "(30s)"
                elif age < 50: age_group = "(40s)"
                elif age < 60: age_group = "(50s)"
                elif age < 70: age_group = "(60s)"
                elif age < 80: age_group = "(70s)"
                else: age_group = "(elderly)"
                summary.append(f"Age {age} {age_group}")
            
            elif feat == 'Gender':
                gender = "Male" if round(value) == 1 else "Female"
                summary.append(f"Gender {gender}")
            
            elif feat == 'Height':
                summary.append(f"Height {value:.1f} cm")
            
            elif feat == 'Weight':
                summary.append(f"Weight {value:.1f} kg")
    
    # Calculate BMI if both height and weight are available
    if 'Height' in patient_data_agg and 'Weight' in patient_data_agg:
        height = patient_data_agg['Height']
        weight = patient_data_agg['Weight']
        if not pd.isna(height) and not pd.isna(weight) and height > 0 and weight > 0:
            height_m = height / 100
            bmi = weight / (height_m * height_m)
            bmi_category = ""
            if bmi < 18.5: bmi_category = "(underweight)"
            elif bmi < 25: bmi_category = "(normal)"
            elif bmi < 30: bmi_category = "(overweight)"
            else: bmi_category = "(obese)"
            summary.append(f"BMI {bmi:.1f} {bmi_category}")
    
    # Track high-risk conditions based on clinical guidelines
    high_risk_conditions = []
    
    # Process dynamic features using the aggregated statistics
    for feat in feats_to_include:
        # Look for derived statistics columns in the aggregated features
        mean_col = f'{feat}_mean'
        max_col = f'{feat}_max'
        min_col = f'{feat}_min'
        
        # Check if these columns exist in patient data
        feat_summary = []
        
        # Add mean value if available
        if mean_col in patient_data_agg:
            mean_val = patient_data_agg[mean_col]
            if not pd.isna(mean_val):
                feat_summary.append(f"mean {mean_val:.1f}")
                
                # Check for specific risk thresholds based on clinical knowledge
                if feat == 'HR' and mean_val > 120: 
                    high_risk_conditions.append("tachycardia")
                elif feat == 'HR' and mean_val < 50: 
                    high_risk_conditions.append("bradycardia")
                elif feat == 'RespRate' and mean_val > 30: 
                    high_risk_conditions.append("respiratory distress")
                elif feat == 'RespRate' and mean_val < 8: 
                    high_risk_conditions.append("respiratory depression")
                elif feat == 'MAP' and mean_val < 65: 
                    high_risk_conditions.append("low MAP")
                elif feat == 'GCS' and mean_val < 8: 
                    high_risk_conditions.append("severe altered mental status")
                elif feat == 'GCS' and mean_val < 13: 
                    high_risk_conditions.append("altered mental status")
                elif feat == 'Lactate' and mean_val > 4.0: 
                    high_risk_conditions.append("severe lactic acidosis")
                elif feat == 'Creatinine' and mean_val > 1.3: 
                    high_risk_conditions.append("renal dysfunction")
        
        # Add max value if available
        if max_col in patient_data_agg:
            max_val = patient_data_agg[max_col]
            if not pd.isna(max_val):
                feat_summary.append(f"max {max_val:.1f}")
                
                # Additional checks based on max values
                if feat == 'Temp' and max_val > 38.5:
                    high_risk_conditions.append("fever")
                elif feat == 'SysABP' and max_val > 160:
                    high_risk_conditions.append("hypertension")
        
        # Add min value if available
        if min_col in patient_data_agg:
            min_val = patient_data_agg[min_col]
            if not pd.isna(min_val):
                feat_summary.append(f"min {min_val:.1f}")
                
                # Additional checks based on min values
                if feat == 'SysABP' and min_val < 90:
                    high_risk_conditions.append("hypotension")
                elif feat == 'Temp' and min_val < 36.0:
                    high_risk_conditions.append("hypothermia")
        
        # Add feature summary if we have any data
        if feat_summary:
            # Add appropriate units based on the variable
            units = ""
            if feat in ['HR']: units = "bpm"
            elif feat in ['RespRate']: units = "breaths/min"
            elif feat in ['SysABP', 'DiasABP', 'MAP', 'NISysABP', 'NIDiasABP', 'NIMAP']: units = "mmHg"
            elif feat in ['Temp']: units = "°C"
            elif feat in ['Glucose']: units = "mg/dL"
            elif feat in ['Lactate']: units = "mmol/L"
            elif feat in ['Creatinine', 'BUN']: units = "mg/dL"
            elif feat in ['HCT']: units = "%"
            elif feat in ['WBC', 'Platelets']: units = "K/uL"
            elif feat in ['PaO2', 'PaCO2']: units = "mmHg"
            elif feat in ['Urine']: units = "mL"
            
            summary.append(f"{feat}: {' '.join(feat_summary)}{' ' + units if units else ''}")
    
    # Add high risk conditions if any
    if high_risk_conditions:
        unique_conditions = list(set(high_risk_conditions))  # Remove duplicates
        summary.append(f"High-risk factors: {', '.join(unique_conditions)}")
    
    return "Patient Summary: " + ", ".join(summary) + "."

# Function to get embeddings from Ollama with retry logic
def get_ollama_embedding(text, model_name, max_retries=3):
    """Get vector embedding for text from Ollama model with retry logic."""
    for attempt in range(max_retries):
        try: 
            response = ollama.embeddings(model=model_name, prompt=text)
            return response['embedding']
        except Exception as e:
            if attempt < max_retries - 1:
                time.sleep(1)  # Short delay before retry
                continue
            else:
                print(f"Embedding Error ({model_name}): {e}")
                return None

# Process patient summaries in parallel
def process_summaries_parallel(patients, features_df, features_list, max_workers=5):
    """Generate patient summaries in parallel"""
    summaries = {}
    
    def process_patient(pid):
        return pid, create_text_summary_from_agg(pid, features_df, features_list)
    
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(process_patient, pid) for pid in patients]
        
        for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Generating summaries"):
            pid, summary = future.result()
            summaries[pid] = summary
            
    return summaries

# Get embeddings in parallel with batching
def get_embeddings_parallel(summaries, model_name, batch_size=20, max_workers=5):
    """Get embeddings in parallel with batching"""
    all_embeddings = {}
    patient_ids = list(summaries.keys())
    total_batches = (len(patient_ids) + batch_size - 1) // batch_size
    
    for batch_idx in range(total_batches):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, len(patient_ids))
        batch_ids = patient_ids[start_idx:end_idx]
        print(f"Processing batch {batch_idx+1}/{total_batches} ({len(batch_ids)} patients)")
        
        # Process embeddings in parallel
        def get_embedding_for_patient(pid):
            return pid, get_ollama_embedding(summaries[pid], model_name)
        
        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = [executor.submit(get_embedding_for_patient, pid) for pid in batch_ids]
            
            for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Generating embeddings"):
                try:
                    pid, embedding = future.result()
                    if embedding is not None:
                        all_embeddings[pid] = embedding
                except Exception as e:
                    print(f"Error in embedding generation: {e}")
        
        # Add a short pause between batches to avoid overloading the Ollama server
        if batch_idx < total_batches - 1:
            time.sleep(1)
    
    return all_embeddings

# --- Main execution for Q4.2: Using LLMs to retrieve embeddings ---
def run_llm_embeddings():
    print(f"\n--- Q4.2: Using LLMs to retrieve embeddings ({LLM_MODEL}) ---")
    
    # Get patient IDs from both train and test sets
    train_patient_ids = X_train_agg.index.tolist()
    test_patient_ids = X_test_agg.index.tolist()
    
    # Select features based on clinical importance
    features_for_summary_llm = [
        # Vital signs
        'HR', 'RespRate', 'SysABP', 'DiasABP', 'MAP', 'Temp',
        # Neurological
        'GCS',
        # Respiratory
        'PaO2', 'SaO2', 'FiO2', 'MechVent',
        # Key lab values
        'Lactate', 'Creatinine', 'BUN', 'Glucose', 'WBC', 'Platelets', 'HCT',
        # Other
        'Urine', 'pH'
    ]

    # Verify features exist in the data
    available_features = []
    for feature in features_for_summary_llm:
        mean_feature = f"{feature}_mean"
        if mean_feature in X_train_agg.columns:
            available_features.append(feature)
    
    missing_features = set(features_for_summary_llm) - set(available_features)
    if missing_features:
        print(f"Warning: Some desired features not found: {missing_features}")
        features_for_summary_llm = available_features
        
    print(f"Using {len(features_for_summary_llm)} features for patient summaries.")

    # First, test the embedding functionality with a small subset
    print("Testing embedding functionality with a small subset...")
    test_size = 3  # Small test subset
    test_subset_ids = test_patient_ids[:test_size]
    
    # Create summaries for test subset
    test_subset_summaries = {}
    for pid in test_subset_ids:
        test_subset_summaries[pid] = create_text_summary_from_agg(pid, X_test_agg, features_for_summary_llm)
    
    # Test getting embeddings for the subset
    print("Testing embedding generation...")
    test_embeddings = {}
    for pid in test_subset_ids:
        embedding = get_ollama_embedding(test_subset_summaries[pid], model_name=LLM_MODEL)
        if embedding is not None:
            test_embeddings[pid] = embedding
            print(f"Successfully got embedding for patient {pid}, dimension: {len(embedding)}")
        else:
            print(f"Failed to get embedding for patient {pid}")
    
    if not test_embeddings:
        print("ERROR: Failed to generate test embeddings. Check Ollama server.")
        return
    
    # Prompt user for how many patients to process
    max_patients = input(f"Specify maximum number of patients to process (default: all {len(train_patient_ids)} train, {len(test_patient_ids)} test): ")
    try:
        max_patients = int(max_patients)
        train_patient_ids = train_patient_ids[:max_patients]
        test_patient_ids = test_patient_ids[:max_patients]
        print(f"Will process {len(train_patient_ids)} train and {len(test_patient_ids)} test patients")
    except:
        print(f"Using all patients: {len(train_patient_ids)} train and {len(test_patient_ids)} test")
    
    # Generate summaries in parallel
    print("\nGenerating patient summaries for train set...")
    train_summaries = process_summaries_parallel(
        train_patient_ids, X_train_agg, features_for_summary_llm, max_workers=5
    )
    
    print("\nGenerating patient summaries for test set...")
    test_summaries = process_summaries_parallel(
        test_patient_ids, X_test_agg, features_for_summary_llm, max_workers=5
    )
    
    # Get embeddings in parallel with batching
    print(f"\nGenerating LLM embeddings for train set using {LLM_MODEL}...")
    X_train_llm_emb_dict = get_embeddings_parallel(
        train_summaries, model_name=LLM_MODEL, batch_size=10, max_workers=5
    )
    
    print(f"\nGenerating LLM embeddings for test set using {LLM_MODEL}...")
    X_test_llm_emb_dict = get_embeddings_parallel(
        test_summaries, model_name=LLM_MODEL, batch_size=10, max_workers=5
    )

    # Filter for patients with valid embeddings
    train_ids_with_emb = [pid for pid in train_patient_ids if pid in X_train_llm_emb_dict]
    test_ids_with_emb = [pid for pid in test_patient_ids if pid in X_test_llm_emb_dict]

    if not train_ids_with_emb or not test_ids_with_emb:
        print("Error: Failed to generate sufficient LLM embeddings.")
        return
    
    print(f"Successfully generated embeddings for {len(train_ids_with_emb)}/{len(train_patient_ids)} train patients")
    print(f"Successfully generated embeddings for {len(test_ids_with_emb)}/{len(test_patient_ids)} test patients")
    
    # Convert embeddings to numpy arrays for processing
    X_train_llm_emb = np.array([X_train_llm_emb_dict[pid] for pid in train_ids_with_emb])
    y_train_llm_emb = y_train.loc[train_ids_with_emb].values
    
    X_test_llm_emb = np.array([X_test_llm_emb_dict[pid] for pid in test_ids_with_emb])
    y_test_llm_emb = y_test.loc[test_ids_with_emb].values
    
    # Filter out NaN labels
    valid_train_llm_idx = ~np.isnan(y_train_llm_emb)
    valid_test_llm_idx = ~np.isnan(y_test_llm_emb)

    if np.sum(valid_train_llm_idx) < 2 or np.sum(valid_test_llm_idx) == 0:
        print("Warning: Not enough valid samples for LLM embedding probe.")
        return
    
    X_train_llm_emb = X_train_llm_emb[valid_train_llm_idx]
    y_train_llm_emb = y_train_llm_emb[valid_train_llm_idx]
    
    X_test_llm_emb = X_test_llm_emb[valid_test_llm_idx]
    y_test_llm_emb = y_test_llm_emb[valid_test_llm_idx]
    
    print(f"LLM Embeddings shape: Train X={X_train_llm_emb.shape}, Test X={X_test_llm_emb.shape}")
    
    # Linear Probe on LLM embeddings
    print("\nTraining Linear Probe on LLM Embeddings...")
    probe_llm = LogisticRegression(
        solver='liblinear', random_state=SEED, 
        max_iter=1000, C=1.0, class_weight='balanced'
    )
    probe_llm.fit(X_train_llm_emb, y_train_llm_emb)
    
    # Evaluate on test set
    y_pred_proba_llm_probe = probe_llm.predict_proba(X_test_llm_emb)[:, 1]
    print(f"\nLinear Probe Results (LLM Embeddings - {LLM_MODEL}):")
    auroc_llm_probe, auprc_llm_probe = calculate_metrics(
        y_test_llm_emb, y_pred_proba_llm_probe, 
        set_name=f"Test (LLM Emb Probe {LLM_MODEL})"
    )
    
    # Save probe model
    os.makedirs(f"{MODELS_DIR}/llm_probe", exist_ok=True)
    with open(f"{MODELS_DIR}/llm_probe/probe_model.pkl", "wb") as f:
        pickle.dump(probe_llm, f)
    
    # Visualization with t-SNE
    print("\nVisualizing LLM embeddings with t-SNE...")
    
    # Sample if there are too many points for effective t-SNE visualization
    max_tsne_samples = 1000  # t-SNE can be slow with too many samples
    if len(X_train_llm_emb) > max_tsne_samples:
        print(f"Sampling {max_tsne_samples} points for t-SNE visualization...")
        tsne_indices = np.random.choice(len(X_train_llm_emb), max_tsne_samples, replace=False)
        X_train_llm_emb_tsne_input = X_train_llm_emb[tsne_indices]
        y_train_llm_emb_tsne = y_train_llm_emb[tsne_indices]
    else:
        X_train_llm_emb_tsne_input = X_train_llm_emb
        y_train_llm_emb_tsne = y_train_llm_emb
    
    # Apply t-SNE for visualization
    tsne = TSNE(
        n_components=2, 
        random_state=SEED, 
        perplexity=min(30, len(X_train_llm_emb_tsne_input) - 1),
        n_jobs=-1  # Use all available cores
    )
    
    # Apply t-SNE to training data
    X_train_llm_tsne = tsne.fit_transform(X_train_llm_emb_tsne_input)
    
    # Calculate silhouette score
    if len(np.unique(y_train_llm_emb_tsne)) > 1 and min(np.bincount(y_train_llm_emb_tsne.astype(int))) >= 2:
        silhouette = silhouette_score(X_train_llm_tsne, y_train_llm_emb_tsne)
        print(f"t-SNE Silhouette Score (LLM Embeddings): {silhouette:.4f}")
    else:
        print("Cannot calculate silhouette score: insufficient class distribution")
    
    # Plot t-SNE visualization
    plt.figure(figsize=(10, 8))
    for label in np.unique(y_train_llm_emb_tsne):
        indices = y_train_llm_emb_tsne == label
        plt.scatter(
            X_train_llm_tsne[indices, 0], X_train_llm_tsne[indices, 1],
            label=f"{'Survived' if int(label) == 0 else 'Died'}", 
            alpha=0.7
        )
    
    plt.title(f"t-SNE Visualization of LLM Embeddings ({LLM_MODEL})")
    plt.legend()
    plt.grid(alpha=0.3)
    plt.savefig(f"{RESULTS_DIR}/llm_embedding_tsne.png")
    plt.close()
    
    # Save embeddings for future use
    with open(f"{RESULTS_DIR}/llm_embeddings_train.pkl", "wb") as f:
        pickle.dump({"embeddings": X_train_llm_emb, "labels": y_train_llm_emb}, f)
        
    with open(f"{RESULTS_DIR}/llm_embeddings_test.pkl", "wb") as f:
        pickle.dump({"embeddings": X_test_llm_emb, "labels": y_test_llm_emb}, f)
    
    # Save results
    with open(f"{RESULTS_DIR}/llm_embedding_results.txt", "w") as f:
        f.write(f"LLM Model for Embeddings: {LLM_MODEL}\n")
        f.write(f"Embedding Dimension: {X_train_llm_emb.shape[1]}\n")
        f.write(f"Valid training samples: {len(X_train_llm_emb)}\n")
        f.write(f"Valid test samples: {len(X_test_llm_emb)}\n")
        f.write(f"Linear Probe AuROC: {auroc_llm_probe:.4f}\n")
        f.write(f"Linear Probe AuPRC: {auprc_llm_probe:.4f}\n")
        if 'silhouette' in locals():
            f.write(f"t-SNE Silhouette Score: {silhouette:.4f}\n")
    
    print(f"\nEmbeddings successfully generated and evaluated. Results saved to {RESULTS_DIR}.")

if __name__ == "__main__":
    try:
        run_llm_embeddings()
    except Exception as e:
        print(f"Error during Q4.2 execution: {e}")
        import traceback
        traceback.print_exc()
    
    print(f"\n--- Finished Q4.2 ---")

### Q4.3

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import RobustScaler
from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.linear_model import LogisticRegression
import pickle
import warnings
from chronos.models import ChronosForecaster  # You may need to install this library
import concurrent.futures
import time

warnings.filterwarnings('ignore')

# Define paths
ML_READY_PATH = 'ml_ready_data'   # Where your scaled and feature-engineered files are
PROCESSED_PATH = 'processed_data' # Where your initial imputed files are
MODELS_DIR = 'models'             # Directory to save trained models
RESULTS_DIR = 'results'           # Directory to save plots and result summaries

os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

STATIC_VARS = ['Age', 'Gender', 'Height', 'Weight']
TARGET_VAR = 'In_hospital_death'
CHRONOS_MODEL_NAME = "chronos/chronos-t5-small"  # Example model name (use an appropriate Chronos model)

# Set random seed for reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Try to load processed data
print("Loading preprocessed data...")
try:
    # Load the raw processed data with time series
    train_data_raw = pd.read_parquet(f'{PROCESSED_PATH}/set-a.parquet')
    test_data_raw = pd.read_parquet(f'{PROCESSED_PATH}/set-c.parquet')
    
    # Load the patient features (aggregated data) for the target variable and indices
    X_train_agg = pd.read_parquet(f'{ML_READY_PATH}/patient_features-a.parquet')
    X_test_agg = pd.read_parquet(f'{ML_READY_PATH}/patient_features-c.parquet')
    
    # Set index to PatientID for convenience
    X_train_agg = X_train_agg.set_index('PatientID') if 'PatientID' in X_train_agg.columns else X_train_agg
    X_test_agg = X_test_agg.set_index('PatientID') if 'PatientID' in X_test_agg.columns else X_test_agg
    
    # Extract labels
    y_train = X_train_agg[TARGET_VAR] if TARGET_VAR in X_train_agg.columns else None
    y_test = X_test_agg[TARGET_VAR] if TARGET_VAR in X_test_agg.columns else None
    
    # Identify time series variables
    id_vars = ['PatientID', 'RecordID', 'Hour']
    exclude_cols = STATIC_VARS + id_vars + ['ICUType', TARGET_VAR]
    TIME_SERIES_VARS = [col for col in train_data_raw.columns if col not in exclude_cols]
    print(f"Found {len(TIME_SERIES_VARS)} time series variables.")
    
    print("Successfully loaded preprocessed data.")
    DATA_AVAILABLE = True
except Exception as e:
    print(f"Error loading preprocessed data: {e}")
    DATA_AVAILABLE = False

# Define metrics calculation function
def calculate_metrics(y_true, y_pred_proba, set_name="Test"):
    """Calculates and prints AuROC and AuPRC."""
    try:
        y_true, y_pred_proba = np.asarray(y_true), np.asarray(y_pred_proba)
        if not np.all(np.isfinite(y_true)): 
            print(f"Warning: Non-finite y_true for {set_name}. Skip.")
            return np.nan, np.nan
        if not np.all(np.isfinite(y_pred_proba)): 
            print(f"Warning: Non-finite y_pred_proba for {set_name}. Replace.")
            y_pred_proba = np.nan_to_num(y_pred_proba, nan=0.5, posinf=1.0, neginf=0.0)
        if len(np.unique(y_true)) < 2: 
            print(f"Warning: Only one class in y_true for {set_name}.")
            return np.nan, np.nan
            
        auroc = roc_auc_score(y_true, y_pred_proba)
        auprc = average_precision_score(y_true, y_pred_proba)
        
        print(f"{set_name} AuROC: {auroc:.4f}")
        print(f"{set_name} AuPRC: {auprc:.4f}")
        return auroc, auprc
    except ValueError as e: 
        print(f"Metrics Error for {set_name}: {e}")
        return np.nan, np.nan

# Function to preprocess time series data for Chronos
def preprocess_for_chronos(data, patient_id, variable):
    """
    Preprocess a single time series variable for a single patient to feed into Chronos.
    Returns a numpy array of shape (sequence_length,) with the time series values.
    """
    # Get data for this patient
    patient_data = data[data['PatientID'] == patient_id].copy()
    
    # Sort by hour
    patient_data = patient_data.sort_values('Hour')
    
    # Extract the variable values
    values = patient_data[variable].values
    
    # Handle missing values (NaN) by forward filling then backward filling
    mask = np.isnan(values)
    idx = np.where(~mask, np.arange(len(mask)), 0)
    np.maximum.accumulate(idx, out=idx)
    values = values[idx]
    
    # If all values are NaN, return zeros
    if np.all(np.isnan(values)):
        return np.zeros(len(values))
    
    # Ensure the sequence is long enough (pad if needed)
    min_length = 10  # Minimum length for the model
    if len(values) < min_length:
        padding = np.zeros(min_length - len(values))
        values = np.concatenate([values, padding])
    
    return values

# Load Chronos model (this is a placeholder, you'll need to implement according to Chronos API)
def load_chronos_model():
    """
    Load a pre-trained Chronos time-series foundation model.
    Returns a model that can be used to get embeddings.
    """
    try:
        print(f"Loading pre-trained Chronos model: {CHRONOS_MODEL_NAME}")
        
        # This is a placeholder - you'll need to implement based on Chronos documentation
        # For example:
        # model = ChronosForecaster.from_pretrained(CHRONOS_MODEL_NAME)
        
        # For now, we'll simulate a model with a simple function to generate embeddings
        def mock_chronos_embedding_function(time_series):
            """Mock function to simulate Chronos embeddings for testing"""
            # For testing, we'll return random embeddings of dimension 128
            return np.random.randn(128)
        
        return mock_chronos_embedding_function
        
    except Exception as e:
        print(f"Error loading Chronos model: {e}")
        return None

# Function to get Chronos embeddings for a patient's time series
def get_chronos_embeddings(data, patient_id, variables, chronos_model):
    """
    Get Chronos embeddings for all specified variables for a single patient.
    Returns a dictionary of variable_name -> embedding.
    """
    embeddings = {}
    
    for var in variables:
        try:
            # Preprocess the time series for this variable
            time_series = preprocess_for_chronos(data, patient_id, var)
            
            # Get embedding from Chronos model
            # Note: The actual API call will depend on the Chronos library's interface
            embedding = chronos_model(time_series)
            
            # Store the embedding
            embeddings[var] = embedding
            
        except Exception as e:
            print(f"Error getting embedding for patient {patient_id}, variable {var}: {e}")
    
    return embeddings

# Function to process patients in parallel
def process_patients_in_parallel(patient_ids, data, variables, chronos_model, max_workers=5):
    """Process patient embeddings in parallel using thread pool"""
    all_embeddings = {}
    
    def process_patient(pid):
        return pid, get_chronos_embeddings(data, pid, variables, chronos_model)
    
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(process_patient, pid): pid for pid in patient_ids}
        
        for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing patients"):
            pid = futures[future]
            try:
                pid, embeddings = future.result()
                all_embeddings[pid] = embeddings
            except Exception as e:
                print(f"Error processing patient {pid}: {e}")
    
    return all_embeddings

# Function to create a single embedding per patient by averaging across variables
def average_embeddings(patient_embeddings):
    """
    Create a single embedding per patient by averaging across all variables.
    Returns a dictionary of patient_id -> average_embedding.
    """
    averaged_embeddings = {}
    
    for patient_id, var_embeddings in patient_embeddings.items():
        if not var_embeddings:  # Skip if no embeddings for this patient
            continue
            
        # Convert all embeddings to numpy arrays and stack them
        embeddings_list = [emb for var, emb in var_embeddings.items() if emb is not None]
        
        if not embeddings_list:  # Skip if no valid embeddings
            continue
            
        # Stack and average across variables
        stacked_embeddings = np.stack(embeddings_list)
        avg_embedding = np.mean(stacked_embeddings, axis=0)
        
        averaged_embeddings[patient_id] = avg_embedding
    
    return averaged_embeddings

# Channel aggregation neural network
class ChannelAggregator(nn.Module):
    def __init__(self, input_dim, hidden_dim=64, output_dim=1, num_variables=None):
        super(ChannelAggregator, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        
        # Variable attention mechanism
        if num_variables:
            self.attention = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, 1),
                nn.Softmax(dim=1)
            )
        
        # Final prediction layers
        self.predictor = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 2, output_dim),
            nn.Sigmoid()
        )
    
    def forward(self, x, var_embeddings=None):
        """
        Forward pass through the network.
        If var_embeddings is provided (shape: batch_size x num_variables x embedding_dim),
        use attention mechanism to aggregate variables.
        If x is provided (shape: batch_size x embedding_dim), 
        it's assumed to be already aggregated embeddings.
        """
        if var_embeddings is not None:
            # Apply attention to variable embeddings
            batch_size, num_vars, emb_dim = var_embeddings.shape
            
            # Reshape for attention
            var_embeddings_flat = var_embeddings.view(-1, emb_dim)
            attention_flat = self.attention(var_embeddings_flat)
            attention_weights = attention_flat.view(batch_size, num_vars, 1)
            
            # Apply attention weights
            weighted_embeddings = var_embeddings * attention_weights
            x = weighted_embeddings.sum(dim=1)  # Sum across variables
        
        # Apply prediction layers
        return self.predictor(x)

# Function to train the channel aggregator model
def train_channel_aggregator(train_embeddings, train_labels, val_embeddings=None, val_labels=None, 
                            batch_size=32, epochs=50, learning_rate=0.001, patience=5):
    """
    Train a neural network to aggregate channel embeddings.
    
    Args:
        train_embeddings: Dictionary {patient_id -> {variable -> embedding}}
        train_labels: Dictionary {patient_id -> label}
        val_embeddings: Optional validation embeddings
        val_labels: Optional validation labels
        
    Returns:
        Trained model
    """
    # Process embeddings into tensor format
    patient_ids = list(train_embeddings.keys())
    var_names = list(train_embeddings[patient_ids[0]].keys())
    embedding_dim = train_embeddings[patient_ids[0]][var_names[0]].shape[0]
    
    # Create tensor datasets
    X_train_tensors = []
    y_train_tensors = []
    
    for pid in patient_ids:
        if pid not in train_labels:
            continue
            
        # Stack variable embeddings for this patient
        var_embeddings = []
        for var in var_names:
            if var in train_embeddings[pid]:
                var_embeddings.append(train_embeddings[pid][var])
        
        if not var_embeddings:
            continue
            
        # Stack and append
        stacked_embeddings = np.stack(var_embeddings)
        X_train_tensors.append(stacked_embeddings)
        y_train_tensors.append(train_labels[pid])
    
    # Convert to pytorch tensors
    X_train = torch.tensor(np.stack(X_train_tensors), dtype=torch.float32)
    y_train = torch.tensor(y_train_tensors, dtype=torch.float32).view(-1, 1)
    
    # Create data loader
    train_dataset = TensorDataset(X_train, y_train)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    # Initialize model
    num_variables = len(var_names)
    model = ChannelAggregator(
        input_dim=embedding_dim, 
        hidden_dim=64, 
        output_dim=1, 
        num_variables=num_variables
    )
    
    # Move to GPU if available
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # Initialize optimizer and loss function
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.BCELoss()
    
    # Training loop
    best_val_loss = float('inf')
    early_stop_counter = 0
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        
        for batch_x, batch_y in train_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(None, batch_x)
            loss = criterion(outputs, batch_y)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * batch_x.size(0)
        
        train_loss = train_loss / len(train_loader.dataset)
        
        # Validation
        if val_embeddings and val_labels:
            # Process validation data (similar to training data)
            # ... (code would be similar to training data processing)
            val_loss = 0.0  # Compute validation loss
            
            # Early stopping
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                early_stop_counter = 0
                # Save best model
                torch.save(model.state_dict(), f"{MODELS_DIR}/channel_aggregator.pt")
            else:
                early_stop_counter += 1
                if early_stop_counter >= patience:
                    print(f"Early stopping at epoch {epoch}")
                    break
        else:
            # If no validation data, save the model periodically
            if epoch % 10 == 0:
                torch.save(model.state_dict(), f"{MODELS_DIR}/channel_aggregator_epoch{epoch}.pt")
        
        # Print progress
        if (epoch+1) % 5 == 0:
            print(f'Epoch {epoch+1}/{epochs}, Loss: {train_loss:.4f}')
    
    # Load best model if validation was used
    if val_embeddings and val_labels and os.path.exists(f"{MODELS_DIR}/channel_aggregator.pt"):
        model.load_state_dict(torch.load(f"{MODELS_DIR}/channel_aggregator.pt"))
    
    return model

# Function to get predictions from the channel aggregator model
def predict_with_channel_aggregator(model, embeddings, device=None):
    """
    Get predictions from the trained channel aggregator model.
    
    Args:
        model: Trained ChannelAggregator model
        embeddings: Dictionary {patient_id -> {variable -> embedding}}
        
    Returns:
        Dictionary {patient_id -> prediction}
    """
    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    model.to(device)
    model.eval()
    
    predictions = {}
    
    # Process patients in batches
    patient_ids = list(embeddings.keys())
    var_names = list(next(iter(embeddings.values())).keys())
    
    for i in range(0, len(patient_ids), 32):  # Process in batches of 32
        batch_pids = patient_ids[i:i+32]
        batch_embeddings = []
        
        for pid in batch_pids:
            # Stack variable embeddings for this patient
            var_embeddings = []
            for var in var_names:
                if var in embeddings[pid]:
                    var_embeddings.append(embeddings[pid][var])
            
            if not var_embeddings:
                continue
                
            # Stack and append
            stacked_embeddings = np.stack(var_embeddings)
            batch_embeddings.append(stacked_embeddings)
        
        if not batch_embeddings:
            continue
            
        # Convert to pytorch tensor
        X_batch = torch.tensor(np.stack(batch_embeddings), dtype=torch.float32).to(device)
        
        # Get predictions
        with torch.no_grad():
            outputs = model(None, X_batch)
            outputs = outputs.cpu().numpy().flatten()
        
        # Store predictions
        for j, pid in enumerate(batch_pids[:len(outputs)]):
            predictions[pid] = outputs[j]
    
    return predictions

# Main execution function for Q4.3
def run_time_series_foundation_models():
    """Execute Q4.3: Using time-series foundation models"""
    print("\n--- Q4.3: Using time-series foundation models ---")
    
    # Select a subset of key time series variables (to reduce computation time)
    selected_variables = [
        'HR', 'RespRate', 'SysABP', 'DiasABP', 'MAP', 'Temp', 'GCS',
        'PaO2', 'SaO2', 'FiO2', 'Lactate', 'Creatinine', 'Glucose'
    ]
    
    # Filter to only include variables that exist in our dataset
    available_variables = [var for var in selected_variables if var in TIME_SERIES_VARS]
    print(f"Using {len(available_variables)} available time series variables: {available_variables}")
    
    # Load Chronos model
    chronos_model = load_chronos_model()
    if chronos_model is None:
        print("Failed to load Chronos model. Exiting.")
        return
    
    # Get patient IDs
    train_patient_ids = train_data_raw['PatientID'].unique()
    test_patient_ids = test_data_raw['PatientID'].unique()
    
    # For testing/development, use a smaller subset
    max_patients = input(f"Specify maximum number of patients to process (default: all {len(train_patient_ids)} train, {len(test_patient_ids)} test): ")
    try:
        max_patients = int(max_patients)
        train_patient_ids = train_patient_ids[:max_patients]
        test_patient_ids = test_patient_ids[:max_patients]
        print(f"Will process {len(train_patient_ids)} train and {len(test_patient_ids)} test patients")
    except:
        print(f"Using all patients: {len(train_patient_ids)} train and {len(test_patient_ids)} test")
    
    # Generate Chronos embeddings for each patient's time series
    print("\nGenerating Chronos embeddings for training set...")
    train_embeddings = process_patients_in_parallel(
        train_patient_ids, train_data_raw, available_variables, chronos_model
    )
    
    print(f"\nGenerating Chronos embeddings for test set...")
    test_embeddings = process_patients_in_parallel(
        test_patient_ids, test_data_raw, available_variables, chronos_model
    )
    
    print(f"Successfully generated embeddings for {len(train_embeddings)} train patients and {len(test_embeddings)} test patients")
    
    # --- Part 1: Simple averaging approach ---
    print("\n--- Part 1: Average embeddings across variables ---")
    
    # Average embeddings across variables for each patient
    train_avg_embeddings = average_embeddings(train_embeddings)
    test_avg_embeddings = average_embeddings(test_embeddings)
    
    print(f"Averaged embeddings for {len(train_avg_embeddings)} train patients and {len(test_avg_embeddings)} test patients")
    
    # Convert to numpy arrays for the linear probe
    train_ids = list(train_avg_embeddings.keys())
    test_ids = list(test_avg_embeddings.keys())
    
    X_train = np.array([train_avg_embeddings[pid] for pid in train_ids])
    y_train = np.array([y_train.loc[pid] if pid in y_train.index else np.nan for pid in train_ids])
    
    X_test = np.array([test_avg_embeddings[pid] for pid in test_ids])
    y_test_vals = np.array([y_test.loc[pid] if pid in y_test.index else np.nan for pid in test_ids])
    
    # Remove NaN labels
    valid_train_idx = ~np.isnan(y_train)
    valid_test_idx = ~np.isnan(y_test_vals)
    
    X_train_clean = X_train[valid_train_idx]
    y_train_clean = y_train[valid_train_idx]
    X_test_clean = X_test[valid_test_idx]
    y_test_clean = y_test_vals[valid_test_idx]
    
    print(f"Clean data shapes: Train X={X_train_clean.shape}, Test X={X_test_clean.shape}")
    
    # Train linear probe
    print("\nTraining linear probe on averaged Chronos embeddings...")
    probe = LogisticRegression(
        solver='liblinear', random_state=SEED, 
        max_iter=1000, C=1.0, class_weight='balanced'
    )
    probe.fit(X_train_clean, y_train_clean)
    
    # Evaluate on test set
    y_pred_proba = probe.predict_proba(X_test_clean)[:, 1]
    
    print(f"\nLinear Probe Results (Averaged Chronos Embeddings):")
    auroc_avg, auprc_avg = calculate_metrics(
        y_test_clean, y_pred_proba, set_name="Test (Avg Chronos Embeddings)"
    )
    
    # Save probe model
    os.makedirs(f"{MODELS_DIR}/chronos_probe", exist_ok=True)
    with open(f"{MODELS_DIR}/chronos_probe/avg_probe_model.pkl", "wb") as f:
        pickle.dump(probe, f)
    
    # --- Part 2: Neural Network for channel aggregation ---
    print("\n--- Part 2: Neural network for channel aggregation ---")
    
    # Convert labels to dictionary format for the training function
    train_labels_dict = {pid: y_train.loc[pid] for pid in train_ids if pid in y_train.index and not np.isnan(y_train.loc[pid])}
    test_labels_dict = {pid: y_test.loc[pid] for pid in test_ids if pid in y_test.index and not np.isnan(y_test.loc[pid])}
    
    # Train channel aggregator model
    print("\nTraining channel aggregator neural network...")
    aggregator_model = train_channel_aggregator(
        train_embeddings, train_labels_dict,
        batch_size=16, epochs=30, learning_rate=0.001
    )
    
    # Get predictions
    print("\nGenerating predictions with channel aggregator model...")
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    predictions = predict_with_channel_aggregator(aggregator_model, test_embeddings, device)
    
    # Extract test patients with valid labels for evaluation
    test_pids_for_eval = list(test_labels_dict.keys())
    y_true_nn = np.array([test_labels_dict[pid] for pid in test_pids_for_eval if pid in predictions])
    y_pred_nn = np.array([predictions[pid] for pid in test_pids_for_eval if pid in predictions])
    
    print(f"\nChannel Aggregator Neural Network Results:")
    auroc_nn, auprc_nn = calculate_metrics(
        y_true_nn, y_pred_nn, set_name="Test (Channel Aggregator NN)"
    )
    
    # Save results
    with open(f"{RESULTS_DIR}/chronos_results.txt", "w") as f:
        f.write("Chronos Time Series Foundation Model Results\n")
        f.write("===========================================\n\n")
        f.write("1. Simple Averaging Approach\n")
        f.write(f"Number of training patients: {len(X_train_clean)}\n")
        f.write(f"Number of test patients: {len(X_test_clean)}\n")
        f.write(f"AuROC: {auroc_avg:.4f}\n")
        f.write(f"AuPRC: {auprc_avg:.4f}\n\n")
        f.write("2. Neural Network Channel Aggregation\n")
        f.write(f"Number of training patients: {len(train_labels_dict)}\n")
        f.write(f"Number of test patients with predictions: {len(y_true_nn)}\n")
        f.write(f"AuROC: {auroc_nn:.4f}\n")
        f.write(f"AuPRC: {auprc_nn:.4f}\n")
    
    print(f"\nResults saved to {RESULTS_DIR}/chronos_results.txt")
    print("\n--- Finished Q4.3 ---")

if __name__ == "__main__":
    try:
        run_time_series_foundation_models()
    except Exception as e:
        print(f"Error executing Q4.3: {e}")
        import traceback
        traceback.print_exc()