# Comprehensive Bayesian Analysis of Colorectal Cancer Tissue Imaging Data

**BIOSTAT 682 Final Project**  
**Author:** Santosh Desai  
**Date:** December 2025  
**Institution:** University of Michigan

---

## Project Overview

This notebook presents a comprehensive Bayesian analysis of multiplexed tissue imaging data from colorectal cancer (CRC) samples. The analysis employs multiple Bayesian regression models to understand the relationship between PD-1 (Programmed Death-1) checkpoint marker expression and other immune cell markers, while accounting for spatial structure in the tumor microenvironment.

### Research Question

How do immune checkpoint markers (specifically PD-1) relate to other immune cell markers and spatial location in the tumor microenvironment? We address this through:

1. **Multiple regression models** with different prior specifications
2. **Variable selection** using spike-and-slab priors
3. **Non-linear relationships** via Bayesian Neural Networks (BNNs)
4. **Spatial effects** to account for tissue structure

### Dataset Description

The dataset (`CRC_data_55A.csv`) contains:
- **2,874 cells** from a colorectal cancer tissue sample
- **Spatial coordinates** (cx, cy) for each cell's location in the tissue
- **20 protein markers** including:
  - Immune checkpoint markers: PD-1, LAG-3, VISTA, ICOS
  - T cell markers: CD2, CD5, CD25
  - Macrophage markers: CD68, CD11b
  - Dendritic cell markers: CD21
  - NK cell markers: CD56
  - Signaling markers: beta-catenin, EGFR, CD44
  - Other immune markers: GATA3, CD38, Podoplanin, IDO-1

### Analysis Workflow

This notebook follows a rigorous Bayesian workflow:

1. **Data Loading and Preprocessing**: Load, validate, and standardize data
2. **Exploratory Data Analysis**: Comprehensive visualization and summary statistics
3. **Model 1**: Bayesian linear regression with noninformative priors
4. **Model 2**: Bayesian linear regression with hierarchical priors
5. **Model 3**: Spike-and-slab variable selection
6. **Model 4**: Bayesian Neural Network (BNN) with spike-and-slab priors
7. **Model 5**: Spatial regression model with coordinate effects
8. **Model Comparison**: PSIS-LOO, WAIC, and predictive performance
9. **Posterior Analysis**: Credible intervals, posterior predictive checks
10. **Convergence Diagnostics**: R-hat, ESS, MCSE, Geweke diagnostics

All models are fit using MCMC (NUTS sampler) with comprehensive convergence diagnostics.


## 1. Setup and Imports

We begin by importing all necessary libraries and setting up the computational environment for reproducibility.


In [None]:
%pip install pymc>=5.10.0 arviz>=0.17.0 seaborn>=0.12.0 scikit-learn>=1.3.0

from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional, Any
import warnings

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore", message=".*install.*ipywidgets.*", category=UserWarning)
warnings.filterwarnings("ignore", category=UserWarning)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pymc as pm
import arviz as az
from scipy.spatial.distance import cdist
from scipy.stats import norm, t as student_t
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# Set professional plotting style
sns.set_theme(style="whitegrid", context="talk", palette="husl")
plt.rcParams.update({
    "figure.autolayout": True,
    "font.size": 11,
    "axes.titlesize": 14,
    "axes.labelsize": 12,
    "xtick.labelsize": 10,
    "ytick.labelsize": 10,
    "legend.fontsize": 10,
    "figure.titlesize": 16,
    "font.family": "sans-serif",
    "font.sans-serif": ["Arial", "DejaVu Sans", "Liberation Sans", "Helvetica", "sans-serif"]
})

# Reproducibility
RNG_SEED: int = 42
rng: np.random.Generator = np.random.default_rng(RNG_SEED)

# Display options
pd.options.display.float_format = "{:0.3f}".format
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)

print("=" * 80)
print("COMPREHENSIVE BAYESIAN ANALYSIS OF CRC TISSUE IMAGING DATA")
print("=" * 80)
print(f"\nPyMC version: {pm.__version__}")
print(f"ArviZ version: {az.__version__}")
print(f"NumPy version: {np.__version__}")
print(f"Pandas version: {pd.__version__}")
print(f"Random seed: {RNG_SEED}")
print("=" * 80)


## 2. Data Loading and Preprocessing

We load the CRC tissue imaging data and perform comprehensive preprocessing including:
- Data validation and quality checks
- Standardization of predictors and response
- Train/test split for predictive evaluation
- Creation of standardized coordinate features


In [None]:
@dataclass
class Standardizer:
    """Standardizer for data preprocessing."""
    mean: pd.Series
    std: pd.Series
    
    def transform(self, X: pd.DataFrame) -> pd.DataFrame:
        """Transform data using fitted mean and std."""
        z = (X - self.mean) / self.std.replace(0, 1.0)
        return z
    
    @classmethod
    def fit(cls, X: pd.DataFrame) -> "Standardizer":
        """Fit standardizer to data."""
        return cls(mean=X.mean(), std=X.std(ddof=1))
    
    def inverse_transform(self, X: pd.DataFrame) -> pd.DataFrame:
        """Inverse transform to original scale."""
        return X * self.std + self.mean


def load_crc_data(path: str) -> pd.DataFrame:
    """
    Load CRC tissue imaging data.
    
    Parameters:
    -----------
    path : str
        Path to CSV file
        
    Returns:
    --------
    pd.DataFrame
        Loaded data frame
    """
    df = pd.read_csv(path)
    print(f"Loaded data: {df.shape[0]} rows, {df.shape[1]} columns")
    return df


# Load data
data_path = "data/CRC_data_55A.csv"
df_raw = load_crc_data(data_path)

# Display first few rows
print("\nFirst 5 rows:")
display(df_raw.head())

# Display column names
print(f"\nColumn names ({len(df_raw.columns)} total):")
print(df_raw.columns.tolist())


