# CRBSI Prediction - MIMIC-IV Data Preprocessing
## Multi-Task Learning with Static and Temporal Features

This notebook preprocesses MIMIC-IV data for CRBSI prediction using SMTAFormer architecture.

**Features:**
- Static: Demographics, catheter characteristics, baseline labs, comorbidities
- Temporal: Vital signs (hourly), Labs (daily), Catheter events (daily)
- Outcome: CRBSI occurrence (binary classification + time-to-event)

In [2]:
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set display options
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 100)

print("Libraries imported successfully!")

Libraries imported successfully!


## 1. Configuration and Paths

In [3]:
# Path to MIMIC-IV data (update these paths to your data location)
MIMIC_PATH = 'MIMIC-IV (3.1)/'  # UPDATE THIS
HOSP_PATH = MIMIC_PATH + 'hosp/'
ICU_PATH = MIMIC_PATH + 'icu/'

# Output path
OUTPUT_PATH = MIMIC_PATH

# Create output directory
import os
os.makedirs(OUTPUT_PATH, exist_ok=True)

# Time windows configuration
FEATURE_WINDOW_HOURS = 48  # Look back 48 hours for feature extraction
PREDICTION_WINDOW_HOURS = 72  # Predict CRBSI in next 72 hours
SURVIVAL_WINDOW_HOURS = 168  # Track up to 7 days for survival analysis

# Temporal sequence lengths (for model input)
VITAL_SEQ_LENGTH = 48  # 48 hours of hourly vitals
LAB_SEQ_LENGTH = 14  # 7 days × 2 measurements/day
CATHETER_EVENT_SEQ_LENGTH = 14  # 14 days of daily catheter events

print(f"Configuration set!")
print(f"Feature window: {FEATURE_WINDOW_HOURS}h")
print(f"Prediction window: {PREDICTION_WINDOW_HOURS}h")
print(f"Survival tracking: {SURVIVAL_WINDOW_HOURS}h")

Configuration set!
Feature window: 48h
Prediction window: 72h
Survival tracking: 168h


## 2. Load Core MIMIC-IV Tables

In [None]:
print("Loading core tables...")

# Core patient information
patients = pd.read_csv(HOSP_PATH + 'patients.csv')
admissions = pd.read_csv(HOSP_PATH + 'admissions.csv')
icustays = pd.read_csv(ICU_PATH + 'icustays.csv')
transfers = pd.read_csv(HOSP_PATH + 'transfers.csv')

# Diagnoses and procedures
diagnoses_icd = pd.read_csv(HOSP_PATH + 'diagnoses_icd.csv')
d_icd_diagnoses = pd.read_csv(HOSP_PATH + 'd_icd_diagnoses.csv')
procedures_icd = pd.read_csv(HOSP_PATH + 'procedures_icd.csv')

# Labs and microbiology
labevents = pd.read_csv(HOSP_PATH + 'labevents.csv.gz', 
                        usecols=['subject_id', 'hadm_id', 'itemid', 'charttime', 'value', 'valuenum'] , low_memory=False)
d_labitems = pd.read_csv(HOSP_PATH + 'd_labitems.csv')
microbiologyevents = pd.read_csv(HOSP_PATH + 'microbiologyevents.csv')

# ICU data
chartevents = pd.read_csv(ICU_PATH + 'chartevents.csv.gz',
                         usecols=['subject_id', 'hadm_id', 'stay_id', 'itemid', 'charttime', 'value', 'valuenum'] , low_memory=False)
d_items = pd.read_csv(ICU_PATH + 'd_items.csv')
procedureevents = pd.read_csv(ICU_PATH + 'procedureevents.csv')

# Pharmacy
prescriptions = pd.read_csv(HOSP_PATH + 'prescriptions.csv')

print(f"Loaded data for {len(patients)} patients")
print(f"Loaded {len(admissions)} admissions")
print(f"Loaded {len(icustays)} ICU stays")
print(f"Chart events: {len(chartevents):,} records")
print(f"Lab events: {len(labevents):,} records")

Loading core tables...


## 3. Define Central Line and CRBSI Identification

Based on the clinical presentation slides, we identify:
- Central line types: CVC, PICC, Hickman, Port-A, Swan-Ganz, Midline
- CRBSI criteria: Positive blood culture + catheter culture with same organism

In [None]:
# Central line procedure codes (ICD-9 and ICD-10)
CENTRAL_LINE_CODES = {
    # ICD-9
    '3893': 'Venous catheterization',
    '3895': 'Venous catheterization for renal dialysis',
    '8607': 'Insertion of totally implantable vascular access device',
    # ICD-10
    '02H60JZ': 'Insertion of central venous catheter',
    '02HV3JZ': 'Insertion of tunneled central venous catheter',
    '02HV33Z': 'Insertion of PICC',
    '05H033Z': 'Insertion of infusion device into superior vena cava'
}

