In [None]:
print("--- Installing Libraries ---")
# Install Gradio and PyArrow. Tensorflow is pre-installed
!pip install gradio pyarrow -q
import os
import json
import pandas as pd
import numpy as np
import zipfile
from datetime import datetime
from google.colab import drive
import gc
import shutil
print("\n--- Cleaning up previous data... ---")
!rm -rf /content/fhir/
!rm -rf /content/output/
!rm -rf /content/temp_parquet/
!rm -f /content/fhir.zip  # We only need to clean up the zip file
# Create the target directory
os.makedirs('/content/fhir', exist_ok=True)
print("\n--- Mounting Google Drive... ---")
drive.mount('/content/drive', force_remount=True) # Force remount to be safe
print("\n--- Copying fhir.zip from Drive to Colab's local disk... ---")
drive_zip_path = '/content/drive/My Drive/fhir.zip'
local_zip_path = '/content/fhir.zip'

if not os.path.exists(drive_zip_path):
    print(f"--- WARNING: {drive_zip_path} not found in Google Drive. Stopping. ---")
else:
    print(f"Copying {drive_zip_path}...")
    !cp "{drive_zip_path}" "{local_zip_path}"
    print(f"Copy complete: {local_zip_path}")
print("\n--- Unpacking local fhir.zip... ---")
# This is the path to the file we just copied
local_zip_path = '/content/fhir.zip'

if os.path.exists(local_zip_path):
    print(f"Extracting {local_zip_path}...")
    with zipfile.ZipFile(local_zip_path, 'r') as zip_ref:
        file_list = zip_ref.filelist
        # Find all non-empty JSON files within the zip
        json_files = [f for f in file_list if f.filename.endswith('.json') and f.file_size > 0]
        print(f"Found {len(json_files)} non-empty JSON files in zip.")

        # Extract only the non-empty JSON files
        for i, file_info in enumerate(json_files):
            if i % 5000 == 0:
                print(f"    Extracting file {i}/{len(json_files)}...")
            zip_ref.extract(file_info, '/content/') # Extract to /content/

    print(f"--- Finished processing: {local_zip_path} ---")
    # Clean up the local archive to save disk space
    os.remove(local_zip_path)
    print(f"Removed local archive: {local_zip_path}")
else:
    print(f"--- WARNING: {local_zip_path} was not found. Skipping extraction. ---")
print("\n--- All archives processed! Consolidating files... ---")
# This logic is important as fhir.zip often extracts into a nested folder.
# This code finds that nested folder and moves all files to /content/fhir/

