In [None]:
import pandas as pd
import os
from scipy.stats import pearsonr
from statsmodels.miscmodels.ordinal_model import OrderedModel
import seaborn as sns
import matplotlib.pyplot as plt
from utils import create_registry_case_identification_column, create_ehr_case_identification_column, patient_selection
from utils import load_data_from_main_dir
from lab_preprocessing import preprocess_labs
from outcome_preprocessing import preprocess_outcomes


In [None]:
eds_path = '/Users/jk1/stroke_datasets/stroke_unit_dataset/per_value/Extraction_20221117/eds_j1.csv'
ehr_data_path = '/Users/jk1/stroke_datasets/stroke_unit_dataset/per_value/Extraction_20221117/'
registry_path = '/Users/jk1/Library/CloudStorage/OneDrive-unige.ch/stroke_research/geneva_stroke_unit_dataset/data/stroke_registry/post_hoc_modified/stroke_registry_post_hoc_modified.xlsx'

In [None]:
eds_df = pd.read_csv(eds_path, delimiter=';', encoding='utf-8',
                         dtype=str)
registry_df = pd.read_excel(registry_path, dtype=str)


In [None]:
registry_df['case_admission_id'] = create_registry_case_identification_column(registry_df)
eds_df['case_admission_id'] = create_ehr_case_identification_column(eds_df)

In [None]:
eds_df.head()

In [None]:
inclusion_registry_df, excluded_patients_df = patient_selection(
    registry_path=registry_path,
    eds_path=eds_path,
    exclude_patients_under_18=True,
    exclude_non_ischemic_stroke=True,
    exclude_non_acute_stroke=True,
    verbose=True
)

In [None]:
eds_df['case_admission_id'] = eds_df[eds_df['case_admission_id'].isin(inclusion_registry_df['case_admission_id'])]['case_admission_id']
print(f'Number of patients in EDS after selection: {eds_df.patient_id.nunique()}')

In [None]:
registry_df.case_admission_id.nunique()

In [None]:
lab_file_start = 'labo'
lab_df = load_data_from_main_dir(ehr_data_path, lab_file_start)

In [None]:
lab_df['case_admission_id'] = create_ehr_case_identification_column(lab_df)

In [None]:
preprocessed_lactate_df = preprocess_labs(lab_df, ["lactate"])

In [None]:
preprocessed_lactate_df = preprocessed_lactate_df[preprocessed_lactate_df['case_admission_id'].isin(inclusion_registry_df['case_admission_id'])]

In [None]:
preprocessed_lactate_df.head()

In [None]:
preprocessed_lactate_df.unit_of_measure.unique()

In [None]:
preprocessed_lactate_df.value.hist(bins=100)

In [None]:
inclusion_registry_df.head()

In [None]:
inclusion_registry_df['stroke_dt'].isna().sum(), inclusion_registry_df['arrival_dt'].isna().sum()

In [None]:
inclusion_registry_df['T0'] = inclusion_registry_df['stroke_dt'].fillna(inclusion_registry_df['arrival_dt'])
preprocessed_lactate_df = preprocessed_lactate_df.merge(
    inclusion_registry_df[['case_admission_id', 'T0']],
    on='case_admission_id',
    how='left'
)

In [None]:
dt_format = '%d.%m.%Y %H:%M'
preprocessed_lactate_df['relative_sample_date'] = (pd.to_datetime(preprocessed_lactate_df['sample_date'], format=dt_format) - pd.to_datetime(preprocessed_lactate_df['T0'], format=dt_format)).dt.total_seconds() / 3600 # convert to hours


In [None]:
preprocessed_lactate_df['sample_date'].values

In [None]:
import numpy as np
preprocessed_lactate_df['relative_sample_date_hcat'] = preprocessed_lactate_df['relative_sample_date'].apply(np.floor)

In [None]:
preprocessed_lactate_df[['T0', 'sample_date', 'relative_sample_date']]

In [None]:

# sns.set(style="whitegrid")
# plt.figure(figsize=(10, 6))
# ax = sns.lineplot(x='relative_sample_date_hcat', y='value', data=preprocessed_lactate_df)

# ax.set_xlim(-24, 7*24)

In [None]:
preprocessed_lactate_df.case_admission_id.nunique()