# CRBSI-related diagnosis codes
CRBSI_CODES = {
    # ICD-9
    '99931': 'Infection due to central venous catheter',
    '99632': 'Bloodstream infection due to central venous catheter',
    # ICD-10  
    'T80211A': 'Bloodstream infection due to central venous catheter',
    'T80212A': 'Local infection due to central venous catheter',
    'T80218A': 'Other infection due to central venous catheter',
    'T80219A': 'Unspecified infection due to central venous catheter'
}

# Common CRBSI pathogens (from slide 10)
CRBSI_ORGANISMS = [
    'STAPHYLOCOCCUS AUREUS',
    'STAPHYLOCOCCUS, COAGULASE NEGATIVE',
    'ENTEROCOCCUS',
    'ESCHERICHIA COLI',
    'KLEBSIELLA PNEUMONIAE',
    'PSEUDOMONAS AERUGINOSA',
    'ACINETOBACTER BAUMANNII',
    'CANDIDA'
]

print("Central line and CRBSI criteria defined")

## 4. Identify Central Line Cohort

In [None]:
# Identify patients with central line procedures
central_line_procedures = procedures_icd[
    procedures_icd['icd_code'].isin(CENTRAL_LINE_CODES.keys())
].copy()

# Parse datetime
central_line_procedures['chartdate'] = pd.to_datetime(central_line_procedures['chartdate'])

# Merge with admissions and ICU stays
central_line_stays = icustays.merge(
    central_line_procedures[['subject_id', 'hadm_id', 'chartdate']],
    on=['subject_id', 'hadm_id'],
    how='inner'
)

# Parse ICU times
central_line_stays['intime'] = pd.to_datetime(central_line_stays['intime'])
central_line_stays['outtime'] = pd.to_datetime(central_line_stays['outtime'])

# Filter: central line placed during or before ICU stay
central_line_stays = central_line_stays[
    central_line_stays['chartdate'] <= central_line_stays['outtime']
]

# Calculate catheter duration (approximate)
central_line_stays['catheter_duration_hours'] = (
    central_line_stays['outtime'] - central_line_stays['chartdate']
).dt.total_seconds() / 3600

# Filter: Keep stays with catheter duration > 48 hours (clinical relevance)
central_line_stays = central_line_stays[
    central_line_stays['catheter_duration_hours'] >= 48
]

print(f"Identified {len(central_line_stays)} ICU stays with central lines")
print(f"Unique patients: {central_line_stays['subject_id'].nunique()}")
print(f"\nCatheter duration statistics (hours):")
print(central_line_stays['catheter_duration_hours'].describe())

## 5. Identify CRBSI Cases

CRBSI definition (from slides):
1. Same organism from blood sample AND catheter tip
2. Bacterial load from catheter : peripheral blood ≥ 10:1
3. DTP (Differential Time to Positive) ≥ 2 hours

In [None]:
# Identify CRBSI from diagnosis codes
crbsi_diagnoses = diagnoses_icd[
    diagnoses_icd['icd_code'].isin(CRBSI_CODES.keys())
][['subject_id', 'hadm_id', 'icd_code']].copy()

# Identify CRBSI from microbiology (blood culture positive for CRBSI organisms)
microbiologyevents['charttime'] = pd.to_datetime(microbiologyevents['charttime'])

blood_cultures = microbiologyevents[
    (microbiologyevents['spec_type_desc'].str.contains('BLOOD', case=False, na=False)) &
    (microbiologyevents['org_name'].str.upper().isin([org.upper() for org in CRBSI_ORGANISMS]))
].copy()

# Combine CRBSI identifications
crbsi_hadm_ids = set(crbsi_diagnoses['hadm_id'].unique()) | set(blood_cultures['hadm_id'].unique())

# Tag CRBSI in central line stays
central_line_stays['crbsi'] = central_line_stays['hadm_id'].isin(crbsi_hadm_ids).astype(int)

# For CRBSI cases, try to find the timing from microbiology
crbsi_timing = blood_cultures.groupby('hadm_id')['charttime'].min().reset_index()
crbsi_timing.columns = ['hadm_id', 'crbsi_time']

central_line_stays = central_line_stays.merge(crbsi_timing, on='hadm_id', how='left')

# Calculate time to CRBSI (for survival analysis)
central_line_stays['time_to_crbsi_hours'] = (
    central_line_stays['crbsi_time'] - central_line_stays['chartdate']
).dt.total_seconds() / 3600

