In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.mixture import GaussianMixture
from sklearn.decomposition import PCA
import os

PATIENT_ID = "person_id"

# ------------------------------------------------
# FILES FOR EACH TIME WINDOW
# ------------------------------------------------
files = {
    "0_6":   "patients_0_6hrs.parquet",
    "6_12":  "patients_6_12hrs.parquet",
    "12_18": "patients_12_18hrs.parquet",
    "18_24": "patients_18_24hrs.parquet",
    "24_30": "patients_24_30hrs.parquet",
    "30_36": "patients_30_36hrs.parquet",
    "36_42": "patients_36_42hrs.parquet",
    "42_48": "patients_42_48hrs.parquet",
    "48_54": "patients_48_54hrs.parquet",
    "54_60": "patients_54_60hrs.parquet",
    "60_66": "patients_60_66hrs.parquet",
    "66_72": "patients_66_72hrs.parquet"
}

# ------------------------------------------------
# PHYSIOLOGICALLY PLAUSIBLE RANGES
# ------------------------------------------------
PLAUSIBLE_RANGES = {
    "fio2_mean": (21, 100),
    "peep_mean": (3, 25),
    "peak_mean": (10, 50),
    "pao2_mean": (50, 500),
    "spo2_mean": (70, 100),
    "ph_mean": (7.0, 7.6),
    "crp_mean": (0, 400),
    "wbc_mean": (1, 50),
    "temp_mean": (35, 42),
    "map_mean": (50, 120),
    "sbp_mean": (70, 200),
    "dbp_mean": (40, 120),
    "creatinine_mean": (20, 800),
    "platelets_mean": (20, 800),
    "age": (18, 100)
}

# ------------------------------------------------
# STORAGE
# ------------------------------------------------
patient_level_results = {}
phenotype_summaries = {}
data_quality_summary = []

# Create output directories
os.makedirs("phenotype_summaries", exist_ok=True)
os.makedirs("phenotype_plots", exist_ok=True)

# Color palette
palette = {0: "#1f77b4", 1: "#d62728"}