In [None]:
n_patients_with_lactate_in_first_24h = preprocessed_lactate_df[(preprocessed_lactate_df.relative_sample_date > -12) & (preprocessed_lactate_df.relative_sample_date < 24)].case_admission_id.nunique()
n_patients_with_lactate_in_24_to_72h = preprocessed_lactate_df[(preprocessed_lactate_df.relative_sample_date > 24) & (preprocessed_lactate_df.relative_sample_date < 3*24)].case_admission_id.nunique()

print(f'Number of patients with lactate in first 24h: {n_patients_with_lactate_in_first_24h}')
print(f'Number of patients with lactate in 24 to 72h: {n_patients_with_lactate_in_24_to_72h}')

In [None]:
outcome_df = preprocess_outcomes(registry_path)
outcome_df = outcome_df[outcome_df.case_admission_id.isin(inclusion_registry_df.case_admission_id.unique())]
outcome_df.drop_duplicates(subset='case_admission_id', keep='first', inplace=True)

In [None]:
preprocessed_lactate_df = preprocessed_lactate_df.merge(
    outcome_df[['case_admission_id', '3M mRS']],
    on='case_admission_id',
    how='left'
)

In [None]:
early_lactate_df = preprocessed_lactate_df[(preprocessed_lactate_df.relative_sample_date > -12) & (preprocessed_lactate_df.relative_sample_date < 24)]
lactate_d2_df = preprocessed_lactate_df[(preprocessed_lactate_df.relative_sample_date > 24) & (preprocessed_lactate_df.relative_sample_date < 2*72)]
lactate_d3_df = preprocessed_lactate_df[(preprocessed_lactate_df.relative_sample_date > 2*24) & (preprocessed_lactate_df.relative_sample_date < 3*72)]
lactate_d_2_3_df = preprocessed_lactate_df[(preprocessed_lactate_df.relative_sample_date > 1*24) & (preprocessed_lactate_df.relative_sample_date < 3*72)]

## Lactate trajectories
Group-based trajectory modeling (GBTM) analysis of lactate over time to evalutate the predictive potential of dynamic lactate trajectories for all-cause mortality in patients with ischemic stroke


## Group-based trajectory modeling (GBTM)
We fit finite mixtures of quadratic lactate trajectories (time in hours from T0) to identify latent classes, then test whether class membership associates with 3-month mortality.

In [None]:
# Prepare data for GBTM
# Merge mortality outcome and restrict to a clinically meaningful window around T0
timing_lower_bound = -12
timing_upper_bound = 72

analysis_window = preprocessed_lactate_df[(preprocessed_lactate_df['relative_sample_date'] > timing_lower_bound) &
                                          (preprocessed_lactate_df['relative_sample_date'] < timing_upper_bound)].copy()

# Ensure mortality outcome is available
if '3M Death' not in analysis_window.columns:
    analysis_window = analysis_window.merge(
        outcome_df[['case_admission_id', '3M Death']],
        on='case_admission_id',
        how='left'
    )

analysis_window['3M Death'] = pd.to_numeric(analysis_window['3M Death'], errors='coerce')
print(f"Patients with lactate in window: {analysis_window.case_admission_id.nunique()}")
print(f"Patients with mortality label: {analysis_window.dropna(subset=['3M Death']).case_admission_id.nunique()}")

In [None]:
# Fit patient-level quadratic trajectories and mixture model to identify latent classes
import numpy as np
from sklearn.mixture import GaussianMixture
from sklearn.preprocessing import StandardScaler

# Toggle to use bayes_traj (Dirichlet process mixture) instead of sklearn GaussianMixture
use_bayes_traj = False
bayes_traj_k = 2
bayes_traj_iters = 300

# Set manual_k to an integer (e.g., 2, 3, 4, 5) to override automatic selection; keep None for auto
manual_k = None

# Force trajectories to be linear (no quadratic term) when True
only_use_linear_trajectories = True

