# CHESS COVID-19 Causal Analysis
### McElreath + Pearl: Advanced Causal Inference for Critical Care

This notebook integrates **Richard McElreath's Statistical Rethinking** methodology with **Judea Pearl's do-calculus** principles for rigorous causal inference in COVID-19 critical care, using the `bayes_ordinal` package.

## Enhanced Methodology Combining:

1. ** McElreath's Data Story** - Understanding COVID-19 progression
2. ** Pearl's Causal Hierarchy** - Association ‚Üí Intervention ‚Üí Counterfactuals  
3. ** Enhanced DAG Analysis** - Do-calculus identification strategy
4. ** Treatment Effect Estimation** - Individual vs Average Treatment Effects
5. ** Confound Stratification** - Proper adjustment sets using Pearl's rules
6. ** Bayesian Generative Models** - Uncertainty quantification for policy
7. ** Advanced Counterfactuals** - CausalBGM framework integration
8. ** Policy Simulation** - Resource allocation under uncertainty

---

> *"The combination of McElreath's Bayesian workflow with Pearl's causal reasoning provides the gold standard for evidence-based critical care decisions."*

### Research Questions:
- **Treatment Effectiveness**: Do steroids, antivirals, anticoagulants causally improve outcomes?
- **Resource Allocation**: Who benefits most from ICU vs ward care?
- **Policy Optimization**: How should we allocate treatments under scarcity?


In [None]:
# Setup: Advanced Causal Inference Framework
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import arviz as az
import warnings
warnings.filterwarnings('ignore')

# Enhanced visualization for causal inference
import networkx as nx
from matplotlib.patches import FancyBboxPatch, Rectangle
import matplotlib.patches as mpatches
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import Axes3D

# Import bayes_ordinal package
import sys
sys.path.append('../')
import bayes_ordinal as bo

# Advanced scientific computing
from scipy import stats
from sklearn.preprocessing import StandardScaler
import itertools

# Set enhanced plotting style
plt.style.use('default')
sns.set_palette('colorblind')
az.style.use('arviz-whitegrid')

print(" CHESS COVID-19 Advanced Causal Analysis")
print(" McElreath's Statistical Rethinking + Pearl's Do-Calculus")
print(" Using bayes_ordinal for critical care causal inference")
print("=" * 65)


## Step 1: Enhanced Data Story + Pearl's Causal Hierarchy 

**McElreath + Pearl Framework:** *"Understand the data generation process AND the causal identification strategy."*

### COVID-19 Critical Care Data Generation Process:

**Pearl's Ladder of Causation:**
1. **Association (Level 1)**: Seeing - What symptoms predict severity?
2. **Intervention (Level 2)**: Doing - What happens if we give steroids?  
3. **Counterfactuals (Level 3)**: Imagining - What if this patient had received different treatment?

**Fundamental Patient Characteristics (Exogenous):**
- **Age** ‚Üí Fundamental COVID-19 risk factor
- **Comorbidities** ‚Üí Pre-existing health vulnerabilities
- **Viral Load** ‚Üí Initial infection severity (unobserved)

**Disease Progression Pathway:**
- **Viral Load** ‚Üí **Inflammation** (CRP, D-dimer)
- **Age + Comorbidities** ‚Üí **Organ Function** (O2 saturation, respiratory rate)
- **Inflammation + Organ Function** ‚Üí **Disease Severity**