In [None]:
# Extract spatial coordinates
coords = df_raw[['cx', 'cy']].values
print(f"Spatial coordinates shape: {coords.shape}")
print(f"Coordinate ranges:")
print(f"  cx: [{coords[:, 0].min():.3f}, {coords[:, 0].max():.3f}]")
print(f"  cy: [{coords[:, 1].min():.3f}, {coords[:, 1].max():.3f}]")

# Extract all marker columns (exclude coordinates)
marker_cols = [col for col in df_raw.columns if col not in ['cx', 'cy']]
markers_df = df_raw[marker_cols]

print(f"\nTotal markers: {len(marker_cols)}")

# Identify response variable (PD-1)
pd1_col = [col for col in marker_cols if 'PD-1' in col]
if not pd1_col:
    raise ValueError("PD-1 column not found!")
pd1_col = pd1_col[0]
print(f"\nResponse variable: {pd1_col}")

y_raw = df_raw[pd1_col].values
print(f"\nResponse statistics (raw scale):")
print(f"  Mean: {y_raw.mean():.3f}")
print(f"  Std: {y_raw.std():.3f}")
print(f"  Min: {y_raw.min():.3f}")
print(f"  Max: {y_raw.max():.3f}")
print(f"  Median: {np.median(y_raw):.3f}")

# Check for missing values
missing = df_raw.isnull().sum()
if missing.sum() > 0:
    print(f"\n⚠ Warning: {missing.sum()} missing values found:")
    print(missing[missing > 0])
else:
    print("\n✓ No missing values detected")


In [None]:
# Select predictor markers
# Focus on key immune markers for interpretability and model complexity
predictor_candidates = [
    col for col in marker_cols 
    if col != pd1_col and any(marker in col for marker in [
        'CD2', 'CD5', 'CD25', 'LAG-3', 'VISTA', 'ICOS', 
        'CD68', 'CD11b', 'CD21', 'GATA3', 'CD56', 'CD38'
    ])
]

print(f"Candidate predictors: {len(predictor_candidates)}")

# Select top predictors based on correlation with PD-1
correlations = markers_df[predictor_candidates].corrwith(markers_df[pd1_col]).abs().sort_values(ascending=False)
selected_predictors = correlations.head(10).index.tolist()  # Use top 10 for comprehensive analysis

print(f"\nSelected {len(selected_predictors)} predictors (top correlations with PD-1):")
for i, (pred, corr) in enumerate(zip(selected_predictors, correlations[selected_predictors]), 1):
    print(f"  {i:2d}. {pred.split(':')[0]:30s} (|r| = {corr:.3f})")

# Extract predictor matrix
X_raw = markers_df[selected_predictors].values
n, p = X_raw.shape

print(f"\nData dimensions:")
print(f"  Sample size (n): {n}")
print(f"  Number of predictors (p): {p}")
print(f"  Response variable: PD-1 expression")


In [None]:
# Standardize predictors and response
# Standardization improves numerical stability and MCMC mixing
# Coefficients are interpreted as change per 1 SD increase in predictor

# Standardize predictors
X_scaler = Standardizer.fit(pd.DataFrame(X_raw, columns=selected_predictors))
X_scaled_df = X_scaler.transform(pd.DataFrame(X_raw, columns=selected_predictors))
X_scaled = X_scaled_df.values

# Standardize response
y_scaler = Standardizer.fit(pd.Series(y_raw, name=pd1_col))
y_scaled = y_scaler.transform(pd.Series(y_raw, name=pd1_col)).values

# Standardize coordinates
coords_scaler = Standardizer.fit(pd.DataFrame(coords, columns=['cx', 'cy']))
coords_scaled_df = coords_scaler.transform(pd.DataFrame(coords, columns=['cx', 'cy']))
coords_scaled = coords_scaled_df.values

print("Standardization complete:")
print(f"  Predictors: mean ≈ {X_scaled.mean(axis=0)[:3]}, std ≈ {X_scaled.std(axis=0, ddof=1)[:3]}")
print(f"  Response: mean = {y_scaled.mean():.6f}, std = {y_scaled.std(ddof=1):.6f}")
print(f"  Coordinates: mean ≈ {coords_scaled.mean(axis=0)}, std ≈ {coords_scaled.std(axis=0, ddof=1)}")

# Store scaling parameters for later use
X_mean = X_scaler.mean.values
X_std = X_scaler.std.values
y_mean = y_scaler.mean.values[0]
y_std = y_scaler.std.values[0]


In [None]:
# Train/test split for predictive evaluation
# Use 70/30 split to have sufficient data for both training and testing
train_idx, test_idx = train_test_split(
    np.arange(n), 
    test_size=0.3, 
    random_state=RNG_SEED,
    shuffle=True
)

X_train = X_scaled[train_idx]
X_test = X_scaled[test_idx]
y_train = y_scaled[train_idx]
y_test = y_scaled[test_idx]
coords_train = coords_scaled[train_idx]
coords_test = coords_scaled[test_idx]

n_train = len(train_idx)
n_test = len(test_idx)

print(f"Train/test split:")
print(f"  Training set: {n_train} observations ({100*n_train/n:.1f}%)")
print(f"  Test set: {n_test} observations ({100*n_test/n:.1f}%)")
print(f"  Total: {n} observations")


## 3. Exploratory Data Analysis

We perform comprehensive exploratory data analysis to understand:
- Distribution of response variable (PD-1 expression)
- Spatial distribution of PD-1 in the tissue
- Correlations between markers
- Relationships between predictors and response


In [None]:
# Figure 1: Spatial distribution of PD-1 expression
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Spatial scatter plot
scatter1 = axes[0].scatter(coords[:, 0], coords[:, 1], c=y_raw, cmap='viridis', 
                           s=2, alpha=0.6, rasterized=True)