# Derive per-patient quadratic coefficients (falls back to lower order if few points)
def fit_quadratic(group: pd.DataFrame) -> pd.Series:
    t = group['relative_sample_date'].values
    y = group['value'].values
    try:
        if only_use_linear_trajectories:
            if len(t) >= 2:
                coef = np.polyfit(t, y, 1)
                coef = np.array([0.0, coef[0], coef[1]])
            else:
                coef = np.array([0.0, 0.0, np.mean(y)])
        else:
            if len(t) >= 3:
                coef = np.polyfit(t, y, 2)
            elif len(t) == 2:
                coef = np.polyfit(t, y, 1)
                coef = np.array([0.0, coef[0], coef[1]])
            else:
                coef = np.array([0.0, 0.0, np.mean(y)])
    except Exception:
        coef = np.array([0.0, 0.0, np.mean(y)])
    return pd.Series({
        'beta2': coef[0],
        'beta1': coef[1],
        'beta0': coef[2],
        'n_obs': len(t),
        't_span': t.max() - t.min() if len(t) > 1 else 0.0
    })

traj_features = analysis_window.groupby('case_admission_id').apply(fit_quadratic).reset_index()
feature_cols = ['beta1', 'beta0'] if only_use_linear_trajectories else ['beta2', 'beta1', 'beta0']

# Winsorize coefficients to limit extreme outliers that can create single-point classes
q_low = traj_features[feature_cols].quantile(0.01)
q_high = traj_features[feature_cols].quantile(0.99)
X = traj_features[feature_cols].clip(lower=q_low, upper=q_high, axis=1)

scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