# For non-CRBSI cases, time to event is catheter removal time (censored)
central_line_stays.loc[central_line_stays['crbsi'] == 0, 'time_to_crbsi_hours'] = \
    central_line_stays.loc[central_line_stays['crbsi'] == 0, 'catheter_duration_hours']

print(f"\nCRBSI Cases: {central_line_stays['crbsi'].sum()}")
print(f"Non-CRBSI Cases: {(central_line_stays['crbsi'] == 0).sum()}")
print(f"CRBSI Rate: {central_line_stays['crbsi'].mean():.2%}")
print(f"\nTime to CRBSI (hours) for CRBSI cases:")
print(central_line_stays[central_line_stays['crbsi'] == 1]['time_to_crbsi_hours'].describe())

## 6. Extract Static Features

Static features (from Enhanced Architecture Specification):
- Demographics: age, sex, BMI, race
- Catheter: type, site, duration at baseline, number of lumens, insertion location
- Clinical context: neutropenia, immunosuppression, ICU, mechanical ventilation, TPN, diabetes, CKD
- Baseline labs: WBC, CRP, albumin

In [None]:
def extract_static_features(stay_row, patients_df, admissions_df, diagnoses_df, labevents_df):
    """
    Extract static features for a given ICU stay
    """
    subject_id = stay_row['subject_id']
    hadm_id = stay_row['hadm_id']
    catheter_time = stay_row['chartdate']
    
    features = {}
    
    # ===== DEMOGRAPHICS =====
    patient = patients_df[patients_df['subject_id'] == subject_id].iloc[0]
    admission = admissions_df[admissions_df['hadm_id'] == hadm_id].iloc[0]
    
    # Age at admission
    features['age'] = patient['anchor_age']
    
    # Sex (M=1, F=0)
    features['sex'] = 1 if patient['gender'] == 'M' else 0
    
    # Race (simplified categories)
    race_map = {
        'WHITE': 0, 'BLACK': 1, 'HISPANIC': 2, 'ASIAN': 3, 'OTHER': 4
    }
    race = admission['race'].upper()
    for key in race_map:
        if key in race:
            features['race'] = race_map[key]
            break
    else:
        features['race'] = 4  # OTHER
    
    # BMI (would need to calculate from weight/height in chartevents - simplified here)
    features['bmi'] = np.nan  # Placeholder - requires height/weight extraction
    
    # ===== CATHETER CHARACTERISTICS =====
    # These would ideally come from procedure notes or itemids
    # For now, we'll use placeholders
    features['catheter_type'] = 0  # 0=CVC, 1=PICC, 2=Hickman, 3=Port, etc.
    features['insertion_site'] = 0  # 0=Subclavian, 1=Jugular, 2=Femoral
    features['catheter_duration_baseline'] = 0  # Days at prediction time
    features['number_of_lumens'] = 2  # 1, 2, or 3+
    features['insertion_location'] = 1  # 0=OR, 1=ICU, 2=bedside, 3=IR
    
    # ===== COMORBIDITIES (from diagnoses) =====
    patient_diagnoses = diagnoses_df[diagnoses_df['hadm_id'] == hadm_id]['icd_code'].values
    
    # Neutropenia (ICD codes)
    neutropenia_codes = ['28800', '28801', '28802', '28803', '28809', 'D70']
    features['neutropenia'] = int(any(code in str(patient_diagnoses) for code in neutropenia_codes))
    
    # Immunosuppression
    immunosupp_codes = ['279', 'D84', 'D89']
    features['immunosuppression'] = int(any(code in str(patient_diagnoses) for code in immunosupp_codes))
    
    # Diabetes
    diabetes_codes = ['250', 'E10', 'E11']
    features['diabetes'] = int(any(code in str(patient_diagnoses) for code in diabetes_codes))
    
    # CKD stage (simplified)
    ckd_codes = ['585', 'N18']
    features['ckd_stage'] = int(any(code in str(patient_diagnoses) for code in ckd_codes))
    
    # ICU admission
    features['icu_admission'] = 1  # All are in ICU by definition
    
    # Mechanical ventilation (would need procedureevents - placeholder)
    features['mechanical_ventilation'] = 0
    
    # TPN use (would need prescriptions - placeholder)
    features['tpn_use'] = 0
    
    # ===== BASELINE LABS =====
    # Get labs within 24h before catheter insertion
    baseline_labs = labevents_df[
        (labevents_df['subject_id'] == subject_id) &
        (labevents_df['charttime'] >= catheter_time - pd.Timedelta(hours=24)) &
        (labevents_df['charttime'] <= catheter_time)
    ]
    
    # WBC (itemid 51300, 51301)
    wbc = baseline_labs[baseline_labs['itemid'].isin([51300, 51301])]['valuenum'].median()
    features['baseline_wbc'] = wbc if not pd.isna(wbc) else 10.0  # Default normal
    
    # CRP (itemid 50889)
    crp = baseline_labs[baseline_labs['itemid'] == 50889]['valuenum'].median()
    features['baseline_crp'] = crp if not pd.isna(crp) else 5.0
    
    # Albumin (itemid 50862)
    albumin = baseline_labs[baseline_labs['itemid'] == 50862]['valuenum'].median()
    features['baseline_albumin'] = albumin if not pd.isna(albumin) else 3.5
    
    return features