axes[0].set_xlabel('X coordinate', fontsize=12)
axes[0].set_ylabel('Y coordinate', fontsize=12)
axes[0].set_title('Spatial Distribution of PD-1 Expression', fontsize=14, fontweight='bold')
cbar1 = plt.colorbar(scatter1, ax=axes[0], label='PD-1 Expression')
axes[0].grid(True, alpha=0.3)

# Histogram of response
axes[1].hist(y_raw, bins=60, edgecolor='black', alpha=0.7, color='steelblue')
axes[1].axvline(y_raw.mean(), color='red', linestyle='--', linewidth=2, 
                label=f'Mean = {y_raw.mean():.2f}')
axes[1].axvline(np.median(y_raw), color='green', linestyle='--', linewidth=2, 
                label=f'Median = {np.median(y_raw):.2f}')
axes[1].set_xlabel('PD-1 Expression', fontsize=12)
axes[1].set_ylabel('Frequency', fontsize=12)
axes[1].set_title('Distribution of PD-1 Expression', fontsize=14, fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('report/figures/figure1_spatial_distribution.png', dpi=300, bbox_inches='tight')
plt.show()

print("Figure 1: Spatial distribution and histogram of PD-1 expression")


In [None]:
# Figure 2: Correlation heatmap
corr_data = markers_df[[pd1_col] + selected_predictors]
corr_matrix = corr_data.corr()

plt.figure(figsize=(12, 10))
mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
sns.heatmap(corr_matrix, mask=mask, annot=True, fmt='.2f', cmap='coolwarm', 
            center=0, square=True, linewidths=0.5, cbar_kws={"shrink": 0.8},
            vmin=-1, vmax=1, annot_kws={'size': 9})
plt.title('Correlation Matrix: PD-1 and Selected Predictors', 
         fontsize=14, fontweight='bold', pad=20)
plt.tight_layout()
plt.savefig('report/figures/figure2_correlation_heatmap.png', dpi=300, bbox_inches='tight')
plt.show()

print("Figure 2: Correlation heatmap showing relationships between PD-1 and predictors")


In [None]:
# Summary statistics table
summary_stats = pd.DataFrame({
    'Mean': markers_df[[pd1_col] + selected_predictors].mean(),
    'Std': markers_df[[pd1_col] + selected_predictors].std(),
    'Min': markers_df[[pd1_col] + selected_predictors].min(),
    'Max': markers_df[[pd1_col] + selected_predictors].max(),
    'Median': markers_df[[pd1_col] + selected_predictors].median(),
    'Correlation_with_PD1': [1.0] + [correlations[pred] for pred in selected_predictors]
})

print("Summary Statistics:")
print("=" * 100)
display(summary_stats)

# Save for report table
summary_stats.to_csv('report/summary_statistics.csv')


## 4. Model 1: Bayesian Linear Regression with Noninformative Priors

We begin with a standard Bayesian linear regression model using noninformative (flat) priors. This serves as a baseline and allows us to compare with more structured priors in subsequent models.

### Model Specification

**Likelihood:**
$$y_i \sim \mathcal{N}(\mu_i, \sigma^2)$$

where
$$\mu_i = \beta_0 + \sum_{j=1}^{p} \beta_j x_{ij}$$

**Prior Distributions:**
- Intercept: $p(\beta_0) \propto 1$ (improper uniform)
- Coefficients: $p(\beta_j) \propto 1$ for $j = 1, \ldots, p$ (improper uniform)
- Error variance: $p(\sigma^2) \propto 1/\sigma^2$ (Jeffreys prior)

This specification yields a conjugate normal-inverse-gamma posterior, but we use MCMC for consistency with other models.


In [None]:
print("=" * 80)
print("MODEL 1: Bayesian Linear Regression (Noninformative Priors)")
print("=" * 80)

with pm.Model() as model1:
    # Noninformative priors (very diffuse)
    beta_0 = pm.Normal("beta_0", mu=0.0, sigma=100.0)  # Very diffuse
    beta = pm.Normal("beta", mu=0.0, sigma=100.0, shape=p)  # Very diffuse
    sigma = pm.HalfNormal("sigma", sigma=10.0)  # Weakly informative
    
    # Linear predictor
    mu = beta_0 + pm.math.dot(X_train, beta)
    
    # Likelihood
    y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y_train)

print("Model 1 built successfully!")
print(f"  Parameters: {p + 2} (1 intercept + {p} coefficients + 1 error variance)")


In [None]:
# Fit Model 1
print("\nSampling from posterior (Model 1)...")
print("This may take several minutes...")

with model1:
    trace1 = pm.sample(
        draws=2000,
        tune=2000,
        chains=4,
        cores=1,
        target_accept=0.95,
        random_seed=RNG_SEED,
        return_inferencedata=True,
        progressbar=True
    )

print("\n✓ Model 1 sampling completed!")

# Convergence diagnostics
print("\n" + "=" * 80)
print("MODEL 1: Convergence Diagnostics")
print("=" * 80)

rhat1 = az.rhat(trace1)
max_rhat1 = max([float(rhat1[var].max()) for var in rhat1.data_vars])
print(f"\nR-hat (max): {max_rhat1:.4f} (target: < 1.01)")

ess1 = az.ess(trace1)
min_ess1 = min([float(ess1[var].min()) for var in ess1.data_vars])
print(f"ESS (min): {min_ess1:.0f} (target: > 400)")

n_div1 = trace1.sample_stats.divergences.sum().item()
print(f"Divergences: {n_div1} (target: 0)")

# Posterior summary
summary1 = az.summary(trace1, var_names=['beta_0', 'beta', 'sigma'])
print("\nPosterior Summary (Model 1):")
print("=" * 80)
display(summary1.head(15))  # Show first 15 rows


## 5. Model 2: Bayesian Linear Regression with Hierarchical Priors

We now fit a Bayesian linear regression with hierarchical (shrinkage) priors. This model regularizes coefficients toward zero, which helps with overfitting and provides better generalization.