if os.path.exists('/content/output/fhir'):
    print("Consolidating from /content/output/fhir...")
    !mv /content/output/fhir/* /content/fhir/
    !rm -rf /content/output

# Also check if files were extracted inside a 'fhir' folder (e.g., /content/fhir/fhir/)
if os.path.exists('/content/fhir/fhir'):
    print("Consolidating from /content/fhir/fhir...")
    !mv /content/fhir/fhir/* /content/fhir/
    !rm -rf /content/fhir/fhir

if os.path.exists('/content/fhir'):
    print("Final 'fhir' directory is ready. Checking file count...")
    # Count only the JSON files that will be processed
    file_count = len([f for f in os.listdir('/content/fhir') if f.endswith('.json')])
    print(f"Total .json files in /content/fhir: {file_count}")
else:
    print("--- ERROR: Final 'fhir' directory not found. Please check archive contents. ---")

In [None]:
from concurrent.futures import ProcessPoolExecutor, as_completed
import os
import json
import pandas as pd
import numpy as np
import shutil
import gc
from datetime import datetime

# --- NEW: A robust mapping of standard codes to our simple text flags ---
CONDITION_CODE_MAP = {
    # ICD-10 Codes
    'I10': 'hypertension', 'E11': 'diabetes', 'I50': 'congestive heart failure', 'J44': 'copd', 'D64': 'anemia',
    'E78': 'hyperlipidemia', 'I25': 'coronary artery disease', 'N18': 'chronic kidney disease', 'E66': 'obesity',
    'J45': 'asthma', 'J18': 'pneumonia', 'J11': 'influenza', 'I48': 'atrial fibrillation',
    'I82': 'dvt', 'I26': 'pulmonary embolism', 'E03': 'hypothyroidism', 'E05': 'hyperthyroidism',
    'M10': 'gout', 'M19': 'osteoarthritis', 'M06': 'rheumatoid arthritis', 'K85': 'pancreatitis',
    'K92': 'gi bleed', 'K50': 'crohn\'s disease', 'K51': 'ulcerative colitis', 'L03': 'cellulitis',
    'N39': 'urinary tract infection', 'G40': 'epilepsy', 'R56': 'seizure',
    'G30': 'alzheimer\'s', 'F03': 'dementia', 'F41': 'anxiety', 'F32': 'depression',

    # --- NEWLY ADDED 14 CONDITIONS (ICD-10) ---
    'K21': 'gerd', 'K21.9': 'gerd',
    'J30': 'allergies', 'J30.9': 'allergies',
    'G43': 'migraine', 'G43.9': 'migraine',
    'G47.33': 'sleep apnea',
    'M54.5': 'low back pain',
    'F17.2': 'nicotine dependence',
    'F10.2': 'alcohol use disorder',
    'L20': 'atopic dermatitis', 'L20.9': 'atopic dermatitis',
    'D50.9': 'iron deficiency', 'D50.8': 'iron deficiency',
    'M81.0': 'osteoporosis',
    'G47.00': 'insomnia',
    'K58.0': 'ibs', 'K58.9': 'ibs',
    'J32.9': 'sinusitis', 'J32.4': 'sinusitis',
    'R42': 'dizziness', 'H81.10': 'vertigo',

    # SNOMED CT Codes
    '59621000': 'hypertension', '44054006': 'diabetes', '42343007': 'congestive heart failure',
    '13645005': 'copd', '271737000': 'anemia', '55822004': 'hyperlipidemia', '53741008': 'coronary artery disease',
    '709044004': 'chronic kidney disease', '414915002': 'obesity', '195967001': 'asthma',
    '233604007': 'pneumonia', '6142004': 'influenza', '49436004': 'atrial fibrillation',
    '128063001': 'dvt', '103228002': 'pulmonary embolism', '40930008': 'hypothyroidism', '42725007': 'hyperthyroidism',
    '90560007': 'gout', '396275006': 'osteoarthritis', '69896004': 'rheumatoid arthritis', '71831004': 'pancreatitis',
    '74474003': 'gi bleed', '34000006': 'crohn\'s disease', '64766004': 'ulcerative colitis', '23131000': 'cellulitis',
    '68566005': 'urinary tract infection', '84757009': 'epilepsy', '91175000': 'seizure', '26929004': 'alzheimer\'s',
    '52448006': 'dementia', '48694002': 'anxiety', '35489007': 'depression', '91302008': 'sepsis',

    # --- NEWLY ADDED 14 CONDITIONS (SNOMED) ---
    '69808002': 'gerd',
    '61598001': 'allergies',
    '3860009': 'migraine',
    '73430006': 'sleep apnea',
    '279039007': 'low back pain',
    '77176002': 'nicotine dependence',
    '65363002': 'alcohol use disorder',
    '24079001': 'atopic dermatitis',
    '34443006': 'iron deficiency',
    '64859006': 'osteoporosis',
    '193462001': 'insomnia',
    '44796009': 'ibs',
    '40055000': 'sinusitis',
    '48340000': 'vertigo', '271789005': 'dizziness'
}
# --- END OF MODIFICATIONS ---


def parse_fhir_bundle_24h(file_path):
    try:
        if os.path.getsize(file_path) == 0: return None
        with open(file_path, 'r') as f: bundle = json.load(f)
        if not isinstance(bundle, dict) or 'entry' not in bundle: return None

        patient_id = os.path.basename(file_path).split('.')[0]
        patient_features = {'patient_id': patient_id}
        observations, conditions, medications, procedures = [], set(), set(), set()

        for entry in bundle.get('entry', []):
            resource = entry.get('resource', {})
            r_type = resource.get('resourceType')

            if r_type == 'Patient':
                patient_features['gender'] = resource.get('gender')
                birth_date = resource.get('birthDate')
                if birth_date:
                    try: patient_features['age'] = (datetime(2025, 1, 1) - datetime.strptime(birth_date, '%Y-%m-%d')).days // 365
                    except: patient_features['age'] = None

            elif r_type == 'Condition':
                text = resource.get('code', {}).get('text', '').lower()
                if text: conditions.add(text)
                for coding in resource.get('code', {}).get('coding', []):
                    code = coding.get('code')
                    if code in CONDITION_CODE_MAP:
                        conditions.add(CONDITION_CODE_MAP[code])

            elif r_type == 'Observation':
                obs_time = resource.get('effectiveDateTime')
                if obs_time: observations.append((datetime.fromisoformat(obs_time.replace('Z', '+00:00')), resource))

            elif r_type == 'MedicationRequest':
                text = resource.get('medicationCodeableConcept', {}).get('text', '').lower()
                if text: medications.add(text)
                for coding in resource.get('medicationCodeableConcept', {}).get('coding', []):
                    text = coding.get('display', '').lower()
                    if text: medications.add(text)

            elif r_type == 'Procedure':
                text = resource.get('code', {}).get('text', '').lower()
                if text: procedures.add(text)
                for coding in resource.get('code', {}).get('coding', []):
                    text = coding.get('display', '').lower()
                    if text: procedures.add(text)

        if not observations: return None
        observations.sort(key=lambda x: x[0])

        # --- Static Feature Calculation (Efficient) ---
        patient_features['history_hypertension'] = int(any('hypertension' in c for c in conditions))
        patient_features['history_diabetes'] = int(any('diabetes' in c for c in conditions))
        patient_features['history_chf'] = int(any('congestive heart failure' in c for c in conditions))
        patient_features['history_copd'] = int(any('copd' in c for c in conditions))
        patient_features['history_anemia'] = int(any('anemia' in c for c in conditions))
        patient_features['history_hyperlipidemia'] = int(any('hyperlipidemia' in c for c in conditions))
        patient_features['history_cad'] = int(any('coronary artery disease' in c for c in conditions))
        patient_features['history_ckd'] = int(any('chronic kidney disease' in c for c in conditions))
        patient_features['history_obesity'] = int(any('obesity' in c for c in conditions))
        patient_features['history_asthma'] = int(any('asthma' in c for c in conditions))
        patient_features['history_pneumonia'] = int(any('pneumonia' in c for c in conditions))
        patient_features['history_flu'] = int(any('influenza' in c for c in conditions))
        patient_features['on_statin'] = int(any('statin' in m for m in medications))
        patient_features['on_metformin'] = int(any('metformin' in m for m in medications))
        patient_features['on_aspirin'] = int(any('aspirin' in m for m in medications))
        patient_features['on_lisinopril'] = int(any('lisinopril' in m for m in medications))
        patient_features['history_cabg'] = int(any('cabg' in p or 'coronary artery bypass' in p for p in procedures))
        patient_features['history_appendectomy'] = int(any('appendectomy' in p for p in procedures))
        patient_features['history_afib'] = int(any('atrial fibrillation' in c for c in conditions))
        patient_features['history_dvt_pe'] = int(any(s in c for s in ['dvt', 'deep vein thrombosis', 'pulmonary embolism'] for c in conditions))
        patient_features['history_thyroid'] = int(any(s in c for s in ['hypothyroidism', 'hyperthyroidism', 'thyroid disease'] for c in conditions))
        patient_features['history_gout'] = int(any('gout' in c for c in conditions))
        patient_features['history_arthritis'] = int(any(s in c for s in ['osteoarthritis', 'rheumatoid arthritis'] for c in conditions))
        patient_features['history_pancreatitis'] = int(any('pancreatitis' in c for c in conditions))
        patient_features['history_gi_bleed'] = int(any(s in c for s in ['gi bleed', 'gastrointestinal bleed', 'varices'] for c in conditions))
        patient_features['history_ibd'] = int(any(s in c for s in ["crohn's disease", 'ulcerative colitis'] for c in conditions))
        patient_features['history_cellulitis'] = int(any('cellulitis' in c for c in conditions))
        patient_features['history_uti'] = int(any('urinary tract infection' in c for c in conditions))
        patient_features['history_seizure'] = int(any(s in c for s in ['epilepsy', 'seizure'] for c in conditions))
        patient_features['history_dementia'] = int(any(s in c for s in ['dementia', "alzheimer's"] for c in conditions))
        patient_features['history_anxiety'] = int(any('anxiety' in c for c in conditions))
        patient_features['history_depression'] = int(any('depression' in c for c in conditions))

        patient_features['on_anticoagulant'] = int(any(s in m for s in ['warfarin', 'eliquis', 'xarelto', 'heparin'] for m in medications))
        patient_features['on_beta_blocker'] = int(any(s in m for s in ['metoprolol', 'carvedilol', 'atenolol'] for m in medications))
        patient_features['on_diuretic'] = int(any(s in m for s in ['furosemide', 'hydrochlorothiazide', 'spironolactone'] for m in medications))
        patient_features['on_ccb'] = int(any(s in m for s in ['amlodipine', 'diltiazem'] for m in medications))
        patient_features['on_insulin'] = int(any('insulin' in m for m in medications))
        patient_features['on_ppi'] = int(any(s in m for s in ['omeprazole', 'pantoprazole', 'esomeprazole'] for m in medications))
        patient_features['on_thyroid_med'] = int(any('levothyroxine' in m for m in medications))
        patient_features['on_gout_med'] = int(any('allopurinol' in m for m in medications))
        patient_features['proc_colonoscopy'] = int(any('colonoscopy' in p for p in procedures))
        patient_features['proc_egd'] = int(any(s in p for s in ['egd', 'esophagogastroduodenoscopy'] for p in procedures))
        patient_features['proc_dialysis'] = int(any('dialysis' in p for p in procedures))

        # --- NEWLY ADDED 14 HISTORY FLAGS ---
        patient_features['history_gerd'] = int(any('gerd' in c for c in conditions))
        patient_features['history_allergies'] = int(any('allergies' in c for c in conditions))
        patient_features['history_migraine'] = int(any('migraine' in c for c in conditions))
        patient_features['history_sleep_apnea'] = int(any('sleep apnea' in c for c in conditions))
        patient_features['history_low_back_pain'] = int(any('low back pain' in c for c in conditions))
        patient_features['history_smoker'] = int(any('nicotine dependence' in c for c in conditions))
        patient_features['history_alcohol_use'] = int(any('alcohol use disorder' in c for c in conditions))
        patient_features['history_eczema'] = int(any('atopic dermatitis' in c for c in conditions))
        patient_features['history_iron_deficiency'] = int(any('iron deficiency' in c for c in conditions))
        patient_features['history_osteoporosis'] = int(any('osteoporosis' in c for c in conditions))
        patient_features['history_insomnia'] = int(any('insomnia' in c for c in conditions))
        patient_features['history_ibs'] = int(any('ibs' in c for c in conditions))
        patient_features['history_sinusitis'] = int(any('sinusitis' in c for c in conditions))
        patient_features['history_dizziness'] = int(any(s in c for s in ['dizziness', 'vertigo'] for c in conditions))
        # --- END OF MODIFICATIONS ---

        obs_map = {
            '8480-6':'sbp', '8462-4':'dbp', '8867-4':'heart_rate', '9279-1':'respiratory_rate', '59408-5':'spo2',
            '8310-5':'temperature', '6690-2':'wbc_count', '718-7':'hemoglobin', '4544-3':'hematocrit', '777-3':'platelet_count',
            '2093-3':'cholesterol_total', '2571-8':'triglycerides', '18262-6':'cholesterol_ldl', '2085-9':'cholesterol_hdl',
            '72514-3':'pain_score', '48643-1':'hba1c', '55758-7':'phq2_score', '70274-6':'gad7_score', '8302-2':'height_m',
            '29463-7':'weight_kg', '3094-0':'bun', '2951-2':'sodium', '2823-3':'potassium', '2345-7':'glucose',
            '1975-2':'bilirubin', '2160-0':'creatinine', '10839-9': 'troponin_i', '1988-5': 'crp', '2524-7': 'lactate',
            '33959-8': 'procalcitonin', '48065-7': 'd_dimer', '5821-4': 'urine_wbc', '5811-5': 'urine_nitrite',
            '5799-2': 'urine_leukocyte', '14563-1': 'stool_occult_blood', '1751-7': 'albumin', '1742-6': 'alt',
            '1920-8': 'ast', '2991-8': 'tsh', '3005-4': 'free_t4', '14338-8': 'lipase', '1798-8': 'amylase',
            '34714-6': 'inr', '3173-2': 'ptt', '2601-5': 'magnesium', '2777-1': 'phosphate', '11568-8': 'rheumatoid_factor'
        }

        patient_rows = []
        for obs_time, obs in observations:
            row = patient_features.copy()
            row['obs_time'] = obs_time
            if 'component' in obs:
                for comp in obs['component']:
                    code = comp.get('code', {}).get('coding', [{}])[0].get('code')
                    if code in obs_map: row[obs_map[code]] = comp.get('valueQuantity', {}).get('value')
            else:
                code = obs.get('code', {}).get('coding', [{}])[0].get('code')
                if code in obs_map: row[obs_map[code]] = obs.get('valueQuantity', {}).get('value')
            patient_rows.append(row)
        return patient_rows
    except Exception as e:
        return None

def process_and_save(file_path, output_dir):
    patient_data = parse_fhir_bundle_24h(file_path) # Uses the new efficient/robust version
    if patient_data:
        patient_id = os.path.basename(file_path).split('.')[0]
        pd.DataFrame(patient_data).to_parquet(
            os.path.join(output_dir, f'{patient_id}.parquet'),
            compression='snappy', index=False
        )

# Setup directories
data_directory = '/content/fhir/'
temp_output_dir = '/content/temp_parquet/'
if os.path.exists(temp_output_dir):
    shutil.rmtree(temp_output_dir)
os.makedirs(temp_output_dir)

all_json_files = [os.path.join(data_directory, f) for f in os.listdir(data_directory) if f.endswith('.json')]
print(f"Total candidate JSON files: {len(all_json_files)}")

print("\n--- Parsing in parallel and saving to temporary files... ---")
# Limit workers to 2 to conserve RAM in the free Colab environment
with ProcessPoolExecutor(max_workers=2) as executor:
    futures = {executor.submit(process_and_save, fp, temp_output_dir): fp for fp in all_json_files}
    for i, future in enumerate(as_completed(futures), 1):
        if i % 1000 == 0:
            print(f"Processed {i} of {len(all_json_files)} files...")
            gc.collect()

print("\n✅ All files parsed and saved to disk.")

# Delete source files immediately to free disk space
print("--- Removing source JSON files to free disk space... ---")
if os.path.exists(data_directory):
    shutil.rmtree(data_directory)
gc.collect()

print("--- Source JSON directory removed. ---")

In [None]:
import pandas as pd
import numpy as np
import os
import shutil
import gc

def optimize_dtypes(df):
    """Aggressively reduce memory footprint"""
    for col in df.select_dtypes(include=['float64']).columns:
        df[col] = df[col].astype('float32')
    for col in df.select_dtypes(include=['int64']).columns:
        if df[col].min() >= 0 and df[col].max() <= 1: df[col] = df[col].astype('int8')
        elif df[col].min() >= -128 and df[col].max() <= 127: df[col] = df[col].astype('int8')
        elif df[col].min() >= -32768 and df[col].max() <= 32767: df[col] = df[col].astype('int16')
        else: df[col] = df[col].astype('int32')
    return df

print("--- Processing data with streaming writes (NO RAM SPIKE!) ---")
temp_output_dir = '/content/temp_parquet/'
all_parquet_files = [os.path.join(temp_output_dir, f) for f in os.listdir(temp_output_dir) if f.endswith('.parquet')]

final_chunks_dir = '/content/final_parquet_chunks/'
if os.path.exists(final_chunks_dir):
    shutil.rmtree(final_chunks_dir)
os.makedirs(final_chunks_dir)

chunk_size = 2000

essential_columns = [
    'sbp', 'dbp', 'heart_rate', 'respiratory_rate', 'wbc_count', 'hemoglobin', 'hematocrit', 'platelet_count',
    'cholesterol_total', 'triglycerides', 'cholesterol_ldl', 'cholesterol_hdl', 'hba1c', 'pain_score', 'phq2_score',
    'gad7_score', 'height_m', 'weight_kg', 'bun', 'sodium', 'potassium', 'glucose', 'bilirubin', 'temperature',
    'creatinine', 'spo2', 'troponin_i', 'crp', 'lactate', 'procalcitonin', 'd_dimer', 'urine_wbc', 'urine_nitrite',
    'urine_leukocyte', 'stool_occult_blood', 'albumin', 'alt', 'ast', 'tsh', 'free_t4', 'lipase', 'amylase',
    'inr', 'ptt', 'magnesium', 'phosphate', 'rheumatoid_factor'
]

# --- 1. FIX: List ALL columns used in rules to pre-populate them ---
all_rule_cols = [
    # Labs
    'lactate', 'crp', 'troponin_i', 'phq2_score', 'gad7_score', 'urine_wbc', 'urine_nitrite', 'urine_leukocyte', 'd_dimer',
    'stool_occult_blood', 'lipase', 'amylase', 'inr', 'tsh', 'uric_acid', 'rheumatoid_factor', 'albumin', 'magnesium',
    'phosphate', 'chloride', 'bicarbonate', 'alt', 'ast', 'temperature', 'pain_score', 'hba1c', 'bilirubin', 'potassium',
    'platelet_count', 'sodium', 'wbc_count', 'age',

    # History/Med Flags
    'history_cad', 'history_afib', 'history_hypertension', 'history_pneumonia', 'history_chf', 'history_diabetes',
    'history_ckd', 'history_copd', 'on_statin', 'history_uti', 'history_dvt_pe', 'history_gi_bleed',
    'history_pancreatitis', 'on_anticoagulant', 'history_thyroid', 'history_gout', 'history_arthritis',
    'history_ibd', 'history_cellulitis', 'history_seizure', 'history_dementia', 'on_lisinopril', 'on_aspirin',
    'on_beta_blocker', 'on_diuretic', 'history_obesity', 'history_anxiety', 'history_depression', 'on_ppi',

    # --- NEWLY ADDED 14 FLAGS ---
    'history_gerd', 'history_allergies', 'history_migraine', 'history_sleep_apnea',
    'history_low_back_pain', 'history_smoker', 'history_alcohol_use', 'history_eczema',
    'history_iron_deficiency', 'history_osteoporosis', 'history_insomnia',
    'history_ibs', 'history_sinusitis', 'history_dizziness',
    # --- END OF MODIFICATIONS ---

    # Delta Flags
    'delta_heart_rate', 'delta_hemoglobin', 'delta_creatinine'
]


total_chunks = (len(all_parquet_files) - 1) // chunk_size + 1

for i in range(0, len(all_parquet_files), chunk_size):
    chunk_files = all_parquet_files[i:i + chunk_size]
    if not chunk_files: continue
    current_chunk_num = i // chunk_size + 1
    print(f"Processing chunk {current_chunk_num}/{total_chunks}...")

    df_chunk = pd.concat([pd.read_parquet(f) for f in chunk_files], ignore_index=True)

    for col in essential_columns:
        if col not in df_chunk.columns: df_chunk[col] = np.nan

    df_chunk['obs_time'] = pd.to_datetime(df_chunk['obs_time'], utc=True)
    df_chunk = df_chunk.sort_values(['patient_id', 'obs_time']).set_index('obs_time')

    vitals_to_track = [
        'heart_rate', 'sbp', 'dbp', 'respiratory_rate', 'temperature', 'glucose',
        'wbc_count', 'hemoglobin', 'creatinine', 'potassium', 'lactate', 'troponin_i'
    ]
    for vital in vitals_to_track:
        if vital in df_chunk.columns:
            rolling_window = df_chunk.groupby('patient_id')[vital].rolling(window='3h', min_periods=1)
            df_chunk[f'{vital}_avg_3hr'] = rolling_window.mean().reset_index(0, drop=True)
            df_chunk[f'{vital}_std_3hr'] = rolling_window.std().reset_index(0, drop=True)

    df_chunk.reset_index(inplace=True)
    cols_to_fill = df_chunk.columns.difference(['patient_id'])
    df_chunk[cols_to_fill] = df_chunk.groupby('patient_id')[cols_to_fill].ffill()
    df_chunk.fillna(0, inplace=True)

    if 'gender' in df_chunk.columns:
        df_chunk = pd.get_dummies(df_chunk, columns=['gender'], drop_first=True, dtype='int8')

    safe_height_sq = (df_chunk['height_m'].replace(0, np.nan)) ** 2
    safe_sbp = df_chunk['sbp'].replace(0, np.nan)
    safe_creatinine = df_chunk['creatinine'].replace(0, np.nan)
    safe_hemoglobin = df_chunk['hemoglobin'].replace(0, np.nan)

    df_chunk['map'] = ((2 * df_chunk['dbp']) + df_chunk['sbp']) / 3
    df_chunk['shock_index'] = df_chunk['heart_rate'] / safe_sbp
    df_chunk['bun_creatinine_ratio'] = df_chunk['bun'] / safe_creatinine

    # Ensure all rule columns exist before calculating rules
    for col in all_rule_cols:
         if col not in df_chunk.columns:
            df_chunk[col] = 0

    df_chunk['anion_gap'] = df_chunk['sodium'] - (df_chunk['chloride'] + df_chunk['bicarbonate'])
    df_chunk['pulse_pressure'] = df_chunk['sbp'] - df_chunk['dbp']

    delta_cols_list = ['sbp', 'dbp', 'heart_rate', 'wbc_count', 'hemoglobin', 'map', 'creatinine', 'lactate', 'troponin_i']
    for col in delta_cols_list:
        if col in df_chunk.columns:
            df_chunk[f'delta_{col}'] = df_chunk.groupby('patient_id')[col].diff().fillna(0)

    # --- MULTI-TIMEFRAME LABEL ENGINEERING (NOW 65 CONDITIONS) ---
    conditions = {
        # --- Original 28 Conditions ---
        'sepsis': [((df_chunk['heart_rate'] > 110) & (df_chunk['wbc_count'] > 14) & (df_chunk['lactate'] > 2)), ((df_chunk['heart_rate'] > 100) & (df_chunk['wbc_count'] > 12)), ((df_chunk['heart_rate'] > 90) & (df_chunk['wbc_count'] > 11))],
        'anemia': [(df_chunk['hemoglobin'] < 10), (df_chunk['hemoglobin'] < 12), (df_chunk['hemoglobin'] < 13)],
        'hyperlipidemia': [((df_chunk['cholesterol_total'] > 240) | (df_chunk['cholesterol_ldl'] > 160)), ((df_chunk['cholesterol_total'] > 200) | (df_chunk['cholesterol_ldl'] > 130)), (df_chunk['on_statin'] == 1)],
        'mi': [(df_chunk['troponin_i'] > 0.1), (df_chunk['troponin_i'] > 0.04), ((df_chunk['history_cad'] == 1) & (df_chunk['pain_score'] > 5))],
        'stroke': [((df_chunk['history_afib'] == 1) & (df_chunk['age'] > 70)), ((df_chunk['history_hypertension'] == 1) & (df_chunk['age'] > 60)), (df_chunk['age'] > 55)],
        'depression': [(df_chunk['phq2_score'] > 4), (df_chunk['phq2_score'] > 2), (df_chunk['history_depression'] == 1)],
        'anxiety': [(df_chunk['gad7_score'] > 15), (df_chunk['gad7_score'] > 9), (df_chunk['history_anxiety'] == 1)],
        'pneumonia': [((df_chunk['respiratory_rate'] > 24) & (df_chunk['wbc_count'] > 13) & (df_chunk['crp'] > 50)), ((df_chunk['respiratory_rate'] > 20) & (df_chunk['wbc_count'] > 12)), (df_chunk['history_pneumonia'] == 1)],
        'chf_exacerbation': [((df_chunk['history_chf'] == 1) & (df_chunk['respiratory_rate'] > 22)), ((df_chunk['history_chf'] == 1) & (df_chunk['respiratory_rate'] > 20)), (df_chunk['history_chf'] == 1)],
        'hypertension': [((df_chunk['sbp'] > 160) | (df_chunk['dbp'] > 100)), ((df_chunk['sbp'] > 140) | (df_chunk['dbp'] > 90)), (df_chunk['history_hypertension'] == 1)],
        'diabetes': [(df_chunk['glucose'] > 200), ((df_chunk['hba1c'] > 6.5) | (df_chunk['glucose'] > 126)), (df_chunk['history_diabetes'] == 1)],
        'hypoglycemia': [(df_chunk['glucose'] < 60), (df_chunk['glucose'] < 70), (df_chunk['glucose'] < 80)],
        'aki': [(df_chunk['creatinine'] > (df_chunk['creatinine'].median() + 0.5)), (df_chunk['creatinine'] > (df_chunk['creatinine'].median() + 0.3)), (df_chunk['creatinine'] > (df_chunk['creatinine'].median() + 0.2))],
        'tachycardia': [(df_chunk['heart_rate'] > 120), (df_chunk['heart_rate'] > 100), (df_chunk['heart_rate'] > 95)],
        'bradycardia': [(df_chunk['heart_rate'] < 50), (df_chunk['heart_rate'] < 60), (df_chunk['heart_rate'] < 65)],
        'hypotension': [(df_chunk['sbp'] < 85), (df_chunk['sbp'] < 90), (df_chunk['sbp'] < 95)],
        'acute_bronchitis': [((df_chunk['respiratory_rate'] > 22) & (df_chunk['wbc_count'] < 10)), ((df_chunk['respiratory_rate'] > 20) & (df_chunk['wbc_count'] < 12)), ((df_chunk['respiratory_rate'] > 18) & (df_chunk['wbc_count'] < 12))],
        'ckd': [(df_chunk['creatinine'] > 2.0), (df_chunk['creatinine'] > 1.5), (df_chunk['history_ckd'] == 1)],
        'copd_exacerbation': [((df_chunk['history_copd'] == 1) & (df_chunk['respiratory_rate'] > 22)), ((df_chunk['history_copd'] == 1) & (df_chunk['respiratory_rate'] > 20)), (df_chunk['history_copd'] == 1)],
        'liver_disease': [(df_chunk['bilirubin'] > 2.5), ((df_chunk['alt'] > 100) | (df_chunk['ast'] > 100)), (df_chunk['bilirubin'] > 1.8)],
        'hypokalemia': [(df_chunk['potassium'] < 3.0), (df_chunk['potassium'] < 3.5), (df_chunk['potassium'] < 3.7)],
        'hypernatremia': [(df_chunk['sodium'] > 150), (df_chunk['sodium'] > 145), (df_chunk['sodium'] > 142)],
        'obesity': [((df_chunk['weight_kg'] / safe_height_sq) > 35), ((df_chunk['weight_kg'] / safe_height_sq) > 30), (df_chunk['history_obesity'] == 1)],
        'dehydration': [(df_chunk['bun_creatinine_ratio'] > 25), (df_chunk['bun_creatinine_ratio'] > 20), (df_chunk['sodium'] > 145)],
        'thrombocytopenia': [(df_chunk['platelet_count'] < 100), (df_chunk['platelet_count'] < 150), (df_chunk['platelet_count'] < 180)],
        'hyperkalemia': [(df_chunk['potassium'] > 5.5), (df_chunk['potassium'] > 5.0), (df_chunk['potassium'] > 4.8)],
        'hyponatremia': [(df_chunk['sodium'] < 130), (df_chunk['sodium'] < 135), (df_chunk['sodium'] < 137)],
        'leukopenia': [(df_chunk['wbc_count'] < 3.0), (df_chunk['wbc_count'] < 4.0), (df_chunk['wbc_count'] < 4.5)],

        # --- Original 23 New Conditions ---
        'uti_risk': [((df_chunk['urine_wbc'] > 50) & (df_chunk['urine_nitrite'] == 1)), ((df_chunk['urine_wbc'] > 10) | (df_chunk['urine_leukocyte'] == 1)), (df_chunk['history_uti'] == 1)],
        'pulmonary_embolism': [((df_chunk['d_dimer'] > 500) & (df_chunk['heart_rate'] > 100)), (df_chunk['d_dimer'] > 250), (df_chunk['history_dvt_pe'] == 1)],
        'atrial_fibrillation': [((df_chunk['heart_rate'] > 120) & (df_chunk['delta_heart_rate'].abs() > 30)), ((df_chunk['heart_rate'] > 100) & (df_chunk['age'] > 65)), (df_chunk['history_afib'] == 1)],
        'gi_bleed': [(df_chunk['delta_hemoglobin'] < -2.0), (df_chunk['stool_occult_blood'] == 1), (df_chunk['history_gi_bleed'] == 1)],
        'pancreatitis': [((df_chunk['lipase'] > 300) | (df_chunk['amylase'] > 300)), ((df_chunk['lipase'] > 150) | (df_chunk['amylase'] > 150)), (df_chunk['history_pancreatitis'] == 1)],
        'coagulopathy': [(df_chunk['inr'] > 3.0), (df_chunk['inr'] > 1.5), (df_chunk['on_anticoagulant'] == 1)],
        'dvt_risk': [(df_chunk['d_dimer'] > 500), (df_chunk['d_dimer'] > 250), (df_chunk['history_dvt_pe'] == 1)],
        'hypothyroidism': [(df_chunk['tsh'] > 10.0), (df_chunk['tsh'] > 4.5), (df_chunk['history_thyroid'] == 1)],
        'hyperthyroidism': [(df_chunk['tsh'] < 0.1), (df_chunk['tsh'] < 0.4), ((df_chunk['history_thyroid'] == 1) & (df_chunk['heart_rate'] > 100))],
        'gout': [(df_chunk['uric_acid'] > 9.0), (df_chunk['uric_acid'] > 7.0), (df_chunk['history_gout'] == 1)],
        'arthritis': [(df_chunk['rheumatoid_factor'] > 20), (df_chunk['crp'] > 10), (df_chunk['history_arthritis'] == 1)],
        'ibd_exacerbation': [((df_chunk['history_ibd'] == 1) & (df_chunk['crp'] > 20)), ((df_chunk['history_ibd'] == 1) & (df_chunk['hemoglobin'] < 10)), (df_chunk['history_ibd'] == 1)],
        'cellulitis': [((df_chunk['history_cellulitis'] == 1) & (df_chunk['wbc_count'] > 12)), ((df_chunk['history_cellulitis'] == 1) & (df_chunk['temperature'] > 38.0)), (df_chunk['history_cellulitis'] == 1)],
        'seizure_risk': [(df_chunk['history_seizure'] == 1), (df_chunk['history_seizure'] == 1), (df_chunk['history_seizure'] == 1)],
        'dementia_risk': [(df_chunk['history_dementia'] == 1), (df_chunk['history_dementia'] == 1), (df_chunk['history_dementia'] == 1)],
        'malnutrition': [(df_chunk['albumin'] < 2.5), (df_chunk['albumin'] < 3.4), ((df_chunk['age'] > 75) & (df_chunk['albumin'] < 3.5))],
        'drug_side_effect_renal': [((df_chunk['on_lisinopril'] == 1) & (df_chunk['delta_creatinine'] > 0.3)), ((df_chunk['on_lisinopril'] == 1) & (df_chunk['potassium'] > 5.0)), (df_chunk['on_lisinopril'] == 1)],
        'drug_side_effect_bleed': [(((df_chunk['on_aspirin'] == 1) | (df_chunk['on_anticoagulant'] == 1)) & (df_chunk['delta_hemoglobin'] < -1.0)), (((df_chunk['on_aspirin'] == 1) | (df_chunk['on_anticoagulant'] == 1)) & (df_chunk['platelet_count'] < 150)), ((df_chunk['on_aspirin'] == 1) | (df_chunk['on_anticoagulant'] == 1))],
        'drug_side_effect_bradycardia': [((df_chunk['on_beta_blocker'] == 1) & (df_chunk['heart_rate'] < 50)), ((df_chunk['on_beta_blocker'] == 1) & (df_chunk['heart_rate'] < 60)), (df_chunk['on_beta_blocker'] == 1)],
        'hypomagnesemia': [(df_chunk['magnesium'] < 1.2), (df_chunk['magnesium'] < 1.8), (df_chunk['on_diuretic'] == 1)],
        'hypophosphatemia': [(df_chunk['phosphate'] < 2.0), (df_chunk['phosphate'] < 2.5), (df_chunk['history_diabetes'] == 1)],
        'acidosis': [(df_chunk['anion_gap'] > 20), (df_chunk['anion_gap'] > 16), (df_chunk['lactate'] > 2.0)],
        'rhabdomyolysis': [((df_chunk['creatinine'] > 2.0) & (df_chunk['potassium'] > 5.5)), (df_chunk['on_statin'] == 1), (df_chunk['on_statin'] == 1)],

        # --- NEW 14 BASIC CONDITIONS ---
        'gerd_risk': [(df_chunk['on_ppi'] == 1), (df_chunk['history_gerd'] == 1), (df_chunk['history_gerd'] == 1)],
        'allergy_risk': [(df_chunk['history_allergies'] == 1), (df_chunk['history_allergies'] == 1), (df_chunk['history_allergies'] == 1)],
        'migraine_risk': [(df_chunk['history_migraine'] == 1), (df_chunk['history_migraine'] == 1), (df_chunk['history_migraine'] == 1)],
        'sleep_apnea_risk': [((df_chunk['history_sleep_apnea'] == 1) & (df_chunk['history_obesity'] == 1)), (df_chunk['history_sleep_apnea'] == 1), (df_chunk['history_obesity'] == 1)],
        'low_back_pain_risk': [(df_chunk['history_low_back_pain'] == 1), (df_chunk['history_low_back_pain'] == 1), (df_chunk['history_low_back_pain'] == 1)],
        'smoker_risk': [(df_chunk['history_smoker'] == 1), (df_chunk['history_smoker'] == 1), (df_chunk['history_smoker'] == 1)],
        'alcohol_use_risk': [((df_chunk['history_alcohol_use'] == 1) & (df_chunk['alt'] > 50)), (df_chunk['history_alcohol_use'] == 1), (df_chunk['history_alcohol_use'] == 1)],
        'eczema_risk': [(df_chunk['history_eczema'] == 1), (df_chunk['history_eczema'] == 1), (df_chunk['history_eczema'] == 1)],
        'iron_deficiency_risk': [(df_chunk['history_iron_deficiency'] == 1), (df_chunk['history_iron_deficiency'] == 1), (df_chunk['history_anemia'] == 1)],
        'osteoporosis_risk': [(df_chunk['history_osteoporosis'] == 1), (df_chunk['history_osteoporosis'] == 1), (df_chunk['age'] > 65)],
        'insomnia_risk': [(df_chunk['history_insomnia'] == 1), (df_chunk['history_insomnia'] == 1), (df_chunk['history_anxiety'] == 1)],
        'ibs_risk': [(df_chunk['history_ibs'] == 1), (df_chunk['history_ibs'] == 1), (df_chunk['history_ibs'] == 1)],
        'sinusitis_risk': [(df_chunk['history_sinusitis'] == 1), (df_chunk['history_sinusitis'] == 1), (df_chunk['history_allergies'] == 1)],
        'dizziness_risk': [(df_chunk['history_dizziness'] == 1), (df_chunk['history_dizziness'] == 1), (df_chunk['history_dizziness'] == 1)]
    }

    # --- 2. FIX: Address PerformanceWarning by creating columns in a dict first ---
    new_risk_cols = {}
    timeframes = ['_6h', '_24h', '_48h']
    for name, rules in conditions.items():
        for idx, time in enumerate(timeframes):
            new_risk_cols[f'risk_{name}{time}'] = rules[idx].astype('int8')

    df_chunk = pd.concat([df_chunk, pd.DataFrame(new_risk_cols, index=df_chunk.index)], axis=1)
    # --- End of Fix 2 ---

    df_chunk.fillna(0, inplace=True)
    df_chunk = optimize_dtypes(df_chunk)

    output_filename = os.path.join(final_chunks_dir, f"processed_chunk_{current_chunk_num}.parquet")
    df_chunk.to_parquet(output_filename, engine='pyarrow', compression='snappy', index=False)

    print(f"    -> Chunk processed and written to {output_filename}. RAM freed.")
    del df_chunk, new_risk_cols
    gc.collect()

# Clean up temp directory
print("\n--- Removing temporary parquet files... ---")
if os.path.exists(temp_output_dir):
    shutil.rmtree(temp_output_dir)
gc.collect()

print(f"\n✅ All chunks processed and saved to '{final_chunks_dir}'")
print(f"Total chunks saved: {total_chunks}. Total conditions: {len(conditions)}")

In [None]:
import pandas as pd
import numpy as np
import json
import joblib
import gc
import os
import math
import pyarrow.parquet as pq # Needed to read file metadata
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Dropout, BatchNormalization
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# --- 1. Get File Lists and Column Names ---
final_chunks_dir = '/content/final_parquet_chunks/'
all_chunk_files = [os.path.join(final_chunks_dir, f) for f in os.listdir(final_chunks_dir) if f.endswith('.parquet')]
all_chunk_files.sort()

print("--- Reading column names from first chunk... ---")
if not all_chunk_files:
    print("❌ ERROR: No .parquet files found in /content/final_parquet_chunks/")
    raise FileNotFoundError("No processed chunk files found. Please run the previous script.")

sample_df = pd.read_parquet(all_chunk_files[0])
target_diseases_multitime = [col for col in sample_df.columns if col.startswith('risk_')]
feature_columns = [col for col in sample_df.columns if col not in target_diseases_multitime and col not in ['patient_id', 'obs_time', 'obs_time_utc']]
base_diseases = sorted(list(set([c.replace('risk_','').split('_')[0] for c in target_diseases_multitime])))
del sample_df
gc.collect()

print(f"Found {len(feature_columns)} features and {len(target_diseases_multitime)} targets.")

# --- 2. Fit the Scaler (Memory-Safe) ---
print("--- Fitting StandardScaler in chunks (memory-safe)... ---")
scaler = StandardScaler()
total_rows = 0

for f in all_chunk_files:
    df_chunk = pd.read_parquet(f, columns=feature_columns)
    scaler.partial_fit(df_chunk)
    total_rows += len(df_chunk)
    del df_chunk

print(f"Scaler fit on {total_rows} total rows.")
print("--- Saving scaler... ---")
joblib.dump(scaler, 'data_scaler.joblib')

# --- 3. Split files for training and validation ---
if len(all_chunk_files) > 1:
    print(f"Splitting {len(all_chunk_files)} chunks into train/val sets...")
    train_files, val_files = train_test_split(all_chunk_files, test_size=0.2, random_state=42)
else:
    print(f"⚠️ Warning: Only {len(all_chunk_files)} data chunk found. Using it for BOTH training and validation to test the pipeline.")
    train_files = all_chunk_files
    val_files = all_chunk_files

print(f"Using {len(train_files)} file(s) for training, {len(val_files)} file(s) for validation.")


# --- 4. Define the Keras Data Generator (FIX #1: Solves System RAM OOM) ---
def data_generator(file_list, scaler, feature_cols, target_cols, batch_size):
    while True:
        shuffled_files = np.random.permutation(file_list)

        for f in shuffled_files:
            try:
                parquet_file = pq.ParquetFile(f)
                for batch in parquet_file.iter_batches(batch_size=batch_size):
                    batch_df = batch.to_pandas()
                    batch_df = batch_df.sample(frac=1)
                    X_pre_scale = batch_df.reindex(columns=feature_cols).fillna(0)
                    y = batch_df.reindex(columns=target_cols).fillna(0)
                    X_scaled = scaler.transform(X_pre_scale)
                    y_dict = {col: y[col] for col in y.columns}
                    yield X_scaled, y_dict
            except Exception as e:
                print(f"Warning: Error reading or streaming {f}: {e}. Skipping file.")
                continue


# --- 5. Define the Multi-Task Neural Network (FIX #2: LIGHTEST GPU VRAM Model) ---
def create_multi_task_model(n_features, targets):
    input_layer = Input(shape=(n_features,), name='input_features')
    x = BatchNormalization()(input_layer)
    x = Dropout(0.2)(x)

    # Shared "Body" - This is now even smaller (64 -> 32) to save VRAM
    x = Dense(64, activation='relu')(x)
    x = Dropout(0.3)(x)
    shared_body = Dense(32, activation='relu')(x) # Was 64

    output_layers = []
    output_losses = {}
    output_metrics = {}

    for target_name in targets:
        # Simplified heads (no extra Dense(16) layer)
        output = Dense(1, activation='sigmoid', name=target_name)(shared_body)
        output_layers.append(output)
        output_losses[target_name] = 'binary_crossentropy'
        output_metrics[target_name] = tf.keras.metrics.AUC(name=f'{target_name}_auc')

    model = Model(inputs=input_layer, outputs=output_layers)

    model.compile(optimizer='adam',
                  loss=output_losses,
                  metrics=output_metrics)
    return model

model = create_multi_task_model(len(feature_columns), target_diseases_multitime)
model.summary()


# --- 6. Train the Model using the Generator (MODIFIED FOR 1 EPOCH) ---
print("\n--- Training Multi-Task Neural Network from disk... ---")

# Using the Batch Size that worked in your previous log
BATCH_SIZE = 256

train_rows = sum([pq.ParquetFile(f).metadata.num_rows for f in train_files])
val_rows = sum([pq.ParquetFile(f).metadata.num_rows for f in val_files])

steps_per_epoch = math.ceil(train_rows / BATCH_SIZE)
validation_steps = math.ceil(val_rows / BATCH_SIZE)

print(f"Train Rows: {train_rows}, Val Rows: {val_rows}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Steps per Epoch: {steps_per_epoch}, Validation Steps: {validation_steps}")

# Define the data types and shapes for the tf.data.Dataset
n_features = len(feature_columns)
X_spec = tf.TensorSpec(shape=(None, n_features), dtype=tf.float32)
y_spec = {
    target_name: tf.TensorSpec(shape=(None,), dtype=tf.float32)
    for target_name in target_diseases_multitime
}
output_signature = (X_spec, y_spec)

# Create the Dataset objects
train_ds = tf.data.Dataset.from_generator(
    lambda: data_generator(train_files, scaler, feature_columns, target_diseases_multitime, BATCH_SIZE),
    output_signature=output_signature
)
val_ds = tf.data.Dataset.from_generator(
    lambda: data_generator(val_files, scaler, feature_columns, target_diseases_multitime, BATCH_SIZE),
    output_signature=output_signature
)

# This is the fix for RAM paging (the slowdown)
train_ds = train_ds.prefetch(5)
val_ds = val_ds.prefetch(5)

print("\nStarting model.fit() for exactly 1 epoch...")

# --- START OF 1-EPOCH MODIFICATION ---
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=1,  # Set to 1
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps,
    callbacks=[], # No callbacks needed for 1 epoch
    verbose=1
)
# --- END OF 1-EPOCH MODIFICATION ---

print("✅ Multi-Task model training complete!\n")

# --- 7. Save Artifacts for GUI and Deployment ---
print("\n--- Saving model and helper files... ---")
model.save('patient_risk_model.keras')
# Scaler was already saved in step 2

with open('feature_columns_multitime.json', 'w') as f:
    json.dump(feature_columns, f)
with open('target_diseases_base.json', 'w') as f:
    json.dump(base_diseases, f)

print("✅ Keras model, scaler, and helper files saved.")

# --- 8. Convert to TensorFlow Lite (for Raspberry Pi) ---
print("\n--- Converting Keras model to TensorFlow Lite for edge deployment... ---")
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

with open('patient_risk_model.tflite', 'wb') as f:
    f.write(tflite_model)

print(f"✅ Model converted and saved as 'patient_risk_model.tflite' ({len(tflite_model) / 1024:.2f} KB)")
print("This .tflite file is what you would deploy to the Raspberry Pi.")

In [None]:
import pandas as pd
import numpy as np
import joblib
import json
import tensorflow as tf

# --- 1. Load Artifacts ---
print("--- Loading 51-condition model and artifacts... ---")
try:
    model = tf.keras.models.load_model('patient_risk_model.keras')
    scaler = joblib.load('data_scaler.joblib')
    feature_columns = json.load(open('feature_columns_multitime.json'))
    base_diseases = json.load(open('target_diseases_base.json'))
    target_diseases_multitime = model.output_names # Get targets directly from model
except FileNotFoundError as e:
    print(f"--- ERROR: Could not find a required file: {e.filename} ---")
    print("Please ensure you have run Block 4 to train and save the model first.")
    model = None

# --- 2. Create Data for a New Hypothetical Patient (with 51-condition data) ---
if model:
    new_patient_data = pd.DataFrame({
        # --- Original Data ---
        'age': [72], 'gender': ['male'],
        'sbp': [160], 'dbp': [95], 'heart_rate': [95], 'respiratory_rate': [24],
        'pain_score': [6], 'height_m': [1.68], 'weight_kg': [80],
        'wbc_count': [11], 'hemoglobin': [12], 'hematocrit': [37], 'platelet_count': [210],
        'cholesterol_total': [230], 'triglycerides': [160], 'cholesterol_ldl': [150], 'cholesterol_hdl': [35],
        'hba1c': [7.0], 'phq2_score': [1], 'gad7_score': [10], 'temperature': [37.2], 'creatinine': [1.4], 'bun': [22],
        'sodium': [140], 'potassium': [4.2], 'glucose': [130], 'bilirubin': [0.8], 'spo2': [96],

        # --- Original History ---
        'history_hypertension': [1], 'history_diabetes': [1], 'history_chf': [1], 'history_copd': [1],
        'history_anemia': [0], 'history_hyperlipidemia': [1], 'history_cad': [1], 'history_ckd': [1],
        'history_asthma': [0], 'history_pneumonia': [0], 'history_flu': [0], 'history_obesity': [1],
        'on_statin': [1], 'on_metformin': [1], 'on_aspirin': [1], 'on_lisinopril': [1],
        'history_cabg': [0], 'history_appendectomy': [0],

        # --- NEW LABS for 51-condition model ---
        'troponin_i': [0.01], 'crp': [45.0], 'lactate': [1.9], 'procalcitonin': [0.3], 'd_dimer': [350],
        'urine_wbc': [10], 'urine_nitrite': [0], 'urine_leukocyte': [1], 'stool_occult_blood': [0],
        'albumin': [3.8], 'alt': [30], 'ast': [35], 'tsh': [4.0], 'free_t4': [1.1],
        'lipase': [80], 'amylase': [75], 'inr': [1.0], 'ptt': [30], 'magnesium': [1.9],
        'phosphate': [3.0], 'rheumatoid_factor': [10],

        # --- NEW HISTORY/MEDS for 51-condition model ---
        'history_afib': [1], 'history_dvt_pe': [0], 'history_thyroid': [1], 'history_gout': [0],
        'history_arthritis': [1], 'history_pancreatitis': [0], 'history_gi_bleed': [0], 'history_ibd': [0],
        'history_cellulitis': [0], 'history_uti': [1], 'history_seizure': [0], 'history_dementia': [0],
        'on_anticoagulant': [1], 'on_beta_blocker': [1], 'on_diuretic': [1], 'on_ccb': [1],
        'on_insulin': [0], 'on_ppi': [1], 'on_thyroid_med': [1], 'on_gout_med': [0],
        'proc_colonoscopy': [1], 'proc_egd': [1], 'proc_dialysis': [0]
    })

    # --- 3. PREPARE DATA FOR PREDICTION (Keras Version) ---
    new_patient_processed = new_patient_data.copy()

    # a. Create trend features (assume stability for a single snapshot)
    vitals_to_track = [
        'heart_rate', 'sbp', 'dbp', 'respiratory_rate', 'temperature', 'glucose',
        'wbc_count', 'hemoglobin', 'creatinine', 'potassium', 'lactate', 'troponin_i'
    ]
    for vital in vitals_to_track:
        if vital in new_patient_processed.columns:
            current_val = new_patient_processed[vital].iloc[0]
            new_patient_processed[f'{vital}_avg_3hr'] = current_val
            new_patient_processed[f'{vital}_std_3hr'] = 0.0

    # b. Create other derived features (must match Block 3)
    new_patient_processed = pd.get_dummies(new_patient_processed, columns=['gender'], drop_first=True, dtype=int)
    safe_height_sq = (new_patient_processed['height_m'].replace(0, np.nan)) ** 2
    safe_sbp = new_patient_processed['sbp'].replace(0, np.nan)
    safe_creatinine = new_patient_processed['creatinine'].replace(0, np.nan)

    new_patient_processed['map'] = ((2 * new_patient_processed['dbp']) + new_patient_processed['sbp']) / 3
    new_patient_processed['shock_index'] = new_patient_processed['heart_rate'] / safe_sbp
    new_patient_processed['bun_creatinine_ratio'] = new_patient_processed['bun'] / safe_creatinine
    new_patient_processed['anion_gap'] = new_patient_processed['sodium'] - (105 + 24) # Using defaults
    new_patient_processed['pulse_pressure'] = new_patient_processed['sbp'] - new_patient_processed['dbp']

    # c. Create delta features (assume 0 for a snapshot)
    delta_cols_list = [
        'sbp', 'dbp', 'heart_rate', 'wbc_count', 'hemoglobin', 'map',
        'creatinine', 'lactate', 'troponin_i'
    ]
    for col in delta_cols_list:
        new_patient_processed[f'delta_{col}'] = 0

    # d. Align columns with the training set
    new_patient_processed = new_patient_processed.reindex(columns=feature_columns).fillna(0)

    # e. Scale the data using the saved scaler
    new_patient_scaled = scaler.transform(new_patient_processed)

    # --- 4. GET PREDICTIONS (Keras Version) ---
    print("--- Making predictions with the Keras multi-task model... ---")
    probabilities_list = model.predict(new_patient_scaled)

    results_map = {name: probabilities_list[i][0][0] for i, name in enumerate(target_diseases_multitime)}

    # --- 5. DISPLAY RESULTS IN A TABLE ---
    output_df_data = []
    for disease_base in base_diseases:
        clean_name = disease_base.replace('_',' ').title()
        output_df_data.append({
            "Condition": clean_name,
            "6-Hour Risk": f"{results_map.get('risk_' + disease_base + '_6h', 0):.2%}",
            "24-Hour Risk": f"{results_map.get('risk_' + disease_base + '_24h', 0):.2%}",
            "48-Hour Risk": f"{results_map.get('risk_' + disease_base + '_48h', 0):.2%}",
            "_sort_key": results_map.get('risk_' + disease_base + '_24h', 0)
        })

    sorted_results = sorted(output_df_data, key=lambda x: x["_sort_key"], reverse=True)
    results_df = pd.DataFrame(sorted_results).drop(columns=['_sort_key'])

    print("\n--- Multi-Timeframe Predictive Risk Assessment (51 Conditions) ---")
    print(results_df.to_string(max_rows=len(sorted_results)))


In [None]:
import pandas as pd
import numpy as np
import joblib
import json
import tensorflow as tf
import os
import math
import gc
import pyarrow.parquet as pq
from sklearn.model_selection import train_test_split

print("--- Evaluating Multi-Task Model Performance on Unseen Test Data ---")

try:
    # --- 1. Load Artifacts ---
    print("Loading model, scaler, and file lists...")
    model = tf.keras.models.load_model('patient_risk_model.keras')
    scaler = joblib.load('data_scaler.joblib')

    # --- 2. Re-create File Lists and Column Names ---
    final_chunks_dir = '/content/final_parquet_chunks/'
    all_chunk_files = [os.path.join(final_chunks_dir, f) for f in os.listdir(final_chunks_dir) if f.endswith('.parquet')]
    all_chunk_files.sort()

    sample_df = pd.read_parquet(all_chunk_files[0])
    target_diseases_multitime = [col for col in sample_df.columns if col.startswith('risk_')]
    feature_columns = [col for col in sample_df.columns if col not in target_diseases_multitime and col not in ['patient_id', 'obs_time', 'obs_time_utc']]
    del sample_df

    train_files, val_files = train_test_split(all_chunk_files, test_size=0.2, random_state=42)

    # --- 3. Re-create Data Generator (for validation) ---
    def data_generator(file_list, scaler, feature_cols, target_cols, batch_size):
        while True: # Keras needs this to be infinite
            for f in file_list: # No need to shuffle for evaluation
                df = pd.read_parquet(f)

                for i in range(0, len(df), batch_size):
                    batch_df = df.iloc[i:i+batch_size]
                    X = batch_df[feature_cols]
                    y = batch_df[target_cols]

                    X_scaled = scaler.transform(X)
                    y_dict = {col: y[col] for col in y.columns}

                    yield X_scaled, y_dict

                del df
                gc.collect()

    # --- 4. Evaluate the Model ---
    BATCH_SIZE = 256 # Use the same batch size as training

    # Calculate total validation rows
    val_rows = sum([pq.ParquetFile(f).metadata.num_rows for f in val_files])
    validation_steps = math.ceil(val_rows / BATCH_SIZE)

    print(f"\nFound {val_rows} validation rows. Using {validation_steps} steps.")

    # Re-create the generator
    val_gen = data_generator(val_files, scaler, feature_columns, target_diseases_multitime, BATCH_SIZE)

    print("Running model.evaluate()...")
    # This runs the model over all validation batches
    results = model.evaluate(
        val_gen,
        steps=validation_steps,
        verbose=1
    )

    # --- 5. Display Results ---
    print("\n--- Evaluation Complete ---")

    # Keras returns a list: [total_loss, head_1_loss, head_2_loss, ..., head_1_auc, head_2_auc, ...]
    # We can map them using the model's metric names
    results_map = dict(zip(model.metrics_names, results))

    # Filter just for the AUC scores
    auc_scores = {name: score for name, score in results_map.items() if 'auc' in name}

    if auc_scores:
        # Sort by AUC score, descending
        sorted_auc = sorted(auc_scores.items(), key=lambda item: item[1], reverse=True)

        print(f"✅ Average AUC-ROC across all {len(auc_scores)} outputs: {np.mean(list(auc_scores.values())):.4f}\n")

        print("--- Best Performing Predictions (Top 10) ---")
        for name, score in sorted_auc[:10]:
            print(f"{name}: {score:.4f}")

        print("\n--- Worst Performing Predictions (Bottom 10) ---")
        for name, score in sorted_auc[-10:]:
            print(f"{name}: {score:.4f}")

    else:
        print("Could not parse AUC scores from evaluation results.")

except Exception as e:
    print(f"--- ERROR during evaluation: {e} ---")
    print("Please ensure Block 4 has been run successfully and all files are present.")


In [None]:
# --- 1. INSTALL GRADIO (if not already installed) ---
!pip install gradio -q
import gradio as gr
import pandas as pd
import numpy as np
import joblib
import json
import tensorflow as tf

# --- 2. LOAD SAVED ARTIFACTS ---
print("--- Loading 51-condition model, scaler, and helper files... ---")
try:
    model = tf.keras.models.load_model('patient_risk_model.keras')
    scaler = joblib.load('data_scaler.joblib')
    feature_columns = json.load(open('feature_columns_multitime.json'))
    base_diseases = json.load(open('target_diseases_base.json'))
    target_diseases_multitime = model.output_names

    # Create a simple median map for imputation (using scaler's mean)
    # This is more robust than the old median file
    imputation_values = pd.Series(scaler.mean_, index=feature_columns)

    print("✅ Keras model and all necessary files loaded successfully.")
except Exception as e:
    print(f"--- ERROR: Could not load a required file. {e} ---")
    model, scaler, feature_columns, base_diseases, target_diseases_multitime = None, None, [], [], []
    imputation_values = pd.Series()

# --- 3. CREATE THE PREDICTION FUNCTION (51-Condition Version) ---
def predict_risk_multitime(ignore_features, *args):
    if model is None: return pd.DataFrame({"Error": ["Model not loaded. Please train the model first."]})

    # Map all the Gradio inputs back to a dictionary
    input_names = [inp.label for inp in all_inputs if inp.label and inp.label != "Ignore Inputs"]
    input_values = dict(zip(input_names, args))

    # Use a copy of the imputation values (scaler means)
    input_data = imputation_values.copy()

    # Overwrite the means with user-provided data
    for name, value in input_values.items():
        # Generate the feature key
        # --- BUG FIX: Completed the replace() chain ---
        key = name.lower().replace(' ', '_').replace('(', '').replace(')', '').replace('°c', '').replace('/', '_').replace('-', '_')
        if key == 'systolic_bp': key = 'sbp'
        if key == 'diastolic_bp': key = 'dbp'
        if key == 'temperature_c': key = 'temperature'

        # Use imputation value if the field is blank (None) OR if it was checked
        if value is None or name in ignore_features:
            # We're already using the mean, so just 'continue'
            continue
        else:
            # User provided a value, so overwrite the mean
            if key in input_data:
                input_data[key] = value

    # --- Handle Checkboxes for History and Medications ---
    # args[-5] = medical_history
    # args[-4] = surgical_history
    # args[-3] = current_meds
    # args[-2] = diuretic_meds
    # args[-1] = other_meds

    hist_med_map = {
        'Hypertension': 'history_hypertension', 'Diabetes': 'history_diabetes', 'CHF': 'history_chf',
        'COPD': 'history_copd', 'Anemia': 'history_anemia', 'Hyperlipidemia': 'history_hyperlipidemia',
        'CAD': 'history_cad', 'CKD': 'history_ckd', 'Asthma': 'history_asthma', 'Atrial Fibrillation': 'history_afib',
        'PE/DVT History': 'history_dvt_pe', 'Thyroid Disease': 'history_thyroid', 'Gout': 'history_gout',
        'Arthritis': 'history_arthritis', 'Pancreatitis': 'history_pancreatitis', 'GI Bleed': 'history_gi_bleed',
        'IBD (Crohn\'s/Colitis)': 'history_ibd', 'Cellulitis': 'history_cellulitis', 'UTI': 'history_uti',
        'Seizure': 'history_seizure', 'Dementia': 'history_dementia'
    }
    for choice, key in hist_med_map.items():
        input_data[key] = 1 if choice in args[-5] else 0

    proc_map = {'CABG': 'history_cabg', 'Appendectomy': 'history_appendectomy', 'Colonoscopy': 'proc_colonoscopy', 'EGD': 'proc_egd', 'Dialysis': 'proc_dialysis'}
    for choice, key in proc_map.items():
        input_data[key] = 1 if choice in args[-4] else 0

    med_map_1 = {'Statin': 'on_statin', 'Metformin': 'on_metformin', 'Aspirin': 'on_aspirin', 'Lisinopril (ACE-I)': 'on_lisinopril', 'Insulin': 'on_insulin'}
    for choice, key in med_map_1.items():
        input_data[key] = 1 if choice in args[-3] else 0

    med_map_2 = {'Furosemide': 'on_diuretic', 'Hydrochlorothiazide': 'on_diuretic', 'Spironolactone': 'on_diuretic'}
    input_data['on_diuretic'] = 1 if any(choice in args[-2] for choice in med_map_2.keys()) else 0

    med_map_3 = {
        'Anticoagulant (Warfarin/Eliquis/Xarelto)': 'on_anticoagulant', 'Beta Blocker (Metoprolol/Atenolol)': 'on_beta_blocker',
        'Calcium Channel Blocker (Amlodipine)': 'on_ccb', 'PPI (Omeprazole/Pantoprazole)': 'on_ppi',
        'Levothyroxine': 'on_thyroid_med', 'Allopurinol': 'on_gout_med'
    }
    for choice, key in med_map_3.items():
        input_data[key] = 1 if choice in args[-1] else 0


    df = pd.DataFrame([input_data])

    # --- Re-create Derived Features (must match Block 3) ---
    # Note: We are overwriting the imputed values with more accurate calculations
    safe_height_sq = (df['height_m'].replace(0, np.nan)) ** 2
    safe_sbp = df['sbp'].replace(0, np.nan)
    safe_creatinine = df['creatinine'].replace(0, np.nan)

    df['map'] = ((2 * df['dbp']) + df['sbp']) / 3
    df['shock_index'] = df['heart_rate'] / safe_sbp
    df['bun_creatinine_ratio'] = df['bun'] / safe_creatinine
    df['anion_gap'] = df['sodium'] - (105 + 24) # Using defaults
    df['pulse_pressure'] = df['sbp'] - df['dbp']

    # Impute trend features (assume stability for snapshot prediction)
    vitals_to_track = [
        'heart_rate', 'sbp', 'dbp', 'respiratory_rate', 'temperature', 'glucose',
        'wbc_count', 'hemoglobin', 'creatinine', 'potassium', 'lactate', 'troponin_i'
    ]
    for vital in vitals_to_track:
        if vital in df.columns:
            current_val = df[vital].iloc[0]
            # --- BUG FIX: Fixed typo 'current_al' to 'current_val' ---
            df[f'{vital}_avg_3hr'] = current_val
            df[f'{vital}_std_3hr'] = 0.0

    # Impute delta features (assume 0 for a snapshot)
    delta_cols_list = ['sbp', 'dbp', 'heart_rate', 'wbc_count', 'hemoglobin', 'map', 'creatinine', 'lactate', 'troponin_i']
    for col in delta_cols_list:
        if col in feature_columns:
            df[f'delta_{col}'] = 0

    # Align columns, fill any *new* NaNs created by derived features
    df = df.reindex(columns=feature_columns).fillna(0)

    # --- Scale the data ---
    df_scaled = scaler.transform(df)

    # --- Make Predictions ---
    probabilities_list = model.predict(df_scaled)
    results_map = {name: probabilities_list[i][0][0] for i, name in enumerate(target_diseases_multitime)}

    # Format the output table
    output_df_data = []
    for disease_base in base_diseases:
        clean_name = disease_base.replace('_',' ').title()
        output_df_data.append({
            "Condition": clean_name,
            "6-Hour Risk": f"{results_map.get('risk_' + disease_base + '_6h', 0):.2%}",
            "24-Hour Risk": f"{results_map.get('risk_' + disease_base + '_24h', 0):.2%}",
            "48-Hour Risk": f"{results_map.get('risk_' + disease_base + '_48h', 0):.2%}",
            "_sort_key": results_map.get('risk_' + disease_base + '_24h', 0)
        })
    sorted_results = sorted(output_df_data, key=lambda x: x["_sort_key"], reverse=True)
    return pd.DataFrame(sorted_results).drop(columns=['_sort_key'])

# --- 4. DEFINE THE GRADIO INTERFACE (51-Condition Version) ---
# Create the list of all labels for the "Ignore" box
vitals_labels = [ "Systolic BP", "Diastolic BP", "Heart Rate", "Respiratory Rate", "Temperature (C)", "Pain Score (0-10)", "Height (m)", "Weight (kg)", "SpO2"]
core_labs_labels = ["WBC Count", "Hemoglobin", "Hematocrit", "Platelet Count", "Glucose", "Creatinine", "BUN", "Sodium", "Potassium", "Bilirubin"]
lipid_panel_labels = ["Total Cholesterol", "Triglycerides", "LDL Cholesterol", "HDL Cholesterol"]
# --- NEW LABS ---
new_labs_labels = [
    "Troponin I", "CRP", "Lactate", "Procalcitonin", "D-dimer", "Urine WBC", "Urine Nitrite",
    "Urine Leukocyte", "Stool Occult Blood", "Albumin", "ALT", "AST", "TSH", "Free T4",
    "Lipase", "Amylase", "INR", "PTT", "Magnesium", "Phosphate", "Rheumatoid Factor"
]
scores_labels = ["HbA1c", "PHQ-2 Score", "GAD-7 Score"]
numerical_inputs_labels = vitals_labels + core_labs_labels + lipid_panel_labels + new_labs_labels + scores_labels

with gr.Blocks() as demo:
    gr.Markdown("# Patient Multi-Timeframe Risk Prediction (51 Conditions)")
    gr.Markdown("Enter patient data. Check a box in 'Ignore Inputs' to use an imputed value (the dataset mean) for any field you don't have data for.")

    with gr.Row():
        ignore_box = gr.CheckboxGroup(choices=numerical_inputs_labels, label="Ignore Inputs", scale=1)

        with gr.Column(scale=2):
            with gr.Row():
                age = gr.Number(label="Age")
                gender = gr.Radio(label="Gender", choices=["male", "female"], value="male")

            with gr.Accordion("Vitals & Measurements", open=False):
                with gr.Row():
                    sbp = gr.Number(label="Systolic BP")
                    dbp = gr.Number(label="Diastolic BP")
                    heart_rate = gr.Number(label="Heart Rate")
                    respiratory_rate = gr.Number(label="Respiratory Rate")
                with gr.Row():
                    temperature = gr.Number(label="Temperature (C)")
                    spo2 = gr.Number(label="SpO2")
                    pain_score = gr.Number(label="Pain Score (0-10)")
                with gr.Row():
                    height_m = gr.Number(label="Height (m)")
                    weight_kg = gr.Number(label="Weight (kg)")

            with gr.Accordion("Core Labs", open=False):
                with gr.Row():
                    wbc_count = gr.Number(label="WBC Count")
                    hemoglobin = gr.Number(label="Hemoglobin")
                    hematocrit = gr.Number(label="Hematocrit")
                    platelet_count = gr.Number(label="Platelet Count")
                with gr.Row():
                    glucose = gr.Number(label="Glucose")
                    creatinine = gr.Number(label="Creatinine")
                    bun = gr.Number(label="BUN")
                with gr.Row():
                    sodium = gr.Number(label="Sodium")
                    potassium = gr.Number(label="Potassium")
                    bilirubin = gr.Number(label="Bilirubin")

            # --- NEW LABS SECTION ---
            with gr.Accordion("Specialty Labs (Cardiac, Inflammatory, Coag, ...)", open=False):
                with gr.Row():
                    troponin_i = gr.Number(label="Troponin I")
                    crp = gr.Number(label="CRP")
                    lactate = gr.Number(label="Lactate")
                    procalcitonin = gr.Number(label="Procalcitonin")
                with gr.Row():
                    d_dimer = gr.Number(label="D-dimer")
                    inr = gr.Number(label="INR")
                    ptt = gr.Number(label="PTT")
                with gr.Row():
                    albumin = gr.Number(label="Albumin")
                    alt = gr.Number(label="ALT")
                    ast = gr.Number(label="AST")
                with gr.Row():
                    lipase = gr.Number(label="Lipase")
                    amylase = gr.Number(label="Amylase")
                    rheumatoid_factor = gr.Number(label="Rheumatoid Factor")
                with gr.Row():
                    magnesium = gr.Number(label="Magnesium")
                    phosphate = gr.Number(label="Phosphate")

            with gr.Accordion("Lipid Panel, Thyroid & Urinalysis", open=False):
                with gr.Row():
                    cholesterol_total = gr.Number(label="Total Cholesterol")
                    triglycerides = gr.Number(label="Triglycerides")
                    cholesterol_ldl = gr.Number(label="LDL Cholesterol")
                    cholesterol_hdl = gr.Number(label="HDL Cholesterol")
                with gr.Row():
                    tsh = gr.Number(label="TSH")
                    free_t4 = gr.Number(label="Free T4")
                    stool_occult_blood = gr.Number(label="Stool Occult Blood")
                with gr.Row():
                    urine_wbc = gr.Number(label="Urine WBC")
                    urine_nitrite = gr.Number(label="Urine Nitrite")
                    urine_leukocyte = gr.Number(label="Urine Leukocyte")

            with gr.Accordion("Scores & History", open=False):
                with gr.Row():
                    hba1c = gr.Number(label="HbA1c")
                    phq2_score = gr.Number(label="PHQ-2 Score")
                    gad7_score = gr.Number(label="GAD-7 Score")

                # --- NEW HISTORY/MEDS ---
                medical_history = gr.CheckboxGroup(label="Medical History", choices=[
                    "Hypertension", "Diabetes", "CHF", "COPD", "Anemia", "Hyperlipidemia", "CAD", "CKD", "Asthma",
                    "Atrial Fibrillation", "PE/DVT History", "Thyroid Disease", "Gout", "Arthritis",
                    "Pancreatitis", "GI Bleed", "IBD (Crohn's/Colitis)", "Cellulitis", "UTI", "Seizure", "Dementia"
                ], value=[]) # Set default to empty list
                surgical_history = gr.CheckboxGroup(label="Surgical/Procedure History", choices=[
                    "CABG", "Appendectomy", "Colonoscopy", "EGD", "Dialysis"
                ], value=[])
                current_meds = gr.CheckboxGroup(label="Current Medications (Common)", choices=[
                    "Statin", "Metformin", "Aspirin", "Lisinopril (ACE-I)", "Insulin"
                ], value=[])
                diuretic_meds = gr.CheckboxGroup(label="Current Medications (Diuretics)", choices=[
                    "Furosemide", "Hydrochlorothiazide", "Spironolactone"
                ], value=[])
                other_meds = gr.CheckboxGroup(label="Current Medications (Other)", choices=[
                    "Anticoagulant (Warfarin/Eliquis/Xarelto)", "Beta Blocker (Metoprolol/Atenolol)",
                    "Calcium Channel Blocker (Amlodipine)", "PPI (Omeprazole/Pantoprazole)",
                    "Levothyroxine", "Allopurinol"
                ], value=[])

    submit_btn = gr.Button("Predict Risk")
    output = gr.Dataframe(headers=["Condition", "6-Hour Risk", "24-Hour Risk", "48-Hour Risk"], label="Predicted Condition Risks", wrap=True, row_count=51)
    # --- UPDATED all_inputs list ---
    all_inputs = [
        ignore_box, age, gender,
        sbp, dbp, heart_rate, respiratory_rate, temperature, spo2, pain_score, height_m, weight_kg,
        wbc_count, hemoglobin, hematocrit, platelet_count, glucose, creatinine, bun, sodium, potassium, bilirubin,
        troponin_i, crp, lactate, procalcitonin, d_dimer, inr, ptt, albumin, alt, ast, lipase, amylase, rheumatoid_factor, magnesium, phosphate,
        cholesterol_total, triglycerides, cholesterol_ldl, cholesterol_hdl,
        tsh, free_t4, stool_occult_blood, urine_wbc, urine_nitrite, urine_leukocyte,
        hba1c, phq2_score, gad7_score,
        medical_history, surgical_history, current_meds, diuretic_meds, other_meds
    ]

    submit_btn.click(fn=predict_risk_multitime, inputs=all_inputs, outputs=output)

# --- 5. LAUNCH THE GUI ---
if model:
    demo.launch(share=True, debug=True)
else:
    print("Gradio app not launched because the model failed to load.")