print("Static feature extraction function defined")

In [None]:
# Extract static features for all stays (sample first 100 for testing)
print("Extracting static features...")

# Parse datetime columns in labevents if not already done
labevents['charttime'] = pd.to_datetime(labevents['charttime'])

static_features_list = []

for idx, row in tqdm(central_line_stays.head(100).iterrows(), total=100):
    features = extract_static_features(
        row, patients, admissions, diagnoses_icd, labevents
    )
    features['stay_id'] = row['stay_id']
    features['subject_id'] = row['subject_id']
    features['hadm_id'] = row['hadm_id']
    static_features_list.append(features)

static_features_df = pd.DataFrame(static_features_list)

print(f"\nExtracted static features for {len(static_features_df)} stays")
print(f"\nStatic features shape: {static_features_df.shape}")
print("\nStatic features preview:")
display(static_features_df.head())

## 7. Extract Temporal Features - Channel 1: Vital Signs

In [None]:
# Vital signs itemids (from MIMIC-IV d_items)
VITAL_ITEMIDS = {
    'heart_rate': [220045],  # Heart Rate
    'temperature': [223761, 223762],  # Temperature Fahrenheit, Celsius
    'sbp': [220050, 220179],  # Systolic BP (invasive and non-invasive)
    'dbp': [220051, 220180],  # Diastolic BP
    'map': [220052, 220181, 225312],  # Mean Arterial Pressure
    'respiratory_rate': [220210, 224690],  # Respiratory Rate
    'spo2': [220277],  # SpO2
    'gcs_total': [226755]  # Glasgow Coma Scale Total
}

def extract_vital_signs(stay_row, chartevents_df, window_hours=48):
    """
    Extract hourly vital signs for specified window
    Returns: DataFrame with hourly resampled vital signs
    """
    stay_id = stay_row['stay_id']
    catheter_time = stay_row['chartdate']
    
    # Define time window
    start_time = catheter_time
    end_time = catheter_time + pd.Timedelta(hours=window_hours)
    
    # Get all vital signs for this stay in the window
    vitals = chartevents_df[
        (chartevents_df['stay_id'] == stay_id) &
        (chartevents_df['charttime'] >= start_time) &
        (chartevents_df['charttime'] <= end_time) &
        (chartevents_df['itemid'].isin([item for items in VITAL_ITEMIDS.values() for item in items]))
    ].copy()
    
    if len(vitals) == 0:
        return None
    
    # Map itemids to feature names
    itemid_to_feature = {}
    for feature, itemids in VITAL_ITEMIDS.items():
        for itemid in itemids:
            itemid_to_feature[itemid] = feature
    
    vitals['feature'] = vitals['itemid'].map(itemid_to_feature)
    
    # Pivot to wide format
    vitals_pivot = vitals.pivot_table(
        index='charttime',
        columns='feature',
        values='valuenum',
        aggfunc='median'
    )
    
    # Resample to hourly and forward-fill
    vitals_hourly = vitals_pivot.resample('1H').median()
    vitals_hourly = vitals_hourly.fillna(method='ffill').fillna(method='bfill')
    
    # Ensure we have exactly window_hours rows
    expected_index = pd.date_range(start=start_time, periods=window_hours, freq='1H')
    vitals_hourly = vitals_hourly.reindex(expected_index)
    vitals_hourly = vitals_hourly.fillna(method='ffill').fillna(method='bfill')
    
    # Add time index
    vitals_hourly['hour'] = range(len(vitals_hourly))
    
    return vitals_hourly

print("Vital signs extraction function defined")

## 8. Extract Temporal Features - Channel 2: Laboratory Values

In [None]:
# Lab itemids
LAB_ITEMIDS = {
    'wbc_count': [51300, 51301],  # White Blood Cells
    'neutrophil_count': [51256],  # Absolute Neutrophil Count
    'crp': [50889],  # C-Reactive Protein
    'procalcitonin': [51449],  # Procalcitonin
    'lactate': [50813],  # Lactate
    'platelets': [51265],  # Platelet Count
    'creatinine': [50912]  # Creatinine
}