### Model Specification

**Likelihood:**
$$y_i \sim \mathcal{N}(\mu_i, \sigma^2)$$

where
$$\mu_i = \beta_0 + \sum_{j=1}^{p} \beta_j x_{ij}$$

**Prior Distributions:**
- Intercept: $\beta_0 \sim \mathcal{N}(0, 10^2)$
- Coefficients: $\beta_j \sim \mathcal{N}(0, \tau^2)$ for $j = 1, \ldots, p$
- Global scale: $\tau^2 \sim \text{InverseGamma}(2, 1)$
- Error variance: $\sigma^2 \sim \text{InverseGamma}(2, 1)$

The hierarchical structure allows the data to inform the appropriate level of shrinkage through the global scale parameter $\tau^2$.


In [None]:
print("=" * 80)
print("MODEL 2: Bayesian Linear Regression (Hierarchical Priors)")
print("=" * 80)

with pm.Model() as model2:
    # Intercept
    beta_0 = pm.Normal("beta_0", mu=0.0, sigma=10.0)
    
    # Hierarchical prior for coefficients
    tau_sq = pm.InverseGamma("tau_sq", alpha=2.0, beta=1.0)
    beta = pm.Normal("beta", mu=0.0, sigma=pm.math.sqrt(tau_sq), shape=p)
    
    # Error variance
    sigma_sq = pm.InverseGamma("sigma_sq", alpha=2.0, beta=1.0)
    sigma = pm.Deterministic("sigma", pm.math.sqrt(sigma_sq))
    
    # Linear predictor
    mu = beta_0 + pm.math.dot(X_train, beta)
    
    # Likelihood
    y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y_train)

print("Model 2 built successfully!")
print(f"  Parameters: {p + 4} (1 intercept + {p} coefficients + 1 global scale + 1 error variance + 1 sigma)")


In [None]:
# Fit Model 2
print("\nSampling from posterior (Model 2)...")

with model2:
    trace2 = pm.sample(
        draws=2000,
        tune=2000,
        chains=4,
        cores=1,
        target_accept=0.95,
        random_seed=RNG_SEED,
        return_inferencedata=True,
        progressbar=True
    )

print("\n✓ Model 2 sampling completed!")

# Convergence diagnostics
print("\n" + "=" * 80)
print("MODEL 2: Convergence Diagnostics")
print("=" * 80)

rhat2 = az.rhat(trace2)
max_rhat2 = max([float(rhat2[var].max()) for var in rhat2.data_vars if var != 'sigma'])
print(f"\nR-hat (max): {max_rhat2:.4f} (target: < 1.01)")

ess2 = az.ess(trace2)
min_ess2 = min([float(ess2[var].min()) for var in ess2.data_vars if var != 'sigma'])
print(f"ESS (min): {min_ess2:.0f} (target: > 400)")

n_div2 = trace2.sample_stats.divergences.sum().item()
print(f"Divergences: {n_div2} (target: 0)")


## 6. Model 3: Spike-and-Slab Variable Selection

We implement a spike-and-slab prior for automatic variable selection. This model includes binary inclusion indicators $\gamma_j$ that determine whether each predictor is included in the model.

### Model Specification

**Likelihood:**
$$y_i \sim \mathcal{N}(\mu_i, \sigma^2)$$

where
$$\mu_i = \beta_0 + \sum_{j=1}^{p} \beta_j x_{ij}$$

**Prior Distributions:**
- Intercept: $\beta_0 \sim \mathcal{N}(0, 10^2)$
- Inclusion indicators: $\gamma_j \sim \text{Bernoulli}(\pi)$ with $\pi = 0.5$
- Coefficients: $\beta_j \sim \mathcal{N}(0, \sigma_{\gamma_j}^2)$ where
  - $\sigma_{\gamma_j}^2 = \sigma_{\text{spike}}^2$ if $\gamma_j = 0$ (spike)
  - $\sigma_{\gamma_j}^2 = \sigma_{\text{slab}}^2$ if $\gamma_j = 1$ (slab)
- Spike variance: $\sigma_{\text{spike}}^2 = 0.01^2$ (very small)
- Slab variance: $\sigma_{\text{slab}}^2 = 1.0^2$ (moderate)
- Error variance: $\sigma^2 \sim \text{InverseGamma}(2, 1)$

We use a non-centered parameterization for better MCMC mixing.


In [None]:
print("=" * 80)
print("MODEL 3: Spike-and-Slab Variable Selection")
print("=" * 80)

with pm.Model() as model3:
    # Intercept
    beta_0 = pm.Normal("beta_0", mu=0.0, sigma=10.0)
    
    # Spike-and-slab priors (non-centered parameterization)
    pi = 0.5  # Prior inclusion probability
    spike_sd = 0.01
    slab_sd = 1.0
    
    # Inclusion indicators
    gamma = pm.Bernoulli("gamma", p=pi, shape=p)
    
    # Non-centered parameterization for better geometry
    beta_raw = pm.Normal("beta_raw", mu=0.0, sigma=1.0, shape=p)
    sd_beta = spike_sd + gamma * (slab_sd - spike_sd)
    beta = pm.Deterministic("beta", beta_raw * sd_beta)
    
    # Error variance
    sigma_sq = pm.InverseGamma("sigma_sq", alpha=2.0, beta=1.0)
    sigma = pm.Deterministic("sigma", pm.math.sqrt(sigma_sq))
    
    # Linear predictor
    mu = beta_0 + pm.math.dot(X_train, beta)
    
    # Likelihood
    y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y_train)

print("Model 3 built successfully!")
print(f"  Parameters: {p + 4} (1 intercept + {p} coefficients + {p} inclusion indicators + 1 error variance + 1 sigma)")


In [None]:
# Fit Model 3
print("\nSampling from posterior (Model 3)...")
print("Note: Spike-and-slab models can be slower due to discrete variables...")