# ============================================================
# LOOP OVER TIME WINDOWS
# ============================================================
for window, path in files.items():

    print(f"\n{'='*60}")
    print(f"Processing window: {window} hrs")
    print(f"{'='*60}")

    # LOAD DATA
    if path.endswith(".parquet"):
        df = pd.read_parquet(path)
    else:
        df = pd.read_csv(path)
    
    n_rows_original = len(df)
    n_patients_original = df[PATIENT_ID].nunique()
    print(f"Original data: {n_rows_original} rows, {n_patients_original} patients")

    # FILTER OUT ABNORMAL VALUES
    df_clean = df.copy()
    
    for col, (low, high) in PLAUSIBLE_RANGES.items():
        if col in df_clean.columns:
            df_clean[col] = pd.to_numeric(df_clean[col], errors='coerce')
            mask = (df_clean[col].notna()) & ((df_clean[col] < low) | (df_clean[col] > high))
            n_filtered = mask.sum()
            df_clean = df_clean[~mask]
            if n_filtered > 0:
                print(f"  Filtered {n_filtered} rows with abnormal {col} values")
    
    n_rows_clean = len(df_clean)
    n_patients_clean = df_clean[PATIENT_ID].nunique()
    print(f"After filtering: {n_rows_clean} rows ({100*n_rows_clean/n_rows_original:.1f}%), {n_patients_clean} patients")
    
    data_quality_summary.append({
        'window': window,
        'original_rows': n_rows_original,
        'clean_rows': n_rows_clean,
        'rows_removed': n_rows_original - n_rows_clean,
        'original_patients': n_patients_original,
        'clean_patients': n_patients_clean,
        'patients_removed': n_patients_original - n_patients_clean
    })

    # AGGREGATION RULES
    AGG_RULES = {
        "fio2_mean": "max",
        "peep_mean": "max",
        "peak_mean": "max",
        "pao2_mean": "min",
        "spo2_mean": "min",
        "ph_mean": "min",
        "crp_mean": "max",
        "wbc_mean": "max",
        "temp_mean": "max",
        "map_mean": "min",
        "sbp_mean": "min",
        "dbp_mean": "min",
        "creatinine_mean": "max",
        "platelets_mean": "min",
        "age": "first"
    }
    AGG_RULES = {k: v for k, v in AGG_RULES.items() if k in df_clean.columns}

    # AGGREGATE TO PATIENT LEVEL
    df_patient = df_clean.groupby(PATIENT_ID, as_index=False).agg(AGG_RULES)
    print(f"Patient-level data: {len(df_patient)} patients")

    # PREPARE FEATURES FOR CLUSTERING
    X = df_patient[list(AGG_RULES.keys())].apply(pd.to_numeric, errors="coerce")
    missing_before_imputation = X.isna().sum().sum()
    if missing_before_imputation > 0:
        print(f"Missing values before imputation: {missing_before_imputation}")
        X = X.fillna(X.mean())
    
    assert X.isna().sum().sum() == 0, "Still have missing values after imputation!"

    if len(df_patient) < 10:
        print(f"WARNING: Only {len(df_patient)} patients - skipping clustering")
        continue

    # STANDARDIZE
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)

    # GMM PHENOTYPING
    gmm = GaussianMixture(
        n_components=2,
        covariance_type="full",
        random_state=42,
        n_init=10
    )
    df_patient["icu_phenotype"] = gmm.fit_predict(X_scaled)
    
    phenotype_counts = df_patient["icu_phenotype"].value_counts().sort_index()
    print(f"\nPhenotype distribution:")
    for pheno, count in phenotype_counts.items():
        print(f"  Phenotype {pheno}: {count} patients ({100*count/len(df_patient):.1f}%)")

    # PHENOTYPE SUMMARY
    phenotype_summary = (
        df_patient
        .groupby("icu_phenotype")[X.columns]
        .mean()
        .T
    )

    # STORE RESULTS
    patient_level_results[window] = df_patient
    phenotype_summaries[window] = phenotype_summary
    
    # Save summary CSV
    csv_filename = f"phenotype_summaries/phenotype_summary_{window}hrs.csv"
    phenotype_summary.to_csv(csv_filename)
    print(f"Saved summary to: {csv_filename}")

    # ============================================================
    # CREATE PCA CLUSTER PLOT
    # ============================================================
    pca = PCA(n_components=2, random_state=42)
    X_pca = pca.fit_transform(X_scaled)
    
    df_plot = df_patient.copy()
    df_plot["PC1"] = X_pca[:, 0]
    df_plot["PC2"] = X_pca[:, 1]
    
    # Create plot
    sns.set_style("whitegrid")
    fig, ax = plt.subplots(figsize=(10, 7))
    
    sns.scatterplot(
        data=df_plot,
        x="PC1",
        y="PC2",
        hue="icu_phenotype",
        palette=palette,
        s=80,
        alpha=0.75,
        edgecolor="black",
        linewidth=0.5,
        ax=ax
    )
    
    # Title with patient counts
    title = f"GMM ICU Phenotypes: {window} Hours\n"
    title += f"Phenotype 0: n={phenotype_counts.get(0, 0):,} | Phenotype 1: n={phenotype_counts.get(1, 0):,}"
    
    ax.set_title(title, fontsize=14, weight="bold", pad=20)
    ax.set_xlabel(
        f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}% variance)",
        fontsize=12
    )
    ax.set_ylabel(
        f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}% variance)",
        fontsize=12
    )
    ax.legend(
        title="Phenotype",
        title_fontsize=11,
        fontsize=10,
        frameon=True,
        loc='best'
    )
    
    plt.tight_layout()
    plot_filename = f"phenotype_plots/cluster_plot_{window}hrs.png"
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    print(f"Saved plot to: {plot_filename}")
    plt.close()

    print("\nPhenotype Summary:")
    print(phenotype_summary)

# ============================================================
# SAVE DATA QUALITY SUMMARY
# ============================================================
quality_df = pd.DataFrame(data_quality_summary)
quality_df.to_csv("phenotype_summaries/data_quality_summary.csv", index=False)
print(f"\n{'='*60}")
print("Data quality summary saved to: phenotype_summaries/data_quality_summary.csv")

# ============================================================
# SAVE ALL PHENOTYPE SUMMARIES IN ONE FILE
# ============================================================
all_summaries = []
for window, summary in phenotype_summaries.items():
    summary_with_window = summary.copy()
    summary_with_window.insert(0, 'time_window', window)
    summary_with_window['variable'] = summary_with_window.index
    all_summaries.append(summary_with_window)