def extract_lab_values(stay_row, labevents_df, window_days=7):
    """
    Extract lab values over window period
    Returns: DataFrame with daily resampled labs
    """
    subject_id = stay_row['subject_id']
    hadm_id = stay_row['hadm_id']
    catheter_time = stay_row['chartdate']
    
    # Define time window
    start_time = catheter_time
    end_time = catheter_time + pd.Timedelta(days=window_days)
    
    # Get all labs for this admission in the window
    labs = labevents_df[
        (labevents_df['subject_id'] == subject_id) &
        (labevents_df['hadm_id'] == hadm_id) &
        (labevents_df['charttime'] >= start_time) &
        (labevents_df['charttime'] <= end_time) &
        (labevents_df['itemid'].isin([item for items in LAB_ITEMIDS.values() for item in items]))
    ].copy()
    
    if len(labs) == 0:
        return None
    
    # Map itemids to feature names
    itemid_to_feature = {}
    for feature, itemids in LAB_ITEMIDS.items():
        for itemid in itemids:
            itemid_to_feature[itemid] = feature
    
    labs['feature'] = labs['itemid'].map(itemid_to_feature)
    
    # Pivot to wide format
    labs_pivot = labs.pivot_table(
        index='charttime',
        columns='feature',
        values='valuenum',
        aggfunc='median'
    )
    
    # Resample to 12-hour intervals (2 per day)
    labs_12h = labs_pivot.resample('12H').median()
    labs_12h = labs_12h.fillna(method='ffill').fillna(method='bfill')
    
    # Ensure we have expected number of measurements
    expected_index = pd.date_range(start=start_time, periods=window_days*2, freq='12H')
    labs_12h = labs_12h.reindex(expected_index)
    labs_12h = labs_12h.fillna(method='ffill').fillna(method='bfill')
    
    # Add time index
    labs_12h['measurement'] = range(len(labs_12h))
    
    return labs_12h

print("Lab values extraction function defined")

## 9. Extract Temporal Features - Channel 3: Catheter Events

In [None]:
# Catheter care event itemids (these are examples - may need adjustment)
CATHETER_EVENT_ITEMIDS = {
    'dressing_change': [225111, 225112],  # Dressing changes
    'line_access': [225158],  # Line access
    'blood_draw': [225168, 227719],  # Blood draw from line
    'medication_admin': [225168],  # Medication administration
    'line_flush': [225166],  # Line flush
    'site_assessment': [224263]  # Site assessment/inflammation
}

def extract_catheter_events(stay_row, procedureevents_df, chartevents_df, window_days=14):
    """
    Extract catheter care events over window period
    Returns: DataFrame with daily event counts
    """
    stay_id = stay_row['stay_id']
    catheter_time = stay_row['chartdate']
    
    # Define time window
    start_time = catheter_time
    end_time = catheter_time + pd.Timedelta(days=window_days)
    
    # Get events from both procedureevents and chartevents
    procedure_events = procedureevents_df[
        (procedureevents_df['stay_id'] == stay_id) &
        (procedureevents_df['starttime'] >= start_time) &
        (procedureevents_df['starttime'] <= end_time)
    ].copy()
    
    chart_events = chartevents_df[
        (chartevents_df['stay_id'] == stay_id) &
        (chartevents_df['charttime'] >= start_time) &
        (chartevents_df['charttime'] <= end_time) &
        (chartevents_df['itemid'].isin([item for items in CATHETER_EVENT_ITEMIDS.values() for item in items]))
    ].copy()
    
    # Create daily bins
    date_range = pd.date_range(start=start_time.date(), periods=window_days, freq='D')
    
    catheter_events = pd.DataFrame(index=date_range)
    catheter_events['day'] = range(len(catheter_events))
    
    # Count daily events (simplified - would need more detailed mapping)
    if len(chart_events) > 0:
        chart_events['date'] = chart_events['charttime'].dt.date
        daily_counts = chart_events.groupby('date').size()
        catheter_events['line_access_count'] = catheter_events.index.map(daily_counts).fillna(0)
    else:
        catheter_events['line_access_count'] = 0
    
    if len(procedure_events) > 0:
        procedure_events['date'] = procedure_events['starttime'].dt.date
        daily_counts = procedure_events.groupby('date').size()
        catheter_events['medication_admin_count'] = catheter_events.index.map(daily_counts).fillna(0)
    else:
        catheter_events['medication_admin_count'] = 0
    
    # Add placeholder features (would need more detailed extraction)
    catheter_events['blood_draw_count'] = 0
    catheter_events['dressing_change'] = 0
    catheter_events['site_assessment_score'] = 0
    catheter_events['line_flush_count'] = 0
    
    return catheter_events