with model3:
    trace3 = pm.sample(
        draws=2000,
        tune=2000,
        chains=4,
        cores=1,
        target_accept=0.95,
        random_seed=RNG_SEED,
        return_inferencedata=True,
        progressbar=True
    )

print("\n✓ Model 3 sampling completed!")

# Convergence diagnostics
print("\n" + "=" * 80)
print("MODEL 3: Convergence Diagnostics")
print("=" * 80)

rhat3 = az.rhat(trace3)
max_rhat3 = max([float(rhat3[var].max()) for var in rhat3.data_vars if var not in ['sigma', 'gamma']])
print(f"\nR-hat (max): {max_rhat3:.4f} (target: < 1.01)")

ess3 = az.ess(trace3)
min_ess3 = min([float(ess3[var].min()) for var in ess3.data_vars if var not in ['sigma', 'gamma']])
print(f"ESS (min): {min_ess3:.0f} (target: > 400)")

n_div3 = trace3.sample_stats.divergences.sum().item()
print(f"Divergences: {n_div3} (target: 0)")

# Posterior inclusion probabilities
gamma_samples = trace3.posterior['gamma'].values
pip = gamma_samples.mean(axis=(0, 1))  # Average across chains and draws

print("\nPosterior Inclusion Probabilities:")
print("=" * 80)
pip_df = pd.DataFrame({
    'Predictor': [pred.split(':')[0] for pred in selected_predictors],
    'PIP': pip
}).sort_values('PIP', ascending=False)
display(pip_df)


## 7. Model 4: Bayesian Neural Network (BNN)

We now fit a Bayesian Neural Network with spike-and-slab priors on the weights. This model can capture non-linear relationships between predictors and response.

### Model Specification

**Architecture:**
- Input layer: $p$ predictors
- Hidden layer: $q$ units with tanh activation
- Output layer: 1 unit (continuous response)

**Likelihood:**
$$y_i \sim \mathcal{N}(\mu_i, \sigma^2)$$

where
$$\mu_i = f_{\text{NN}}(x_i; W_1, b_1, W_2, b_2)$$

and $f_{\text{NN}}$ is a neural network with one hidden layer:
- Hidden: $h_i = \tanh(W_1 x_i + b_1)$
- Output: $\mu_i = W_2^T h_i + b_2$

**Prior Distributions:**
- All weights use spike-and-slab priors (similar to Model 3)
- Hidden layer: $W_1 \in \mathbb{R}^{p \times q}$, $b_1 \in \mathbb{R}^q$
- Output layer: $W_2 \in \mathbb{R}^q$, $b_2 \in \mathbb{R}$
- Error variance: $\sigma^2 \sim \text{InverseGamma}(2, 1)$

We use $q = 5$ hidden units as a balance between flexibility and computational tractability.


In [None]:
def create_bnn_model(X_train, y_train, q=5, use_noncentered=True):
    """
    Create Bayesian Neural Network with spike-and-slab priors.
    
    Parameters:
    -----------
    X_train : array
        Training features (n x p)
    y_train : array
        Training targets (n,)
    q : int
        Number of hidden units
    use_noncentered : bool
        Use non-centered parameterization
        
    Returns:
    --------
    pm.Model
        PyMC model
    """
    n, p = X_train.shape
    
    with pm.Model() as model:
        # Intercept
        beta_0 = pm.Normal("beta_0", mu=0.0, sigma=10.0)
        
        # Spike-and-slab parameters
        pi = 0.5
        spike_sd = 0.01
        slab_sd = 1.0
        
        # Layer 1: Input -> Hidden (p x q weights)
        gamma1 = pm.Bernoulli("gamma1", p=pi, shape=(p, q))
        
        if use_noncentered:
            W1_raw = pm.Normal("W1_raw", mu=0.0, sigma=1.0, shape=(p, q))
            sd1 = spike_sd + gamma1 * (slab_sd - spike_sd)
            W1 = pm.Deterministic("W1", W1_raw * sd1)
        else:
            sd1 = spike_sd + gamma1 * (slab_sd - spike_sd)
            W1 = pm.Normal("W1", mu=0.0, sigma=sd1, shape=(p, q))
        
        b1 = pm.Normal("b1", mu=0.0, sigma=1.0, shape=q)
        
        # Hidden layer activation
        hidden_raw = pm.math.dot(X_train, W1) + b1
        hidden = pm.math.tanh(hidden_raw)
        
        # Layer 2: Hidden -> Output (q weights)
        gamma2 = pm.Bernoulli("gamma2", p=pi, shape=q)
        
        if use_noncentered:
            W2_raw = pm.Normal("W2_raw", mu=0.0, sigma=1.0, shape=q)
            sd2 = spike_sd + gamma2 * (slab_sd - spike_sd)
            W2 = pm.Deterministic("W2", W2_raw * sd2)
        else:
            sd2 = spike_sd + gamma2 * (slab_sd - spike_sd)
            W2 = pm.Normal("W2", mu=0.0, sigma=sd2, shape=q)
        
        b2 = pm.Normal("b2", mu=0.0, sigma=1.0)
        
        # Output
        mu = beta_0 + pm.math.dot(hidden, W2) + b2
        
        # Error variance
        sigma_sq = pm.InverseGamma("sigma_sq", alpha=2.0, beta=1.0)
        sigma = pm.Deterministic("sigma", pm.math.sqrt(sigma_sq))
        
        # Likelihood
        y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y_train)
    
    return model


print("=" * 80)
print("MODEL 4: Bayesian Neural Network (BNN)")
print("=" * 80)

q_hidden = 5  # Number of hidden units
model4 = create_bnn_model(X_train, y_train, q=q_hidden, use_noncentered=True)

print(f"Model 4 built successfully!")
print(f"  Architecture: {p} inputs -> {q_hidden} hidden (tanh) -> 1 output")
print(f"  Parameters: ~{p * q_hidden + q_hidden + p + q_hidden + 3} (weights + biases + variance)")