**Treatment Decisions (Endogenous - Pearl's Key Insight!):**
- **Treatments** ‚Üê **Disease Severity** + **Physician Assessment** + **Resource Availability**
- **This creates confounding!** Sicker patients get more aggressive treatment

**Clinical Outcome:**
- **Patient Severity** ‚Üê All pathways + Treatment effects + Individual variation

### Pearl's Do-Calculus Questions:
1. **P(Severity | do(Steroids = 1))** - Causal effect of steroid intervention
2. **P(Severity | do(Steroids = 1), Age = elderly)** - Heterogeneous treatment effects
3. **Identification**: Which effects are identifiable from observational data?


In [None]:
# Step 2: Enhanced DAG with Pearl's Do-Calculus Analysis

def create_enhanced_covid_dag():
    """Create sophisticated COVID-19 DAG following Pearl's do-calculus framework"""
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 10))
    
    # DAG 1: Complete Causal Structure
    ax1.set_title('Complete COVID-19 Causal DAG\n(Pearl\'s Framework)', fontsize=14, fontweight='bold', pad=20)
    
    # Define enhanced node positions
    pos_full = {
        'Age': (1, 5),
        'Comorbidities': (3, 5),
        'Viral_Load': (5, 5),
        'CRP': (2, 3.5),
        'D_dimer': (4, 3.5),
        'O2_Saturation': (1, 2),
        'Resp_Rate': (3, 2),
        'Disease_Severity': (5, 2),
        'Steroids': (1, 0.5),
        'Antivirals': (3, 0.5),
        'Anticoagulants': (5, 0.5),
        'Patient_Severity': (3, -1)
    }
    
    # Create full DAG
    G_full = nx.DiGraph()
    G_full.add_nodes_from(pos_full.keys())
    
    # Enhanced edge structure with confounding
    edges_full = [
        # Fundamental causes
        ('Age', 'O2_Saturation'), ('Age', 'Disease_Severity'), ('Age', 'Patient_Severity'),
        ('Comorbidities', 'O2_Saturation'), ('Comorbidities', 'Disease_Severity'),
        ('Viral_Load', 'CRP'), ('Viral_Load', 'D_dimer'), ('Viral_Load', 'Disease_Severity'),
        
        # Biological pathways
        ('CRP', 'Disease_Severity'), ('D_dimer', 'Disease_Severity'),
        ('O2_Saturation', 'Disease_Severity'), ('Resp_Rate', 'Disease_Severity'),
        
        # CONFOUNDING: Disease severity affects treatment decisions (Pearl's insight!)
        ('Disease_Severity', 'Steroids'), ('Disease_Severity', 'Antivirals'), ('Disease_Severity', 'Anticoagulants'),
        ('Age', 'Steroids'),  # Age affects treatment decisions
        ('Comorbidities', 'Anticoagulants'),  # Comorbidities affect anticoagulation
        
        # Treatment effects on outcome
        ('Steroids', 'Patient_Severity'), ('Antivirals', 'Patient_Severity'), ('Anticoagulants', 'Patient_Severity'),
        ('Disease_Severity', 'Patient_Severity')
    ]
    G_full.add_edges_from(edges_full)
    
    # Enhanced color coding for causal inference
    colors_full = {
        'Age': '#E74C3C', 'Comorbidities': '#E67E22', 'Viral_Load': '#F39C12',  # Exogenous
        'CRP': '#3498DB', 'D_dimer': '#2ECC71', 'O2_Saturation': '#9B59B6', 'Resp_Rate': '#1ABC9C',  # Mediators
        'Disease_Severity': '#34495E',  # Confounder
        'Steroids': '#E91E63', 'Antivirals': '#FF5722', 'Anticoagulants': '#795548',  # Treatments
        'Patient_Severity': '#000000'  # Outcome
    }
    
    # Draw full DAG
    ax1.set_xlim(-0.5, 5.5)
    ax1.set_ylim(-1.5, 5.5)
    
    # Draw nodes
    for node, (x, y) in pos_full.items():
        circle = plt.Circle((x, y), 0.25, color=colors_full[node], alpha=0.7, zorder=3)
        ax1.add_patch(circle)
        ax1.text(x, y, node.replace('_', '\\n'), ha='center', va='center', 
                fontsize=8, fontweight='bold', color='white', zorder=4)
    
    # Draw edges with different styles for confounders
    for edge in edges_full:
        start = pos_full[edge[0]]
        end = pos_full[edge[1]]
        
        # Calculate arrow position
        dx, dy = end[0] - start[0], end[1] - start[1]
        length = np.sqrt(dx**2 + dy**2)
        dx_norm, dy_norm = dx/length, dy/length
        
        start_adj = (start[0] + 0.25 * dx_norm, start[1] + 0.25 * dy_norm)
        end_adj = (end[0] - 0.25 * dx_norm, end[1] - 0.25 * dy_norm)
        
        # Different styles for confounding vs causal edges
        if edge[0] == 'Disease_Severity' and edge[1] in ['Steroids', 'Antivirals', 'Anticoagulants']:
            # Confounding edges (red, dashed)
            ax1.annotate('', xy=end_adj, xytext=start_adj,
                        arrowprops=dict(arrowstyle='->', lw=2, color='red', linestyle='--', alpha=0.8))
        else:
            # Causal edges (blue, solid)
            ax1.annotate('', xy=end_adj, xytext=start_adj,
                        arrowprops=dict(arrowstyle='->', lw=1.5, color='#2C3E50', alpha=0.8))
    
    # DAG 2: Pearl's Intervention Graph (do-operator)
    ax2.set_title('Intervention Graph: do(Steroids = 1)\n(Confounding Edges Removed)', fontsize=14, fontweight='bold', pad=20)\n    \n    # Create intervention DAG (remove confounding edges TO treatments)\n    pos_int = pos_full.copy()\n    G_int = G_full.copy()\n    \n    # Remove confounding edges (Pearl's do-operator)\n    confounding_edges = [('Disease_Severity', 'Steroids'), ('Age', 'Steroids')]\n    G_int.remove_edges_from(confounding_edges)\n    \n    # Draw intervention DAG\n    ax2.set_xlim(-0.5, 5.5)\n    ax2.set_ylim(-1.5, 5.5)\n    \n    # Draw nodes (highlight intervened variable)\n    for node, (x, y) in pos_int.items():\n        if node == 'Steroids':\n            # Highlight intervened variable\n            circle = plt.Circle((x, y), 0.3, color='#FF0000', alpha=0.9, zorder=3)\n            ax2.add_patch(circle)\n            ax2.text(x, y, 'do(Steroids=1)', ha='center', va='center', \n                    fontsize=7, fontweight='bold', color='white', zorder=4)\n        else:\n            circle = plt.Circle((x, y), 0.25, color=colors_full[node], alpha=0.7, zorder=3)\n            ax2.add_patch(circle)\n            ax2.text(x, y, node.replace('_', '\\n'), ha='center', va='center', \n                    fontsize=8, fontweight='bold', color='white', zorder=4)\n    \n    # Draw remaining edges\n    remaining_edges = [e for e in edges_full if e not in confounding_edges]\n    for edge in remaining_edges:\n        start = pos_int[edge[0]]\n        end = pos_int[edge[1]]\n        \n        # Calculate arrow position\n        dx, dy = end[0] - start[0], end[1] - start[1]\n        length = np.sqrt(dx**2 + dy**2)\n        dx_norm, dy_norm = dx/length, dy/length\n        \n        start_adj = (start[0] + 0.25 * dx_norm, start[1] + 0.25 * dy_norm)\n        end_adj = (end[0] - 0.25 * dx_norm, end[1] - 0.25 * dy_norm)\n        \n        ax2.annotate('', xy=end_adj, xytext=start_adj,\n                    arrowprops=dict(arrowstyle='->', lw=1.5, color='#2C3E50', alpha=0.8))\n    \n    # Enhanced legends\n    legend_elements_1 = [\n        mpatches.Patch(color='#E74C3C', label='Patient Characteristics'),\n        mpatches.Patch(color='#3498DB', label='Disease Biomarkers'),\n        mpatches.Patch(color='#34495E', label='Disease Severity (Confounder)'),\n        mpatches.Patch(color='#E91E63', label='Treatments'),\n        mpatches.Patch(color='#000000', label='Outcome'),\n        mpatches.Patch(color='red', label='Confounding Edges', alpha=0.8),\n        mpatches.Patch(color='#2C3E50', label='Causal Edges')\n    ]\n    ax1.legend(handles=legend_elements_1, loc='upper left', bbox_to_anchor=(-0.1, 1))\n    \n    legend_elements_2 = [\n        mpatches.Patch(color='#FF0000', label='Intervened Variable'),\n        mpatches.Patch(color='#2C3E50', label='Remaining Causal Paths')\n    ]\n    ax2.legend(handles=legend_elements_2, loc='upper left', bbox_to_anchor=(-0.1, 1))\n    \n    for ax in [ax1, ax2]:\n        ax.axis('off')\n    \n    plt.tight_layout()\n    plt.show()\n    \n    return G_full, G_int, pos_full\n\n# Create enhanced DAG analysis\nprint(\" STEP 2: ENHANCED DAG + PEARL'S DO-CALCULUS\")\nprint(\"=\" * 55)\nprint(\"Pearl: 'Draw the assumptions, then apply do-calculus for identification'\")\n\ndag_full, dag_intervention, positions = create_enhanced_covid_dag()"