print("Catheter events extraction function defined")

## 10. Create Complete Patient Dataset

In [None]:
# Parse datetime in chartevents if not already done
chartevents['charttime'] = pd.to_datetime(chartevents['charttime'])

# Extract temporal features for a sample of patients
print("Extracting temporal features for sample patients...")

complete_dataset = []

for idx, row in tqdm(central_line_stays.head(10).iterrows(), total=10):
    patient_data = {
        'stay_id': row['stay_id'],
        'subject_id': row['subject_id'],
        'hadm_id': row['hadm_id'],
        'crbsi': row['crbsi'],
        'time_to_event': row['time_to_crbsi_hours'],
        'event': row['crbsi']  # 1 if CRBSI occurred, 0 if censored
    }
    
    # Extract static features
    static_feats = extract_static_features(row, patients, admissions, diagnoses_icd, labevents)
    patient_data['static_features'] = static_feats
    
    # Extract temporal features
    try:
        vitals = extract_vital_signs(row, chartevents, window_hours=FEATURE_WINDOW_HOURS)
        labs = extract_lab_values(row, labevents, window_days=7)
        catheter = extract_catheter_events(row, procedureevents, chartevents, window_days=14)
        
        if vitals is not None:
            patient_data['vital_signs'] = vitals
        if labs is not None:
            patient_data['lab_values'] = labs
        if catheter is not None:
            patient_data['catheter_events'] = catheter
            
        complete_dataset.append(patient_data)
    except Exception as e:
        print(f"Error processing stay_id {row['stay_id']}: {e}")
        continue

print(f"\nSuccessfully processed {len(complete_dataset)} patients")

## 11. Prepare Data for Model Input

In [None]:
def prepare_model_inputs(patient_data, vital_seq_len=48, lab_seq_len=14, catheter_seq_len=14):
    """
    Prepare data in the format required by SMTAFormer
    
    Returns:
        - static_vector: (m,) array of static features
        - temporal_sequences: list of 3 arrays for [vitals, labs, catheter_events]
        - labels: dict with 'binary', 'time', 'event', 'decision'
    """
    # Static features
    static_dict = patient_data['static_features']
    static_keys = ['age', 'sex', 'race', 'bmi', 'catheter_type', 'insertion_site',
                   'catheter_duration_baseline', 'number_of_lumens', 'insertion_location',
                   'neutropenia', 'immunosuppression', 'icu_admission', 'mechanical_ventilation',
                   'tpn_use', 'diabetes', 'ckd_stage', 'baseline_wbc', 'baseline_crp', 'baseline_albumin']
    
    static_vector = np.array([static_dict.get(k, 0) for k in static_keys], dtype=np.float32)
    
    # Temporal sequences
    temporal_sequences = []
    
    # Vital signs: (seq_len, n_features)
    if 'vital_signs' in patient_data:
        vitals = patient_data['vital_signs']
        vital_cols = ['heart_rate', 'temperature', 'sbp', 'dbp', 'map', 'respiratory_rate', 'spo2', 'gcs_total']
        vital_array = vitals[vital_cols].values[:vital_seq_len].astype(np.float32)
        
        # Pad if necessary
        if len(vital_array) < vital_seq_len:
            padding = np.zeros((vital_seq_len - len(vital_array), len(vital_cols)), dtype=np.float32)
            vital_array = np.vstack([vital_array, padding])
        
        temporal_sequences.append(vital_array)
    else:
        temporal_sequences.append(np.zeros((vital_seq_len, 8), dtype=np.float32))
    
    # Lab values: (seq_len, n_features)
    if 'lab_values' in patient_data:
        labs = patient_data['lab_values']
        lab_cols = ['wbc_count', 'neutrophil_count', 'crp', 'procalcitonin', 'lactate', 'platelets', 'creatinine']
        lab_array = labs[lab_cols].values[:lab_seq_len].astype(np.float32)
        
        if len(lab_array) < lab_seq_len:
            padding = np.zeros((lab_seq_len - len(lab_array), len(lab_cols)), dtype=np.float32)
            lab_array = np.vstack([lab_array, padding])
        
        temporal_sequences.append(lab_array)
    else:
        temporal_sequences.append(np.zeros((lab_seq_len, 7), dtype=np.float32))
    
    # Catheter events: (seq_len, n_features)
    if 'catheter_events' in patient_data:
        catheter = patient_data['catheter_events']
        catheter_cols = ['line_access_count', 'blood_draw_count', 'medication_admin_count',
                        'dressing_change', 'site_assessment_score', 'line_flush_count']
        catheter_array = catheter[catheter_cols].values[:catheter_seq_len].astype(np.float32)
        
        if len(catheter_array) < catheter_seq_len:
            padding = np.zeros((catheter_seq_len - len(catheter_array), len(catheter_cols)), dtype=np.float32)
            catheter_array = np.vstack([catheter_array, padding])
        
        temporal_sequences.append(catheter_array)
    else:
        temporal_sequences.append(np.zeros((catheter_seq_len, 6), dtype=np.float32))
    
    # Labels
    labels = {
        'binary': patient_data['crbsi'],  # Binary CRBSI occurrence
        'time': patient_data['time_to_event'],  # Time to event (hours)
        'event': patient_data['event'],  # Event indicator (1=event, 0=censored)
        'decision': 2  # Decision label (0=remove_now, 1=remove_24h, 2=continue) - needs clinical rule
    }
    
    # Generate decision label based on risk and clinical necessity
    if labels['binary'] == 1 and labels['time'] < 24:
        labels['decision'] = 0  # REMOVE_IMMEDIATELY
    elif labels['binary'] == 1 and labels['time'] < 72:
        labels['decision'] = 1  # REMOVE_WITHIN_24H
    else:
        labels['decision'] = 2  # CONTINUE_WITH_MONITORING
    
    return static_vector, temporal_sequences, labels