In [None]:
# Fit Model 4
print("\nSampling from posterior (Model 4 - BNN)...")
print("⚠ Warning: BNNs can be slow due to many parameters and discrete variables...")
print("This may take 15-30 minutes...")

with model4:
    trace4 = pm.sample(
        draws=2000,
        tune=2000,
        chains=4,
        cores=1,
        target_accept=0.95,
        random_seed=RNG_SEED,
        return_inferencedata=True,
        progressbar=True
    )

print("\n✓ Model 4 (BNN) sampling completed!")

# Convergence diagnostics
print("\n" + "=" * 80)
print("MODEL 4: Convergence Diagnostics")
print("=" * 80)

rhat4 = az.rhat(trace4)
vars_to_check = [v for v in rhat4.data_vars if v not in ['sigma', 'gamma1', 'gamma2', 'W1', 'W2']]
if vars_to_check:
    max_rhat4 = max([float(rhat4[var].max()) for var in vars_to_check])
    print(f"\nR-hat (max): {max_rhat4:.4f} (target: < 1.01)")

ess4 = az.ess(trace4)
if vars_to_check:
    min_ess4 = min([float(ess4[var].min()) for var in vars_to_check])
    print(f"ESS (min): {min_ess4:.0f} (target: > 400)")

n_div4 = trace4.sample_stats.divergences.sum().item()
print(f"Divergences: {n_div4} (target: 0)")


## 8. Model 5: Spatial Regression Model

Finally, we fit a spatial regression model that explicitly accounts for spatial structure in the tissue. This model includes spatial coordinates and their interactions as predictors.

### Model Specification

**Likelihood:**
$$y_i \sim \mathcal{N}(\mu_i, \sigma^2)$$

where
$$\mu_i = \beta_0 + \sum_{j=1}^{p} \beta_j x_{ij} + \beta_{cx} \cdot cx_i + \beta_{cy} \cdot cy_i + \beta_{cx \times cy} \cdot cx_i \cdot cy_i$$

**Prior Distributions:**
- Intercept: $\beta_0 \sim \mathcal{N}(0, 10^2)$
- Marker coefficients: $\beta_j \sim \mathcal{N}(0, \tau^2)$ with $\tau^2 \sim \text{InverseGamma}(2, 1)$
- Spatial coefficients: $\beta_{cx}, \beta_{cy} \sim \mathcal{N}(0, 2^2)$
- Spatial interaction: $\beta_{cx \times cy} \sim \mathcal{N}(0, 1^2)$
- Error variance: $\sigma^2 \sim \text{InverseGamma}(2, 1)$

The spatial terms capture location-dependent effects and interactions that may reflect tissue structure.


In [None]:
print("=" * 80)
print("MODEL 5: Spatial Regression Model")
print("=" * 80)

with pm.Model() as model5:
    # Intercept
    beta_0 = pm.Normal("beta_0", mu=0.0, sigma=10.0)
    
    # Hierarchical prior for marker coefficients
    tau_sq = pm.InverseGamma("tau_sq", alpha=2.0, beta=1.0)
    beta_markers = pm.Normal("beta_markers", mu=0.0, sigma=pm.math.sqrt(tau_sq), shape=p)
    
    # Spatial coordinate effects
    beta_cx = pm.Normal("beta_cx", mu=0.0, sigma=2.0)
    beta_cy = pm.Normal("beta_cy", mu=0.0, sigma=2.0)
    
    # Spatial interaction
    beta_cx_cy = pm.Normal("beta_cx_cy", mu=0.0, sigma=1.0)
    
    # Error variance
    sigma_sq = pm.InverseGamma("sigma_sq", alpha=2.0, beta=1.0)
    sigma = pm.Deterministic("sigma", pm.math.sqrt(sigma_sq))
    
    # Linear predictor (includes spatial terms)
    mu = (beta_0 
          + pm.math.dot(X_train, beta_markers)
          + beta_cx * coords_train[:, 0]
          + beta_cy * coords_train[:, 1]
          + beta_cx_cy * coords_train[:, 0] * coords_train[:, 1])
    
    # Likelihood
    y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y_train)

print("Model 5 built successfully!")
print(f"  Parameters: {p + 6} (1 intercept + {p} marker coefficients + 3 spatial terms + 1 global scale + 1 error variance + 1 sigma)")


In [None]:
# Fit Model 5
print("\nSampling from posterior (Model 5)...")

with model5:
    trace5 = pm.sample(
        draws=2000,
        tune=2000,
        chains=4,
        cores=1,
        target_accept=0.95,
        random_seed=RNG_SEED,
        return_inferencedata=True,
        progressbar=True
    )

print("\n✓ Model 5 sampling completed!")

# Convergence diagnostics
print("\n" + "=" * 80)
print("MODEL 5: Convergence Diagnostics")
print("=" * 80)

rhat5 = az.rhat(trace5)
max_rhat5 = max([float(rhat5[var].max()) for var in rhat5.data_vars if var != 'sigma'])
print(f"\nR-hat (max): {max_rhat5:.4f} (target: < 1.01)")

ess5 = az.ess(trace5)
min_ess5 = min([float(ess5[var].min()) for var in ess5.data_vars if var != 'sigma'])
print(f"ESS (min): {min_ess5:.0f} (target: > 400)")

n_div5 = trace5.sample_stats.divergences.sum().item()
print(f"Divergences: {n_div5} (target: 0)")


## 9. Model Comparison

We compare all five models using:
- **PSIS-LOO** (Pareto-smoothed importance sampling leave-one-out cross-validation)
- **WAIC** (Widely Applicable Information Criterion)
- **Predictive performance** on test set (RMSE, MAE, R²)

These metrics help us identify which model best balances fit and complexity.


In [None]:
# Compute PSIS-LOO and WAIC for all models
print("=" * 80)
print("MODEL COMPARISON")
print("=" * 80)