if use_bayes_traj:
    try:
        from bayes_traj.mult_dp_regression import MultDPRegression
        from bayes_traj import generate_prior
        from bayes_traj.fit_stats import compute_waic2
    except ImportError as e:
        raise ImportError("bayes_traj not installed; run `pip install bayes-traj` in the project venv.") from e

    # Build design matrix for bayes_traj (intercept, time, optional time^2)
    design_df = analysis_window[['case_admission_id', 'relative_sample_date', 'value']].dropna().copy()
    design_df = design_df.rename(columns={'value': 'lactate_value'})
    design_df['intercept'] = 1.0
    design_df['time'] = design_df['relative_sample_date']
    if not only_use_linear_trajectories:
        design_df['time_sq'] = design_df['relative_sample_date'] ** 2
    pred_cols = ['intercept', 'time']
    if not only_use_linear_trajectories:
        pred_cols.append('time_sq')
    target_col = 'lactate_value'

    # Data-driven priors for bayes_traj (falls back to OLS-based prior if helper not available)
    prior_info = {
        'w_mu0': {target_col: {p: 0.0 for p in pred_cols}},
        'w_var0': {target_col: {p: 1.0 for p in pred_cols}},
        'lambda_a0': {target_col: 1.0},
        'lambda_b0': {target_col: 1.0},
    }
    num_trajs = np.array([max(1, bayes_traj_k - 1), bayes_traj_k + 1])
    try:
        prior_info = generate_prior.prior_info_from_df_gaussian(design_df, target_col, pred_cols, num_trajs, prior_info)
    except AttributeError:
        import statsmodels.api as sm
        res_tmp = sm.OLS(design_df[target_col], design_df[pred_cols], missing='drop').fit()
        fudge = 0.002
        vars_est = fudge * design_df.shape[0] * res_tmp.HC0_se.values ** 2
        for (i, p) in enumerate(pred_cols):
            prior_info['w_mu0'][target_col][p] = res_tmp.params.values[i]
            prior_info['w_var0'][target_col][p] = vars_est[i]
        prec_mean = 1.0 / max(np.var(res_tmp.resid.values), 1e-8)
        prec_var = (prec_mean * 0.5) ** 2
        prior_info['lambda_b0'][target_col] = prec_mean / prec_var
        prior_info['lambda_a0'][target_col] = prec_mean ** 2 / prec_var

    # If prior_info generation failed or missed keys, fall back to simple defaults
    if prior_info is None:
        prior_info = {
            'w_mu0': {target_col: {p: 0.0 for p in pred_cols}},
            'w_var0': {target_col: {p: 1.0 for p in pred_cols}},
            'lambda_a0': {target_col: 1.0},
            'lambda_b0': {target_col: 1.0},
        }
    else:
        for p in pred_cols:
            prior_info.setdefault('w_mu0', {}).setdefault(target_col, {}).setdefault(p, 0.0)
            prior_info.setdefault('w_var0', {}).setdefault(target_col, {}).setdefault(p, 1.0)
        prior_info.setdefault('lambda_a0', {}).setdefault(target_col, 1.0)
        prior_info.setdefault('lambda_b0', {}).setdefault(target_col, 1.0)

    alpha = bayes_traj_k / np.log10(max(design_df['case_admission_id'].nunique(), 2))

    w_mu0 = np.array([prior_info['w_mu0'][target_col][p] for p in pred_cols]).reshape(len(pred_cols), 1)
    w_var0 = np.array([prior_info['w_var0'][target_col][p] for p in pred_cols]).reshape(len(pred_cols), 1)
    lambda_a0 = np.array([prior_info['lambda_a0'][target_col]])
    lambda_b0 = np.array([prior_info['lambda_b0'][target_col]])

    mm = MultDPRegression(w_mu0, w_var0, lambda_a0, lambda_b0, 0.25, alpha, K=bayes_traj_k, prob_thresh=0.001)
    mm.fit(target_names=[target_col], predictor_names=pred_cols, df=design_df,
           groupby='case_admission_id', iters=bayes_traj_iters, verbose=False)

    r_cols = [f"traj_{i}" for i in range(bayes_traj_k)]
    r_df = pd.DataFrame(mm.R_, columns=r_cols)
    r_df['case_admission_id'] = design_df['case_admission_id'].values
    patient_probs = r_df.groupby('case_admission_id')[r_cols].mean()
    patient_probs['traj_class'] = patient_probs[r_cols].idxmax(axis=1).str.replace('traj_', '', regex=False).astype(int) + 1
    patient_probs['traj_class_prob'] = patient_probs[r_cols].max(axis=1)

    min_frac = patient_probs['traj_class'].value_counts().min() / patient_probs.shape[0]
    entropy = 1 + (patient_probs[r_cols].values * np.log(patient_probs[r_cols].values + 1e-12)).sum() / (patient_probs.shape[0] * np.log(bayes_traj_k))
    waic2 = compute_waic2(mm)
    n_obs = design_df.shape[0]
    n_params = bayes_traj_k * len(pred_cols) + bayes_traj_k  # coefficients + variance terms
    bic = np.nan
    try:
        log_like_attr = next((attr for attr in ['log_like_', 'log_like', 'loglik_', 'loglik', 'log_likelihood_', 'log_likelihood'] if hasattr(mm, attr)), None)
        if log_like_attr is not None:
            log_like_vals = np.atleast_1d(getattr(mm, log_like_attr))
            log_like = float(np.ravel(log_like_vals)[-1])
            bic = -2.0 * log_like + n_params * np.log(max(n_obs, 1))
    except Exception:
        bic = np.nan
    # If bayes_traj does not expose log-likelihood, approximate BIC from WAIC2 (-2 * elpd) plus penalty
    if np.isnan(bic):
        bic = waic2 + n_params * np.log(max(n_obs, 1))

    results_df = pd.DataFrame([{
        'k': bayes_traj_k,
        'bic': bic,
        'waic2': waic2,
        'entropy': entropy,
        'min_frac': min_frac
    }])

    traj_features = traj_features.merge(patient_probs.reset_index()[['case_admission_id', 'traj_class', 'traj_class_prob']],
                                        on='case_admission_id', how='left')
    best_k = bayes_traj_k

    print('bayes_traj fit completed (Dirichlet process mixture).')
    print(results_df)
    print(traj_features['traj_class'].value_counts().sort_index())