print("Model input preparation function defined")

In [None]:
# Prepare all patient data for model
print("Preparing model inputs...")

model_data = []

for patient_data in complete_dataset:
    try:
        static, temporal, labels = prepare_model_inputs(patient_data)
        
        model_data.append({
            'stay_id': patient_data['stay_id'],
            'subject_id': patient_data['subject_id'],
            'static': static,
            'temporal': temporal,
            'labels': labels
        })
    except Exception as e:
        print(f"Error preparing data for stay {patient_data['stay_id']}: {e}")
        continue

print(f"\nPrepared {len(model_data)} samples for model training")

# Show example
if len(model_data) > 0:
    example = model_data[0]
    print(f"\nExample data structure:")
    print(f"Static features shape: {example['static'].shape}")
    print(f"Vital signs shape: {example['temporal'][0].shape}")
    print(f"Lab values shape: {example['temporal'][1].shape}")
    print(f"Catheter events shape: {example['temporal'][2].shape}")
    print(f"Labels: {example['labels']}")

## 12. Data Normalization and Imputation

In [None]:
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer

def normalize_and_impute(model_data):
    """
    Normalize static and temporal features
    Impute missing values
    """
    # Extract all static features for normalization
    all_static = np.array([d['static'] for d in model_data])
    
    # Impute missing values (mean strategy)
    imputer = SimpleImputer(strategy='mean')
    all_static_imputed = imputer.fit_transform(all_static)
    
    # Normalize (z-score)
    scaler_static = StandardScaler()
    all_static_normalized = scaler_static.fit_transform(all_static_imputed)
    
    # Update static features in model_data
    for i, data in enumerate(model_data):
        data['static'] = all_static_normalized[i]
    
    # Normalize temporal features
    for channel_idx in range(3):  # 3 temporal channels
        # Collect all data for this channel
        all_temporal = np.vstack([d['temporal'][channel_idx] for d in model_data])
        
        # Reshape for normalization
        n_samples = len(model_data)
        seq_len, n_features = model_data[0]['temporal'][channel_idx].shape
        
        # Impute and normalize
        all_temporal_reshaped = all_temporal.reshape(-1, n_features)
        all_temporal_imputed = imputer.fit_transform(all_temporal_reshaped)
        
        scaler_temporal = StandardScaler()
        all_temporal_normalized = scaler_temporal.fit_transform(all_temporal_imputed)
        
        # Reshape back
        all_temporal_normalized = all_temporal_normalized.reshape(n_samples, seq_len, n_features)
        
        # Update in model_data
        for i, data in enumerate(model_data):
            data['temporal'][channel_idx] = all_temporal_normalized[i]
    
    return model_data, scaler_static

# Apply normalization
print("Normalizing and imputing data...")
model_data_normalized, static_scaler = normalize_and_impute(model_data)

print("\nNormalization complete!")
print(f"Static features - Mean: {model_data_normalized[0]['static'].mean():.3f}, Std: {model_data_normalized[0]['static'].std():.3f}")
print(f"Temporal features (vitals) - Mean: {model_data_normalized[0]['temporal'][0].mean():.3f}, Std: {model_data_normalized[0]['temporal'][0].std():.3f}")

