### Visual Predictive Check (VPC) and Sensitivity Analysis

Purpose:
- To assess goodness-of-fit and predictive performance of the mixed Hidden Markov Model (mHMM).
- To perform **Visual Predictive Checks (VPCs)** by simulating new datasets from the fitted model and comparing observed vs simulated summaries.
- To conduct **sensitivity analyses** by perturbing key parameters (variances, correlations, transition probabilities) to evaluate model robustness.


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
import seaborn as sns 
from pathlib import Path 
from tqdm.auto import tqdm
import os 

import sys
sys.path.insert(0, str(Path('..').resolve()))
from mHMM.src.emissions import EmissionModel  
from mHMM.src.transitions import TransitionModel

BASE_DIR = Path('..')
DATA_DIR = BASE_DIR / 'data'
RESULTS_DIR = DATA_DIR / 'results' / 'vpc_sensitivity'
os.makedirs(RESULTS_DIR, exist_ok=True)

print("VPC + sensitivity framework is up")

Load fitted population Parameters

In [None]:
fit_summary_path = BASE_DIR / 'data' / 'results' / 'summary'
posterior_file = sorted(fit_summary_path.glob("posterior_summary_*.csv"))[-1]
post_summary = pd.read_csv(posterior_file, index_col=0)

#extract posterior mean estimates
theta_hat = post_summary['mean'].to_dict()
print('Loading population parameter estimates:')
for k,v in list(theta_hat.items())[:10]:
    print(f"{k}: {v:.3f}")  

Visual Predictive Check (VPC):

-> Simulate new datasets using the fitted paramters, then compare median andpercentile bands of FEV1/PRo trajectories vs observed data

In [None]:
def simulate_vpc_dataset(em_params, trans_params, N_subj=100, T_weeks=60, seed=123):
    rng = np.random.default_rng(seed)
    em = EmissionModel(**em_params)
    tm = TransitionModel(**trans_params)
    trans_mat = tm.transition_matrix()
    init_probs = np.array([0.9, 0.1])

    data = []
    for sid in range(1, N_subj + 1):
        g = em.sample_individual_effects(rng)
        states = np.zeros(T_weeks, dtype=int)
        states[0] = rng.choice([0,1], p=init_probs)
        for t in range(1, T_weeks):

            states[t] = rng.choice([0,1], p=trans_mat[states[t-1], :])
            for t in range(T_weeks):
                mu = np.array([em.individual_fev1(g, states[t]), em.individual_pro(g, t, states[t])])
                cov = em.emission_cov(states[t])
                y = rng.multivariate_normal(mu, cov)
                data.append({'ID':sid, 'Week':t, 'FEV1': y[0], 'PRO':y[1]})

    return pd.DataFrame(data)

#simulate multiple replicates
N_REP =200
T_WEEKS = 60
em_params = {k: theta_hat[k] for k in ["hFEV1R","hFEV1E","x2_FEV1R","x2_FEV1E","hPROR",
                                       "hPROE","x2_PROR","x2_PROE","r2_FEV1","r2_PRO",
                                       "qR","qE","PE","PHL"]}
trans_params = {'hpRE':theta_hat['hpRE'], 'hpER': theta_hat['hpER']}

replicates = [simulate_vpc_dataset(em_params, trans_params,N_subj=100, T_weeks=T_WEEKS, seed=i)
               for i in tqdm(range(N_REP), desc="Simulating VPC datasets")]

obs_path = DATA_DIR / 'simulated' / 'ref_scenario.csv'
df_obs = pd.read_csv(obs_path)

Compute predictive percentiles and compare to observed

In [None]:
def compute_vpc_summary(df_list, variable='FEV1'):
    df_concat = pd.concat(df_list)
    q = df_concat.groupby("Week")[variable].quantile([0.05, 0.5, 0.95]).unstack()
    q.columns = ["p05","p50", "p95"]
    return q.reset_index() 

vpc_fev1 = compute_vpc_summary(replicates, "FEV1")
vpc_pro = compute_vpc_summary(replicates, "PRO")

obs_fev1 = df_obs.groupby("Week")['FEV1'].median()
obs_pro = df_obs.groupby("Week")["PRO"].median() 

Plot VPC results

In [None]:
plt.figure(figsize=(8,4))
plt.fill_between(vpc_fev1["Week"], vpc_fev1['p05'], vpc_fev1["p95"], alpha=0.3, label="Sim 90% PI")
plt.plot(vpc_fev1['Week'], vpc_fev1['p50'], color='CO', label='Sim median')
plt.plot(obs_fev1.index, obs_fev1, "k--", label='Observed median')
plt.title("Visual Predictive Check - FEV1")
plt.xlabel("Week"); plt.ylabel("FEV1"); plt.legend(); plt.grid(True)
plt.show()

plt.figure(figsize=(8,4))
plt.fill_between(vpc_pro["Week"], vpc_pro["p05"], vpc_pro['p95'], alpha=0.3, color='orange', label='Sim 90% PI')
plt.plot(vpc_pro['Week'], vpc_pro['p50'], color='C1', label='Sim median')
plt.plot(obs_pro.index, obs_pro, "k--", label='Observed median')
plt.title("Visual Predicitve Check - PRO")
plt.xlabel("Week"); plt.ylabel("PRO"); plt.legend(); plt.grid(True)
plt.show() 



Sensitivity Analysis (Paramter perturbations) -;

-> Vary key params such as variance, correlation, transition probabilities +-20% and examine
the effect on model predicitons

In [None]:
def pertub_params(base_params, factor_dict):
    perturbed = base_params.copy()
    for k, f in factor_dict.items():
        if k in perturbed:
            perturbed[k] *= f
    return perturbed 

factors = {
    "r2_FEV1": [0.8, 1.2],
    "r2_PRO": [0.8, 1.2],
    "qR": [0.8, 1.2],
    'qE': [0.8, 1.2]
}

sens_results = []
for param, (low, high) in factors.items():
    for f in [low, high]:
        em_mod = pertub_params(em_params, {param: f})
        sim= simulate_vpc_dataset(em_mod, trans_params, N_subj=100, seed=(int(1000*f))) 
        med_fev1 = sim.groupby("Week"['FEV1']).median().values
        sens_results.append({'param':param,'factor':f, "FEV1_median":med_fev1})

Plot Sensitivity plots

In [None]:
plt.figure(figsize=(8,4))
for param in factors.keys():
    for f in [factors[param][0], factors[param][1]]:
        curve = [s['FEV1_median']for s in sens_results if s['param']==param and s['factor']==f][0] 
        plt.plot(range(T_WEEKS), curve, label =f"{param} x{f}")

plt.xlabel("Week"); plt.ylabel("FEV1 median")
plt.title("Sensitivity of FEV1 to key parameters")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
#SAVE
out_dir = RESULTS_DIR
vpc_fev1.to_csv(out_dir / "vpc_fev1_summary.csv", index=False)
vpc_pro.to_csv(out_dir / "vpc_pro_summary.csv", index=False)
pd.DataFrame(sens_results).to_pickle(out_dir / "Sensitivity") 
print(f"PVC saved to {out_dir}") 

##  Summary

In this notebook we have done the following: 

-> Conducted visual predictive checks using posterior mean parameters to validate model fit

-> Compared observed and simulated distributions of FEV1 and PRO across time 

-> Performed **sensitivity analysis** through pertubation, correlation and adjusting transition parameters

-> Identified parameters most influencing predicitve behavior and variance stability

-> Exported VPC and sensitivity summaries. 