else:
    results = []
    models = {}
    for k in range(2, 6):
        gm = GaussianMixture(n_components=k, covariance_type='full', random_state=0, n_init=10, reg_covar=1e-4)
        gm.fit(X_scaled)
        proba = gm.predict_proba(X_scaled)
        pred = proba.argmax(axis=1)
        n = len(pred)
        min_frac = np.bincount(pred, minlength=k).min() / n
        entropy = 1 + (proba * np.log(proba + 1e-12)).sum() / (n * np.log(k))
        bic = gm.bic(X_scaled)
        results.append({
            'k': k,
            'bic': bic,
            'entropy': entropy,
            'min_frac': min_frac
        })
        models[k] = gm

    results_df = pd.DataFrame(results).sort_values('bic')
    candidates = results_df[(results_df['entropy'] > 0.7) & (results_df['min_frac'] >= 0.05)].sort_values('bic')

    if manual_k is not None:
        if manual_k not in models:
            raise ValueError(f"manual_k={manual_k} not in fitted range 2-5")
        best_row = results_df[results_df['k'] == manual_k].iloc[0]
        print(f'Manual override to k={manual_k}. Constraints not enforced for selection; metrics for chosen k shown below.')
    elif not candidates.empty:
        best_row = candidates.iloc[0]
        print('Selected by constraints (entropy>0.7 & min_frac>=0.05) and lowest BIC among them.')
    else:
        best_row = results_df.iloc[0]
        print('No model met both entropy > 0.7 and min group size >= 5%. Falling back to overall lowest BIC.')

    best_k = int(best_row['k'])
    best_model = models[best_k]
    proba = best_model.predict_proba(X_scaled)
    traj_features['traj_class'] = proba.argmax(axis=1) + 1  # 1-based class labels
    traj_features['traj_class_prob'] = proba.max(axis=1)

    print('Model selection table (sorted by BIC):')
    print(results_df)
    print(f"Selected {best_k} classes (entropy>{best_row['entropy']:.3f}, min_frac>{best_row['min_frac']:.3f}, bic={best_row['bic']:.1f}).")
    print(traj_features['traj_class'].value_counts().sort_index())

In [None]:
# Visualize observed and class-average trajectories
analysis_window = analysis_window.drop(columns=[c for c in analysis_window.columns if c.startswith('traj_class')], errors='ignore')
analysis_window = analysis_window.merge(traj_features[['case_admission_id', 'traj_class']],
                                        on='case_admission_id', how='left')

time_grid = np.linspace(-12, 72, 60)
coef_cols = ['beta2', 'beta1', 'beta0']
class_curves = []
for c in sorted(traj_features.traj_class.unique()):
    coefs = traj_features[traj_features.traj_class == c][coef_cols].mean()
    y_hat = coefs['beta2'] * time_grid ** 2 + coefs['beta1'] * time_grid + coefs['beta0']
    class_curves.append(pd.DataFrame({
        'relative_sample_date': time_grid,
        'lactate_pred': y_hat,
        'traj_class': c
    }))

plot_df = pd.concat(class_curves, ignore_index=True)

plt.figure(figsize=(8, 5))
# sns.lineplot(data=analysis_window, x='relative_sample_date', y='value', hue='traj_class',
#              palette='tab10', estimator='median', errorbar=('ci', 95), alpha=0.3)

sns.scatterplot(data=analysis_window, x='relative_sample_date', y='value', hue='traj_class',
             palette='tab10', alpha=0.1)
sns.lineplot(data=plot_df, x='relative_sample_date', y='lactate_pred', hue='traj_class',
             palette='tab10', linewidth=3, legend=False)
plt.axvline(0, color='k', linestyle='--', linewidth=1)
plt.xlabel('Hours from T0')
plt.ylabel('Lactate (mmol/L)')
plt.title('Lactate trajectories by latent class')
plt.xlim(-12, 72)
plt.ylim(0,10)
plt.tight_layout()
plt.show()

In [None]:
# Association between trajectory class and 3-month mortality
import statsmodels.api as sm

mortality_df = traj_features.merge(outcome_df[['case_admission_id', '3M Death']], on='case_admission_id', how='left')
mortality_df['3M Death'] = pd.to_numeric(mortality_df['3M Death'], errors='coerce')
mortality_df = mortality_df.dropna(subset=['3M Death']).copy()
mortality_df['3M Death'] = mortality_df['3M Death'].astype(float)
mortality_df = mortality_df[mortality_df['3M Death'].isin([0.0, 1.0])]

# Class-level mortality rates
class_mortality = mortality_df.groupby('traj_class')['3M Death'].agg(['count', 'mean']).rename(columns={'mean': 'mortality_rate'})
print(class_mortality)

# Logistic regression with class dummies (reference = class 1)
mortality_df = pd.get_dummies(mortality_df, columns=['traj_class'], drop_first=True, dtype=float)
X_cols = [c for c in mortality_df.columns if c.startswith('traj_class_')]
X = sm.add_constant(mortality_df[X_cols].astype(float))
logit_model = sm.Logit(mortality_df['3M Death'], X).fit(disp=False)
print(logit_model.summary())