if all_summaries:
    combined_df = pd.concat(all_summaries, ignore_index=False)
    combined_df.to_csv("phenotype_summaries/all_phenotypes_combined.csv")
    print("Combined summary saved to: phenotype_summaries/all_phenotypes_combined.csv")

# ============================================================
# CREATE COMBINED PLOT (ALL WINDOWS)
# ============================================================
print("\nCreating combined plot with all time windows...")

n_windows = len(patient_level_results)
n_cols = 4
n_rows = (n_windows + n_cols - 1) // n_cols

fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, 5*n_rows))
axes = axes.flatten()

sns.set_style("whitegrid")

for idx, (window, df_patient) in enumerate(sorted(patient_level_results.items())):
    ax = axes[idx]
    
    # Get features
    FEATURE_COLS = [
        c for c in df_patient.columns
        if c not in [PATIENT_ID, "icu_phenotype", "PC1", "PC2"]
    ]
    
    X = df_patient[FEATURE_COLS].apply(pd.to_numeric, errors="coerce")
    X = X.fillna(X.mean())
    X_scaled = StandardScaler().fit_transform(X)
    
    # PCA
    pca = PCA(n_components=2, random_state=42)
    X_pca = pca.fit_transform(X_scaled)
    
    df_plot = df_patient.copy()
    df_plot["PC1"] = X_pca[:, 0]
    df_plot["PC2"] = X_pca[:, 1]
    
    phenotype_counts = df_plot["icu_phenotype"].value_counts().sort_index()
    
    # Plot
    for phenotype in sorted(df_plot["icu_phenotype"].unique()):
        mask = df_plot["icu_phenotype"] == phenotype
        ax.scatter(
            df_plot.loc[mask, "PC1"],
            df_plot.loc[mask, "PC2"],
            c=palette[phenotype],
            s=30,
            alpha=0.6,
            edgecolor="black",
            linewidth=0.3,
            label=f"Pheno {phenotype} (n={phenotype_counts.get(phenotype, 0):,})"
        )
    
    ax.set_title(f"{window} hrs", fontsize=11, weight="bold")
    ax.set_xlabel(f"PC1 ({pca.explained_variance_ratio_[0]*100:.0f}%)", fontsize=9)
    ax.set_ylabel(f"PC2 ({pca.explained_variance_ratio_[1]*100:.0f}%)", fontsize=9)
    ax.legend(fontsize=7, loc='best')
    ax.grid(alpha=0.3)

# Hide unused subplots
for idx in range(n_windows, len(axes)):
    axes[idx].set_visible(False)

plt.suptitle("GMM ICU Phenotypes Across Time Windows", 
             fontsize=16, weight="bold", y=0.995)
plt.tight_layout()

combined_plot_path = "phenotype_plots/combined_cluster_plots.png"
plt.savefig(combined_plot_path, dpi=300, bbox_inches='tight')
print(f"Saved combined plot to: {combined_plot_path}")
plt.close()

print("\n" + "="*60)
print("Processing complete!")
print(f"Individual plots: phenotype_plots/cluster_plot_*hrs.png")
print(f"Combined plot: {combined_plot_path}")
print("="*60)


Processing window: 0_6 hrs
Original data: 512637 rows, 65017 patients
  Filtered 1028 rows with abnormal fio2_mean values
  Filtered 4214 rows with abnormal peep_mean values
  Filtered 458 rows with abnormal peak_mean values
  Filtered 100247 rows with abnormal pao2_mean values
  Filtered 196 rows with abnormal spo2_mean values
  Filtered 810 rows with abnormal ph_mean values
  Filtered 225 rows with abnormal crp_mean values
  Filtered 512 rows with abnormal wbc_mean values
  Filtered 25546 rows with abnormal temp_mean values
  Filtered 13826 rows with abnormal map_mean values
  Filtered 1975 rows with abnormal sbp_mean values
  Filtered 129 rows with abnormal dbp_mean values
  Filtered 266 rows with abnormal creatinine_mean values
  Filtered 251 rows with abnormal platelets_mean values
After filtering: 362954 rows (70.8%), 62378 patients
Patient-level data: 62378 patients
Missing values before imputation: 462129

Phenotype distribution:
  Phenotype 0: 4812 patients (7.7%)
  Phenotype