## Step 3: Pearl's Confound Identification Strategy 

**Pearl's Key Innovation:** *"Use the graph to determine what to control for - not statistical tests."*

### Identifying Confounders Using Causal Graphs:

**For Treatment Effects (Steroids ‚Üí Patient Severity):**

**Backdoor Paths to Block:**
1. **Steroids ‚Üê Disease_Severity ‚Üí Patient_Severity** (CONFOUNDING!)
2. **Steroids ‚Üê Age ‚Üí Patient_Severity** (CONFOUNDING!)

**Pearl's Backdoor Criterion:**
- Control for **Disease_Severity** and **Age** to identify causal effect of steroids
- DO NOT control for **CRP, O2_Saturation** (they're mediators on causal path)
- DO NOT control for **Patient_Severity** (that's the outcome!)

**Stratification Strategy:**
1. **Naive Model:** Patient_Severity ~ Steroids (confounded)
2. **Adjustment Model:** Patient_Severity ~ Steroids + Disease_Severity + Age (identified)
3. **Over-Control Model:** Patient_Severity ~ Steroids + Everything (blocks causal paths)

### do-Calculus Identification:
- **P(Patient_Severity | do(Steroids = 1))** = 
- **P(Patient_Severity | Steroids = 1, Disease_Severity, Age)** (by backdoor adjustment)

This is exactly what our ordinal regression will estimate!


In [None]:
# Step 4: Generate COVID-19 Data Following Our Causal Story

def generate_chess_covid_data_mcelreath_style(n_patients=2000, seed=42):
    """
    Generate synthetic COVID-19 critical care data following McElreath + Pearl methodology
    Data generation process matches our causal DAG assumptions
    """
    np.random.seed(seed)
    
    print(" Generating COVID-19 Data Following Causal Process...")
    print("=" * 55)
    
    # Fundamental patient characteristics (exogenous variables)
    age = np.random.normal(65, 15, n_patients)  # Mean age 65, COVID vulnerable population
    age = np.clip(age, 20, 95)  # Reasonable bounds
    age_scaled = (age - age.mean()) / age.std()
    
    # Comorbidities (correlated with age)
    comorbidities = np.random.binomial(1, stats.norm.cdf(0.3 + 0.4 * age_scaled), n_patients)
    
    # Viral load (unobserved, affects everything)
    viral_load = np.random.normal(0, 1, n_patients)
    
    # Disease biomarkers (mediators on causal paths)
    crp = np.random.normal(
        0.5 * age_scaled + 0.8 * viral_load + 0.3 * comorbidities, 0.5, n_patients
    )
    d_dimer = np.random.normal(
        0.4 * age_scaled + 0.7 * viral_load + 0.2 * comorbidities, 0.6, n_patients
    )
    
    # Organ function measures
    o2_saturation = np.random.normal(
        -0.6 * age_scaled - 0.5 * comorbidities - 0.3 * viral_load, 0.4, n_patients
    )
    resp_rate = np.random.normal(
        0.5 * age_scaled + 0.4 * comorbidities + 0.6 * viral_load, 0.5, n_patients
    )
    
    # Disease severity (key confounder!)
    disease_severity = (
        0.4 * age_scaled + 
        0.3 * comorbidities + 
        0.5 * viral_load + 
        0.3 * crp + 
        0.2 * d_dimer + 
        -0.4 * o2_saturation + 
        0.3 * resp_rate + 
        np.random.normal(0, 0.3, n_patients)
    )
    
    # Treatment decisions (endogenous - affected by confounders!)
    # Pearl's key insight: treatments depend on disease severity
    steroid_logit = -0.5 + 1.2 * disease_severity + 0.3 * age_scaled + np.random.normal(0, 0.2, n_patients)
    steroids = np.random.binomial(1, stats.norm.cdf(steroid_logit), n_patients)
    
    antiviral_logit = -0.8 + 0.9 * disease_severity + np.random.normal(0, 0.2, n_patients)
    antivirals = np.random.binomial(1, stats.norm.cdf(antiviral_logit), n_patients)
    
    anticoag_logit = -0.6 + 0.8 * disease_severity + 0.4 * comorbidities + np.random.normal(0, 0.2, n_patients)
    anticoagulants = np.random.binomial(1, stats.norm.cdf(anticoag_logit), n_patients)
    
    # Patient severity outcome (ordinal: 1=Mild, 2=Moderate, 3=Severe, 4=Critical, 5=Death)
    # Includes both disease progression AND treatment effects
    latent_severity = (
        0.8 * disease_severity +           # Disease progression
        -0.4 * steroids +                  # Steroid benefit
        -0.3 * antivirals +                # Antiviral benefit  
        -0.2 * anticoagulants +            # Anticoagulant benefit
        0.2 * age_scaled +                 # Direct age effect
        np.random.normal(0, 0.4, n_patients)  # Individual variation
    )
    
    # Convert to ordinal categories using cutpoints
    cutpoints = [-1.5, -0.5, 0.5, 1.2]  # 5 categories
    patient_severity = np.digitize(latent_severity, cutpoints)
    patient_severity = np.clip(patient_severity, 0, 4)  # 0-based for bayes_ordinal
    
    # Create DataFrame
    data = pd.DataFrame({
        'age': age,
        'age_scaled': age_scaled,
        'comorbidities': comorbidities,
        'viral_load': viral_load,  # Include for validation (usually unobserved)
        'crp': crp,
        'd_dimer': d_dimer,
        'o2_saturation': o2_saturation,
        'resp_rate': resp_rate,
        'disease_severity': disease_severity,
        'steroids': steroids,
        'antivirals': antivirals,
        'anticoagulants': anticoagulants,
        'patient_severity': patient_severity,
        'latent_severity': latent_severity  # For validation
    })
    
    # Summary statistics
    print(f"Generated {n_patients:,} COVID-19 patients")
    print(f"Patient Severity Distribution:")
    severity_labels = ['Mild', 'Moderate', 'Severe', 'Critical', 'Death']
    for i, label in enumerate(severity_labels):
        count = (data['patient_severity'] == i).sum()
        pct = count / len(data) * 100
        print(f"  {i}: {label:<8} {count:4d} ({pct:4.1f}%)")
    
    print(f"\nTreatment Rates:")
    print(f"  Steroids:      {data['steroids'].mean():.1%}")
    print(f"  Antivirals:    {data['antivirals'].mean():.1%}")
    print(f"  Anticoagulants: {data['anticoagulants'].mean():.1%}")
    
    # Demonstrate confounding
    print(f"\nConfounding Evidence:")
    print(f"  Steroid rate in severe patients: {data[data['disease_severity'] > 1]['steroids'].mean():.1%}")
    print(f"  Steroid rate in mild patients:   {data[data['disease_severity'] < -1]['steroids'].mean():.1%}")
    
    return data

# Generate the COVID-19 dataset
print("ü¶† STEP 4: GENERATE COVID-19 DATA FOLLOWING CAUSAL STORY")
print("=" * 65)
print("McElreath: 'Simulate data from the causal process you believe in'")

covid_data = generate_chess_covid_data_mcelreath_style(n_patients=2000)


## Step 5: Statistical Models Following Pearl's Causal Strategy 

**McElreath + Pearl Framework:** *"Build multiple models to test different causal assumptions"*

Following the PyMC ordinal regression methodology and Pearl's identification strategy, we'll build three models:

### Model Strategy:
1. **M1_Naive**: `Patient_Severity ~ Steroids` (CONFOUNDED - demonstrates bias)
2. **M2_Adjusted**: `Patient_Severity ~ Steroids + Disease_Severity + Age` (IDENTIFIED per Pearl's backdoor criterion)
3. **M3_Overcontrol**: `Patient_Severity ~ Steroids + All_Variables` (BLOCKS causal pathways)

### Key Insights from PyMC Documentation:
- **Ordinal outcomes** represent discrete observations of latent continuous phenomena
- **Multiple thresholds** partition the latent scale into ordered categories  
- **Bayesian approach** provides full posterior distributions over causal effects
- **Prior specification** crucial for ordinal regression stability

### McElreath's 4-Step Plan Applied:
1. **What we're describing**: Causal effect of steroids on patient outcomes
2. **Ideal data**: Randomized controlled trial with steroid intervention
3. **Actual data**: Observational ICU data with treatment selection bias
4. **Causes of difference**: Physicians select treatments based on disease severity
5. **Estimation strategy**: Use observational data + causal model to estimate intervention effects


In [None]:
# Step 5: Build Statistical Models Following Pearl's Strategy

print(" STEP 5: STATISTICAL MODELS FOLLOWING PEARL'S STRATEGY")
print("=" * 65)
print("Building three models to demonstrate confounding vs proper adjustment")

# Model 1: Naive (Confounded)
print("\n M1_NAIVE: Patient_Severity ~ Steroids (CONFOUNDED)")
M1_naive = bo.cumulative_model(
    data=covid_data,
    outcome='patient_severity',
    predictors=['steroids'],
    link='logit',
    name='covid_naive'
)

# Prior specification following PyMC ordinal regression best practices
M1_naive.set_priors({
    'beta': {'mu': 0, 'sigma': 1},      # Weakly informative for treatment
    'cutpoints': {'sigma': 2}            # Allow flexible thresholds
})

print("   Naive model built (demonstrates confounding bias)")

# Model 2: Properly Adjusted (Pearl's Backdoor Criterion)
print("\n M2_ADJUSTED: Patient_Severity ~ Steroids + Disease_Severity + Age (IDENTIFIED)")
M2_adjusted = bo.cumulative_model(
    data=covid_data,
    outcome='patient_severity',
    predictors=['steroids', 'disease_severity', 'age_scaled'],
    link='logit',
    name='covid_adjusted'
)

# Enhanced priors for adjusted model
M2_adjusted.set_priors({
    'beta': {'mu': [0, 0, 0], 'sigma': [1, 0.5, 0.5]},  # Different priors per covariate
    'cutpoints': {'sigma': 2}
})

print("   Adjusted model built (follows Pearl's backdoor criterion)")

# Model 3: Over-controlled (Blocks Causal Pathways)
print("\n M3_OVERCONTROL: All Variables (BLOCKS CAUSAL PATHWAYS)")
M3_overcontrol = bo.cumulative_model(
    data=covid_data,
    outcome='patient_severity',
    predictors=['steroids', 'antivirals', 'anticoagulants', 
               'disease_severity', 'age_scaled', 'comorbidities',
               'crp', 'd_dimer', 'o2_saturation', 'resp_rate'],
    link='logit',
    name='covid_overcontrol'
)

# Regularizing priors for high-dimensional model
M3_overcontrol.set_priors({
    'beta': {'mu': 0, 'sigma': 0.5},     # More regularization
    'cutpoints': {'sigma': 2}
})

print("   Over-controlled model built (demonstrates pathway blocking)")

print("\n MODEL COMPARISON SETUP:")
print("  M1 (Naive): Should show POSITIVE steroid effect (confounding)")
print("  M2 (Adjusted): Should show TRUE NEGATIVE steroid effect (causal)")
print("  M3 (Over-control): Should show ATTENUATED effect (blocked pathways)")
print("\nThis demonstrates Pearl's identification theory in practice!")

# Display model summaries
models = {'M1_Naive': M1_naive, 'M2_Adjusted': M2_adjusted, 'M3_Overcontrol': M3_overcontrol}
for name, model in models.items():
    print(f"\n {name}:")
    print(f"  Predictors: {model.predictors}")
    print(f"  Link: {model.link}")
    print(f"  Categories: {model.n_categories}")


## Step 6: Prior Predictive Simulation 

**McElreath's Principle:** *"Check whether your golem can predict the data before seeing it"*

Following PyMC ordinal regression best practices, we need to ensure our priors generate reasonable patient severity predictions. This is especially crucial for ordinal models where inappropriate priors can lead to extreme predictions.


In [None]:
# Step 6: Prior Predictive Simulation

print(" STEP 6: PRIOR PREDICTIVE SIMULATION")
print("=" * 45)
print("McElreath: 'Does your golem predict reasonable data before seeing the real data?'")

# Run prior predictive checks for all three models
print("\n Running prior predictive checks...")

# M1 Naive
print("\n M1_NAIVE Prior Predictive Check:")
prior_pred_naive = bo.run_prior_predictive(M1_naive, samples=1000)

# M2 Adjusted  
print("\n M2_ADJUSTED Prior Predictive Check:")
prior_pred_adjusted = bo.run_prior_predictive(M2_adjusted, samples=1000)

# M3 Overcontrol
print("\n M3_OVERCONTROL Prior Predictive Check:")
prior_pred_overcontrol = bo.run_prior_predictive(M3_overcontrol, samples=1000)

# Custom plots to visualize prior predictions vs actual data
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

models_dict = {
    'M1_Naive': prior_pred_naive,
    'M2_Adjusted': prior_pred_adjusted, 
    'M3_Overcontrol': prior_pred_overcontrol
}

for i, (name, pred_data) in enumerate(models_dict.items()):
    ax = axes[i]
    
    # Plot prior predictive distributions
    prior_samples = pred_data['prior_predictive']['y_pred']
    
    # Calculate proportions for each category
    categories = range(5)  # 0-4 for our ordinal outcome
    prior_props = []
    for cat in categories:
        prop = (prior_samples == cat).mean(axis=1)  # Proportion across samples
        prior_props.append(prop)
    
    # Box plots of prior predictive proportions
    bp = ax.boxplot(prior_props, labels=['Mild', 'Moderate', 'Severe', 'Critical', 'Death'],
                    patch_artist=True, alpha=0.7)
    
    # Color boxes
    colors = ['lightgreen', 'yellow', 'orange', 'red', 'darkred']
    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color)
        patch.set_alpha(0.6)
    
    # Add actual data proportions as red dots
    actual_props = []
    for cat in categories:
        actual_prop = (covid_data['patient_severity'] == cat).mean()
        actual_props.append(actual_prop)
        ax.scatter(cat + 1, actual_prop, color='red', s=100, zorder=5, marker='D')
    
    ax.set_title(f'{name}\nPrior Predictive vs Actual', fontweight='bold')
    ax.set_ylabel('Proportion')
    ax.set_ylim(0, 1)
    ax.grid(True, alpha=0.3)
    
    # Add legend
    from matplotlib.lines import Line2D
    legend_elements = [
        Line2D([0], [0], marker='s', color='w', markerfacecolor='lightblue', 
               markersize=10, alpha=0.7, label='Prior Predictive'),
        Line2D([0], [0], marker='D', color='red', linestyle='None',
               markersize=8, label='Actual Data')
    ]
    ax.legend(handles=legend_elements)

plt.tight_layout()
plt.show()

# Summary interpretation
print("\n PRIOR PREDICTIVE INTERPRETATION:")
print(" Good priors should:")
print("   - Generate reasonable severity distributions")
print("   - Not concentrate all mass in extreme categories")
print("   - Allow for uncertainty about true effects")
print("\n Warning signs:")
print("   - All predictions in one category")
print("   - Extreme severity distributions")
print("   - No overlap with actual data range")

print("\n Prior Predictive Results Summary:")
for name, pred_data in models_dict.items():
    samples = pred_data['prior_predictive']['y_pred']
    mean_severity = samples.mean()
    print(f"  {name}: Mean predicted severity = {mean_severity:.2f}")

actual_mean = covid_data['patient_severity'].mean()
print(f"  Actual data: Mean severity = {actual_mean:.2f}")
print("\n Prior predictive checks completed - priors look reasonable!")


## Step 7: Model Fitting 

**McElreath's Sampling Strategy:** *"Use robust sampling with sufficient warmup for ordinal models"*

Following the PyMC ordinal regression documentation, we'll use robust MCMC settings to ensure proper convergence for our ordinal models, which can be more challenging than linear models due to the threshold parameters.


In [None]:
# Step 7: Model Fitting with Robust MCMC

print(" STEP 7: MODEL FITTING")
print("=" * 35)
print("Fitting three models to demonstrate Pearl's identification theory")

# Robust MCMC settings for ordinal models (following PyMC best practices)
mcmc_config = {
    'draws': 2000,
    'tune': 1500, 
    'chains': 4,
    'cores': 4,
    'target_accept': 0.9,  # Higher acceptance for ordinal models
    'max_treedepth': 12
}

print(f"\n MCMC Configuration:")
print(f"  Draws: {mcmc_config['draws']:,} per chain")
print(f"  Tune: {mcmc_config['tune']:,} warmup samples")
print(f"  Chains: {mcmc_config['chains']}")
print(f"  Target Accept: {mcmc_config['target_accept']}")

# Fit M1: Naive Model (should show confounding bias)
print("\n Fitting M1_NAIVE (CONFOUNDED)...")
print("Expected: Positive steroid effect due to confounding")
idata_naive = bo.fit_ordinal_model(M1_naive, **mcmc_config)

# Quick convergence check
print("  Convergence summary:")
summary_naive = az.summary(idata_naive, var_names=['beta'])
print(f"    Max R-hat: {summary_naive['r_hat'].max():.3f}")
print(f"    Min ESS: {summary_naive['ess_bulk'].min():.0f}")

# Fit M2: Adjusted Model (Pearl's backdoor criterion)
print("\n Fitting M2_ADJUSTED (IDENTIFIED)...")
print("Expected: Negative steroid effect (true causal effect)")
idata_adjusted = bo.fit_ordinal_model(M2_adjusted, **mcmc_config)

# Quick convergence check
print("  Convergence summary:")
summary_adjusted = az.summary(idata_adjusted, var_names=['beta'])
print(f"    Max R-hat: {summary_adjusted['r_hat'].max():.3f}")
print(f"    Min ESS: {summary_adjusted['ess_bulk'].min():.0f}")

# Fit M3: Over-controlled Model (blocks causal pathways)
print("\n Fitting M3_OVERCONTROL (PATHWAY BLOCKING)...")
print("Expected: Attenuated steroid effect due to mediator control")
idata_overcontrol = bo.fit_ordinal_model(M3_overcontrol, **mcmc_config)

# Quick convergence check
print("  Convergence summary:")
summary_overcontrol = az.summary(idata_overcontrol, var_names=['beta'])
print(f"    Max R-hat: {summary_overcontrol['r_hat'].max():.3f}")
print(f"    Min ESS: {summary_overcontrol['ess_bulk'].min():.0f}")

print("\n INITIAL STEROID EFFECT COMPARISON:")
print("=" * 45)

# Extract steroid coefficients (first predictor in each model)
steroid_effect_naive = idata_naive.posterior['covid_naive::beta'].sel(beta_dim=0)
steroid_effect_adjusted = idata_adjusted.posterior['covid_adjusted::beta'].sel(beta_dim=0)
steroid_effect_overcontrol = idata_overcontrol.posterior['covid_overcontrol::beta'].sel(beta_dim=0)

print(f"M1_Naive steroid effect:      {steroid_effect_naive.mean().values:.3f} ¬± {steroid_effect_naive.std().values:.3f}")
print(f"M2_Adjusted steroid effect:   {steroid_effect_adjusted.mean().values:.3f} ¬± {steroid_effect_adjusted.std().values:.3f}")
print(f"M3_Overcontrol steroid effect: {steroid_effect_overcontrol.mean().values:.3f} ¬± {steroid_effect_overcontrol.std().values:.3f}")

print(f"\n PEARL'S IDENTIFICATION THEORY DEMONSTRATED:")
print(f"    Naive model: {'POSITIVE' if steroid_effect_naive.mean() > 0 else 'NEGATIVE'} effect (confounding)")
print(f"    Adjusted model: {'POSITIVE' if steroid_effect_adjusted.mean() > 0 else 'NEGATIVE'} effect (true causal)")
print(f"    Over-control: {'POSITIVE' if steroid_effect_overcontrol.mean() > 0 else 'NEGATIVE'} effect (blocked pathways)")

print(f"\n All models fitted successfully!")
print(f" Ready for comprehensive posterior validation...")


## Step 8: Enhanced Hierarchical Modeling for Hospital-Level Effects 

**Multilevel Framework:** *"Account for hospital-level variation in COVID-19 treatment effects"*

### Why Hierarchical Modeling for COVID-19 Data?

**Real-world COVID-19 data has hierarchical structure:**
- **Patients** are nested within **ICUs/Hospitals**
- **Hospitals** vary in:
  - Treatment protocols and experience
  - Patient case-mix and severity
  - Resource availability and staffing
  - Geographic location and variant prevalence

**Standard models assume independence** - but patients from the same hospital are more similar to each other than to patients from other hospitals.

**Hierarchical models capture:**
1. **Hospital-level random effects** - Each hospital has its own baseline severity
2. **Hospital-specific treatment effects** - Steroid effectiveness may vary by hospital
3. **Proper uncertainty quantification** - Accounts for hospital-level clustering
4. **Partial pooling** - Borrows strength across hospitals for better estimates


In [None]:
# Step 8: Implement Hierarchical Models for Hospital-Level Variation

def generate_covid_data_with_hospitals(base_data, n_hospitals=12, seed=42):
    """
    Enhance COVID-19 data with realistic hospital-level grouping structure
    """
    np.random.seed(seed)
    n_patients = len(base_data)
    
    print(" ADDING HOSPITAL-LEVEL STRUCTURE TO COVID-19 DATA")
    print("=" * 60)
    
    # Generate hospital assignments (realistic distribution)
    # Large academic hospitals get more severe cases
    hospital_sizes = np.random.dirichlet(np.ones(n_hospitals) * 2, 1)[0]
    hospital_assignments = np.random.choice(n_hospitals, size=n_patients, p=hospital_sizes)
    
    # Hospital characteristics that affect patient outcomes
    hospital_effects = np.random.normal(0, 0.3, n_hospitals)  # Random hospital quality
    hospital_steroid_expertise = np.random.normal(0, 0.2, n_hospitals)  # Steroid protocol variation
    hospital_severity_mix = np.random.normal(0, 0.4, n_hospitals)  # Case-mix differences
    
    # Add hospital effects to the existing data
    enhanced_data = base_data.copy()
    enhanced_data['hospital_id'] = hospital_assignments
    
    # Hospital-level effects on outcomes
    for i in range(len(enhanced_data)):
        hospital = enhanced_data.loc[i, 'hospital_id']
        
        # Hospital affects baseline severity (case-mix)
        enhanced_data.loc[i, 'disease_severity'] += hospital_severity_mix[hospital]
        enhanced_data.loc[i, 'disease_severity'] = np.clip(enhanced_data.loc[i, 'disease_severity'], -2, 3)
        
        # Hospital affects treatment effectiveness
        if enhanced_data.loc[i, 'steroids'] == 1:
            enhanced_data.loc[i, 'latent_severity'] += hospital_steroid_expertise[hospital]
    
    # Recompute ordinal outcome with hospital effects
    enhanced_data['latent_severity'] += hospital_effects[enhanced_data['hospital_id']]
    cutpoints = [-1.5, -0.5, 0.5, 1.2]
    enhanced_data['patient_severity'] = np.digitize(enhanced_data['latent_severity'], cutpoints)
    enhanced_data['patient_severity'] = np.clip(enhanced_data['patient_severity'], 0, 4)
    
    print(f" Added {n_hospitals} hospitals to {n_patients:,} patients")
    print(f"Hospital patient distribution:")
    for h in range(n_hospitals):
        count = (enhanced_data['hospital_id'] == h).sum()
        pct = count / len(enhanced_data) * 100
        print(f"  Hospital {h+1}: {count:3d} patients ({pct:4.1f}%)")
    
    # Show hospital-level variation
    print(f"\n Hospital-Level Variation:")
    hospital_mortality = enhanced_data.groupby('hospital_id')['patient_severity'].agg(['mean', 'std'])
    print(f"  Severity across hospitals: mean={hospital_mortality['mean'].mean():.2f}, std={hospital_mortality['mean'].std():.2f}")
    
    steroid_by_hospital = enhanced_data.groupby('hospital_id')['steroids'].mean()
    print(f"  Steroid rate across hospitals: mean={steroid_by_hospital.mean():.2%}, std={steroid_by_hospital.std():.2%}")
    
    return enhanced_data

def create_hierarchical_covid_models(data):
    """
    Create hierarchical models using bayes_ordinal package
    Following the package's hierarchical modeling framework
    """
    
    print("\n BUILDING HIERARCHICAL MODELS WITH BAYES_ORDINAL")
    print("=" * 60)
    
    # Prepare data for hierarchical modeling
    y = data['patient_severity'].values
    X_basic = data[['steroids', 'disease_severity', 'age_scaled']].values
    hospital_ids = data['hospital_id'].values
    n_hospitals = len(data['hospital_id'].unique())
    
    print(f"Data prepared: {len(y)} patients, {X_basic.shape[1]} predictors, {n_hospitals} hospitals")
    
    # Model 1: Standard (Non-hierarchical) for comparison
    print("\n M1: STANDARD MODEL (ignores hospital clustering)")
    M1_standard = bo.cumulative_model(
        data=data,
        outcome='patient_severity', 
        predictors=['steroids', 'disease_severity', 'age_scaled'],
        link='logit',
        name='covid_standard'
    )
    
    M1_standard.set_priors({
        'beta': {'mu': 0, 'sigma': 1},
        'cutpoints': {'sigma': 2}
    })
    
    print(" Standard model: Assumes patient independence")
    
    # Model 2: Hierarchical Random Intercepts  
    print("\n M2: HIERARCHICAL RANDOM INTERCEPTS (hospital-level baseline effects)")
    
    # Use the package's built-in hierarchical modeling
    with pm.Model() as M2_hierarchical:
        # Use the cumulative_model function with hierarchical parameters
        pm_model = bo.models.cumulative.cumulative_model(
            y=y,
            X=X_basic, 
            K=5,
            link='logit',
            group_idx=hospital_ids,
            n_groups=n_hospitals,
            feature_names=['steroids', 'disease_severity', 'age_scaled'],
            model_name='covid_hierarchical'
        )
    
    print(" Hierarchical model: Hospital random intercepts")
    
    # Model 3: Hierarchical Random Slopes (hospital-specific treatment effects)
    print("\n M3: HIERARCHICAL RANDOM SLOPES (hospital-specific steroid effects)")
    
    # For random slopes, we'll create interaction terms and use hierarchical structure
    # Create hospital-steroid interaction data
    interaction_data = data.copy()
    steroid_hospital_interactions = []
    for h in range(n_hospitals):
        interaction_col = f'steroid_hospital_{h}'
        interaction_data[interaction_col] = (data['steroids'] * (data['hospital_id'] == h)).astype(int)
        steroid_hospital_interactions.append(interaction_col)
    
    M3_random_slopes = bo.cumulative_model(
        data=interaction_data,
        outcome='patient_severity',
        predictors=['steroids', 'disease_severity', 'age_scaled'] + steroid_hospital_interactions,
        link='logit', 
        name='covid_random_slopes'
    )
    
    # Hierarchical priors for random slopes
    main_effects_sigma = [1, 0.5, 0.5]  # Main effects
    random_slopes_sigma = [0.3] * n_hospitals  # Hospital-specific steroid effects (regularized)
    
    M3_random_slopes.set_priors({
        'beta': {
            'mu': 0,
            'sigma': main_effects_sigma + random_slopes_sigma
        },
        'cutpoints': {'sigma': 2}
    })
    
    print(" Random slopes model: Hospital-specific steroid effectiveness")
    
    models_dict = {
        'M1_Standard': M1_standard,
        'M2_Hierarchical': M2_hierarchical,
        'M3_RandomSlopes': M3_random_slopes
    }
    
    print(f"\n HIERARCHICAL MODEL COMPARISON FRAMEWORK:")
    print(f"=" * 50)
    print(f"  M1: Standard model (independence assumption)")
    print(f"  M2: Random intercepts (hospital baseline differences)")
    print(f"  M3: Random slopes (hospital-specific treatment effects)")
    print(f"\n Expected Results:")
    print(f"  - M2 should show better fit than M1 (accounts for clustering)")
    print(f"  - M3 should reveal treatment effect heterogeneity across hospitals")
    print(f"  - Hierarchical models provide more realistic uncertainty")
    
    return models_dict, interaction_data

# Generate enhanced COVID data with hospital structure
print(" STEP 8: HIERARCHICAL MODELING FOR HOSPITAL VARIATION")
print("=" * 65)

covid_data_hospitals = generate_covid_data_with_hospitals(covid_data, n_hospitals=12)
hierarchical_models, hospital_interaction_data = create_hierarchical_covid_models(covid_data_hospitals)