or_vals = np.exp(logit_model.params)
ci = np.exp(logit_model.conf_int())
pvals = logit_model.pvalues
or_table = pd.DataFrame({'OR': or_vals, 'CI_lower': ci[0], 'CI_upper': ci[1], 'p_value': pvals})
print('\nOdds ratios (trajectory classes vs class 1 reference):')
print(or_table)

In [None]:
# Association between trajectory class and 3-month mRS
mrs_df = traj_features.merge(outcome_df[['case_admission_id', '3M mRS']], on='case_admission_id', how='left')
mrs_df['3M mRS'] = pd.to_numeric(mrs_df['3M mRS'], errors='coerce')
mrs_df = mrs_df.dropna(subset=['3M mRS']).copy()
mrs_df['traj_class'] = mrs_df['traj_class'].astype(int)

# Class-level distribution of mRS scores
mrs_counts = mrs_df.groupby('traj_class')['3M mRS'].value_counts().unstack(fill_value=0)
mrs_summary = mrs_df.groupby('traj_class')['3M mRS'].agg(['count', 'median', 'mean'])
print('mRS counts by trajectory class (rows=class, columns=mRS score):')
print(mrs_counts)
print('\nClass-level mRS summary:')
print(mrs_summary)

# Ordinal logistic regression (proportional odds) with class 1 as reference
exog = pd.get_dummies(mrs_df['traj_class'], prefix='traj_class', drop_first=True)
ord_model = OrderedModel(mrs_df['3M mRS'], exog, distr='logit')
ord_res = ord_model.fit(method='bfgs', disp=False)
print(ord_res.summary())

# Extract odds ratios for trajectory classes (thresholds omitted)
coef_idx = [c for c in ord_res.params.index if c.startswith('traj_class_')]
or_vals = np.exp(ord_res.params[coef_idx])
ci = np.exp(ord_res.conf_int().loc[coef_idx])
or_table = pd.DataFrame({'OR': or_vals, 'CI_lower': ci[0], 'CI_upper': ci[1]})
print('\nOdds ratios (trajectory classes vs class 1 reference):')
print(or_table)

## Model selection note
- Lower (more negative) BIC is better.
- Constrained candidates (entropy > 0.7, min group size ≥ 5%) currently: k=2 (BIC≈-12016) and k=3 (BIC≈-14121). The 3-class model has the lower BIC, so it is selected.
- Models with k=4 or 5 have lower BIC but fail the 5% minimum group size criterion.

## Methods summary (stepwise)
1. Cohort assembly: Linked EDS labs to registry via `case_admission_id`; excluded non-ischemic, non-acute, under 18, intra-hospital strokes, and refusals. Added T0 (stroke or arrival time).
2. Lactate preprocessing: Loaded all lab files, harmonized labels/units, removed implausible/non-numeric values, restricted to lactate. Computed relative sample time (hours from T0) and limited analysis window to -12 to +72 hours.
3. Outcome derivation: From registry, derived 3-month mortality (`3M Death`), setting hospital deaths to mRS=6 and binarizing death indicators.
4. Feature construction: For each admission, fit patient-level lactate trajectories (quadratic if ≥3 points; linear if 2; mean if 1) to obtain coefficients `beta2`, `beta1`, `beta0`; winsorized coefficients at 1st/99th percentiles and standardized them.
5. Mixture modeling (GBTM proxy): Default uses Gaussian mixtures (k=2..5) on standardized coefficients with constraint filter (entropy>0.7, min class size≥5%), lowest-BIC candidate selected unless overridden by `manual_k`; optional `use_bayes_traj=True` fits a Dirichlet-process regression mixture (via `bayes_traj`) on raw trajectories with data-driven priors and reports WAIC2/entropy/min_frac.
6. Class assignment: Assigned each admission to the class with highest posterior probability (`traj_class`) and stored maximum posterior as `traj_class_prob` (classification certainty).
7. Visualization: Plotted observed lactate by time (median with CI) and class-mean quadratic curves over -12 to +72h.
8. Mortality association: Merged class labels with outcomes; summarized class-specific mortality; ran logistic regression (reference = class 1) with class dummies to estimate odds ratios and 95% CIs.
9. Reporting: Printed the model selection table (BIC, entropy, min_frac) and class counts; odds-ratio table for mortality associations.