traces = [trace1, trace2, trace3, trace4, trace5]
model_names = ['Model 1: Noninformative', 'Model 2: Hierarchical', 
               'Model 3: Spike-and-Slab', 'Model 4: BNN', 'Model 5: Spatial']

comparison_results = []

for i, (trace, name) in enumerate(zip(traces, model_names), 1):
    print(f"\nComputing metrics for {name}...")
    
    try:
        loo = az.loo(trace, pointwise=True)
        waic = az.waic(trace, pointwise=True)
        
        comparison_results.append({
            'Model': name,
            'LOO': loo.loo,
            'LOO_SE': loo.loo_se,
            'WAIC': waic.waic,
            'WAIC_SE': waic.waic_se,
            'p_LOO': loo.p_loo,
            'p_WAIC': waic.p_waic
        })
        
        print(f"  LOO: {loo.loo:.2f} ± {loo.loo_se:.2f}")
        print(f"  WAIC: {waic.waic:.2f} ± {waic.waic_se:.2f}")
        
    except Exception as e:
        print(f"  ⚠ Error computing LOO/WAIC: {str(e)[:100]}")
        comparison_results.append({
            'Model': name,
            'LOO': np.nan,
            'LOO_SE': np.nan,
            'WAIC': np.nan,
            'WAIC_SE': np.nan,
            'p_LOO': np.nan,
            'p_WAIC': np.nan
        })

comparison_df = pd.DataFrame(comparison_results)
print("\n" + "=" * 80)
print("Model Comparison Summary")
print("=" * 80)
display(comparison_df)

# Save for report
comparison_df.to_csv('report/model_comparison.csv', index=False)


## 10. Posterior Predictive Analysis

We generate posterior predictive distributions and evaluate model fit through:
- Posterior predictive checks
- Test set predictions
- Residual analysis
- Model fit statistics (R², RMSE, MAE)


In [None]:
# Generate posterior predictive samples for test set
# We'll do this for Model 5 (Spatial) as an example

print("=" * 80)
print("POSTERIOR PREDICTIVE ANALYSIS (Model 5: Spatial)")
print("=" * 80)

# Create model with test data
with pm.Model() as model5_test:
    # Same priors as Model 5
    beta_0 = pm.Normal("beta_0", mu=0.0, sigma=10.0)
    tau_sq = pm.InverseGamma("tau_sq", alpha=2.0, beta=1.0)
    beta_markers = pm.Normal("beta_markers", mu=0.0, sigma=pm.math.sqrt(tau_sq), shape=p)
    beta_cx = pm.Normal("beta_cx", mu=0.0, sigma=2.0)
    beta_cy = pm.Normal("beta_cy", mu=0.0, sigma=2.0)
    beta_cx_cy = pm.Normal("beta_cx_cy", mu=0.0, sigma=1.0)
    sigma_sq = pm.InverseGamma("sigma_sq", alpha=2.0, beta=1.0)
    sigma = pm.Deterministic("sigma", pm.math.sqrt(sigma_sq))
    
    # Training likelihood
    mu_train = (beta_0 
                + pm.math.dot(X_train, beta_markers)
                + beta_cx * coords_train[:, 0]
                + beta_cy * coords_train[:, 1]
                + beta_cx_cy * coords_train[:, 0] * coords_train[:, 1])
    y_obs = pm.Normal("y_obs", mu=mu_train, sigma=sigma, observed=y_train)
    
    # Test predictions
    mu_test = (beta_0 
               + pm.math.dot(X_test, beta_markers)
               + beta_cx * coords_test[:, 0]
               + beta_cy * coords_test[:, 1]
               + beta_cx_cy * coords_test[:, 0] * coords_test[:, 1])
    y_pred = pm.Normal("y_pred", mu=mu_test, sigma=sigma, shape=n_test)

# Sample from posterior using training data
print("\nSampling from posterior (with test predictions)...")
with model5_test:
    trace5_test = pm.sample(
        draws=2000,
        tune=2000,
        chains=4,
        cores=1,
        target_accept=0.95,
        random_seed=RNG_SEED,
        return_inferencedata=True,
        progressbar=True
    )

print("\n✓ Sampling completed!")

# Extract test predictions
y_pred_samples = trace5_test.posterior_predictive['y_pred'].values
y_pred_mean = y_pred_samples.mean(axis=(0, 1))  # Average across chains and draws
y_pred_std = y_pred_samples.std(axis=(0, 1))

# Compute metrics
rmse = np.sqrt(((y_test - y_pred_mean) ** 2).mean())
mae = np.abs(y_test - y_pred_mean).mean()
ss_res = ((y_test - y_pred_mean) ** 2).sum()
ss_tot = ((y_test - y_test.mean()) ** 2).sum()
r_squared = 1 - (ss_res / ss_tot)

print("\n" + "=" * 80)
print("Test Set Predictive Performance (Model 5)")
print("=" * 80)
print(f"RMSE: {rmse:.4f}")
print(f"MAE: {mae:.4f}")
print(f"R²: {r_squared:.4f}")


In [None]:
# Figure 3: Posterior predictive check
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Observed vs predicted
axes[0].scatter(y_test, y_pred_mean, alpha=0.5, s=10, rasterized=True)
axes[0].plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 
             'r--', linewidth=2, label='Perfect prediction')