## 13. Save Processed Data

In [None]:
import pickle

# Save processed data
print("Saving processed data...")

# Save as pickle
with open(OUTPUT_PATH + 'crbsi_processed_data.pkl', 'wb') as f:
    pickle.dump(model_data_normalized, f)

# Save scaler for future use
with open(OUTPUT_PATH + 'static_scaler.pkl', 'wb') as f:
    pickle.dump(static_scaler, f)

# Save metadata
metadata = {
    'n_samples': len(model_data_normalized),
    'n_crbsi_cases': sum(d['labels']['binary'] for d in model_data_normalized),
    'feature_window_hours': FEATURE_WINDOW_HOURS,
    'prediction_window_hours': PREDICTION_WINDOW_HOURS,
    'static_features': static_keys,
    'vital_signs_features': ['heart_rate', 'temperature', 'sbp', 'dbp', 'map', 'respiratory_rate', 'spo2', 'gcs_total'],
    'lab_features': ['wbc_count', 'neutrophil_count', 'crp', 'procalcitonin', 'lactate', 'platelets', 'creatinine'],
    'catheter_event_features': ['line_access_count', 'blood_draw_count', 'medication_admin_count', 'dressing_change', 'site_assessment_score', 'line_flush_count']
}

with open(OUTPUT_PATH + 'metadata.pkl', 'wb') as f:
    pickle.dump(metadata, f)

print(f"\nData saved to {OUTPUT_PATH}")
print(f"Total samples: {metadata['n_samples']}")
print(f"CRBSI cases: {metadata['n_crbsi_cases']}")
print(f"CRBSI rate: {metadata['n_crbsi_cases']/metadata['n_samples']:.2%}")

## 14. Data Summary and Visualization

In [None]:
# Visualize class distribution
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Binary outcome distribution
ax = axes[0, 0]
binary_labels = [d['labels']['binary'] for d in model_data_normalized]
ax.bar(['No CRBSI', 'CRBSI'], [binary_labels.count(0), binary_labels.count(1)])
ax.set_title('Binary Outcome Distribution')
ax.set_ylabel('Count')

# Time to event distribution
ax = axes[0, 1]
times = [d['labels']['time'] for d in model_data_normalized if d['labels']['event'] == 1]
ax.hist(times, bins=20, edgecolor='black')
ax.set_title('Time to CRBSI (for events)')
ax.set_xlabel('Hours')
ax.set_ylabel('Count')

# Decision distribution
ax = axes[1, 0]
decisions = [d['labels']['decision'] for d in model_data_normalized]
decision_names = ['Remove Now', 'Remove 24h', 'Continue']
ax.bar(decision_names, [decisions.count(0), decisions.count(1), decisions.count(2)])
ax.set_title('Decision Label Distribution')
ax.set_ylabel('Count')

# Example vital signs trajectory
ax = axes[1, 1]
if len(model_data_normalized) > 0:
    example_vitals = model_data_normalized[0]['temporal'][0]  # First patient's vitals
    ax.plot(example_vitals[:, 0], label='Heart Rate')
    ax.plot(example_vitals[:, 1], label='Temperature')
    ax.plot(example_vitals[:, 6], label='SpO2')
    ax.set_title('Example Vital Signs Trajectory (Normalized)')
    ax.set_xlabel('Hour')
    ax.set_ylabel('Normalized Value')
    ax.legend()

plt.tight_layout()
plt.savefig(OUTPUT_PATH + 'data_summary.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nData summary visualization saved!")

## 15. Next Steps

The preprocessed data is now ready for model training. Next steps:

1. **Split data** into train/validation/test sets
2. **Implement SMTAFormer** architecture (3 prediction heads)
3. **Define loss functions** for multi-task learning:
   - Binary: Focal Loss
   - Survival: Cox Partial Likelihood
   - Decision: Custom decision loss
4. **Train model** with early stopping
5. **Evaluate** on test set:
   - Binary: AUC, Precision, Recall
   - Survival: C-index
   - Decision: Accuracy, Confusion Matrix
6. **Clinical validation** and interpretation

See the Enhanced_CRBSI_Architecture_Specification.md for detailed model architecture.

In [None]:
print("="*80)
print("PREPROCESSING COMPLETE!")
print("="*80)
print(f"\nProcessed {len(model_data_normalized)} patient samples")
print(f"CRBSI cases: {sum(d['labels']['binary'] for d in model_data_normalized)}")
print(f"\nData saved to: {OUTPUT_PATH}")
print(f"\nReady for model training!")