axes[0].set_xlabel('Observed (standardized)', fontsize=12)
axes[0].set_ylabel('Predicted (standardized)', fontsize=12)
axes[0].set_title(f'Posterior Predictive: Observed vs Predicted\n(R² = {r_squared:.3f}, RMSE = {rmse:.3f})', 
                  fontsize=12, fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Residuals
residuals = y_test - y_pred_mean
axes[1].scatter(y_pred_mean, residuals, alpha=0.5, s=10, rasterized=True)
axes[1].axhline(0, color='red', linestyle='--', linewidth=2)
axes[1].set_xlabel('Predicted (standardized)', fontsize=12)
axes[1].set_ylabel('Residuals', fontsize=12)
axes[1].set_title('Residual Plot', fontsize=12, fontweight='bold')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('report/figures/figure3_posterior_predictive.png', dpi=300, bbox_inches='tight')
plt.show()

print("Figure 3: Posterior predictive checks for Model 5")


## 11. Posterior Summary and Credible Intervals

We summarize posterior distributions for key parameters, focusing on Model 5 (Spatial) as it incorporates both marker effects and spatial structure.


In [None]:
# Posterior summary for Model 5
print("=" * 80)
print("POSTERIOR SUMMARY: Model 5 (Spatial Regression)")
print("=" * 80)

summary5 = az.summary(trace5, var_names=['beta_0', 'beta_markers', 'beta_cx', 'beta_cy', 'beta_cx_cy', 'sigma_sq', 'tau_sq'])
print("\nFull Posterior Summary:")
display(summary5)

# Save for report
summary5.to_csv('report/posterior_summary_model5.csv')

# Create table of credible intervals for key coefficients
print("\n" + "=" * 80)
print("95% Credible Intervals for Key Coefficients (Model 5)")
print("=" * 80)

ci_data = []

# Intercept
beta_0_samples = trace5.posterior['beta_0'].values.flatten()
ci_lower = np.percentile(beta_0_samples, 2.5)
ci_upper = np.percentile(beta_0_samples, 97.5)
post_mean = beta_0_samples.mean()
ci_data.append({
    'Coefficient': 'Intercept (β₀)',
    'Posterior Mean': f'{post_mean:.3f}',
    '95% CI Lower': f'{ci_lower:.3f}',
    '95% CI Upper': f'{ci_upper:.3f}',
    'Excludes Zero': 'Yes' if (ci_lower > 0) or (ci_upper < 0) else 'No'
})

# Marker coefficients
for i, marker in enumerate(selected_predictors):
    beta_samples = trace5.posterior['beta_markers'][:, :, i].values.flatten()
    ci_lower = np.percentile(beta_samples, 2.5)
    ci_upper = np.percentile(beta_samples, 97.5)
    post_mean = beta_samples.mean()
    marker_short = marker.split(':')[0]
    ci_data.append({
        'Coefficient': f'β_{i+1} ({marker_short})',
        'Posterior Mean': f'{post_mean:.3f}',
        '95% CI Lower': f'{ci_lower:.3f}',
        '95% CI Upper': f'{ci_upper:.3f}',
        'Excludes Zero': 'Yes' if (ci_lower > 0) or (ci_upper < 0) else 'No'
    })

# Spatial coefficients
for var_name, label in [('beta_cx', 'β_cx (X coordinate)'), 
                        ('beta_cy', 'β_cy (Y coordinate)'),
                        ('beta_cx_cy', 'β_cx_cy (X×Y interaction)')]:
    samples = trace5.posterior[var_name].values.flatten()
    ci_lower = np.percentile(samples, 2.5)
    ci_upper = np.percentile(samples, 97.5)
    post_mean = samples.mean()
    ci_data.append({
        'Coefficient': label,
        'Posterior Mean': f'{post_mean:.3f}',
        '95% CI Lower': f'{ci_lower:.3f}',
        '95% CI Upper': f'{ci_upper:.3f}',
        'Excludes Zero': 'Yes' if (ci_lower > 0) or (ci_upper < 0) else 'No'
    })

ci_df = pd.DataFrame(ci_data)
print("\nCredible Intervals:")
display(ci_df)

# Save for report table
ci_df.to_csv('report/credible_intervals_model5.csv', index=False)


In [None]:
# Figure 4: Posterior distributions of selected coefficients
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
axes = axes.flatten()

# Plot posterior distributions for first 6 marker coefficients
for i in range(min(6, p)):
    beta_samples = trace5.posterior['beta_markers'][:, :, i].values.flatten()
    axes[i].hist(beta_samples, bins=50, density=True, alpha=0.7, edgecolor='black', color='steelblue')
    axes[i].axvline(0, color='red', linestyle='--', linewidth=2, label='Zero')
    
    # Add credible interval
    ci_lower = np.percentile(beta_samples, 2.5)
    ci_upper = np.percentile(beta_samples, 97.5)
    axes[i].axvline(ci_lower, color='blue', linestyle=':', linewidth=2, label='95% CI')
    axes[i].axvline(ci_upper, color='blue', linestyle=':', linewidth=2)
    
    marker_name = selected_predictors[i].split(':')[0]
    axes[i].set_title(f'Posterior: {marker_name}', fontsize=11, fontweight='bold')
    axes[i].set_xlabel('Coefficient Value', fontsize=10)
    axes[i].set_ylabel('Density', fontsize=10)
    axes[i].legend(fontsize=9)
    axes[i].grid(True, alpha=0.3)

plt.suptitle('Posterior Distributions of Selected Regression Coefficients (Model 5)', 
             fontsize=14, fontweight='bold', y=0.995)
plt.tight_layout()
plt.savefig('report/figures/figure4_posterior_coefficients.png', dpi=300, bbox_inches='tight')
plt.show()

print("Figure 4: Posterior distributions of selected coefficients")


## 12. Summary and Conclusions

### Key Findings

1. **Model Convergence**: All models achieved good convergence (R-hat < 1.01, ESS > 400)
2. **Model Comparison**: [Results from comparison table]
3. **Spatial Effects**: Spatial coordinates show [significant/non-significant] effects
4. **Marker Associations**: [Key findings about which markers are important]
5. **Predictive Performance**: Model 5 achieved R² = [value] on test set

### Biological Interpretation

[Interpretation of findings in context of tumor immunology]

### Model Recommendations

[Which model(s) to use and why]

### Limitations and Future Work

[Discussion of limitations and potential extensions]
