---
title: "Aalen's Dynamic Path Model"
subtitle: "Causal Inference with Time Varying Effects in PyMC"
categories: ["path-models", "sem", "causal inference"]
keep-ipynb: true
self-contained: true
draft: true
toc: true
execute: 
  freeze: auto 
  execute: true
jupyter: applied-bayesian-regression-modeling-env
image: 'forking_paths.jpg'
author:
    - url: https://nathanielf.github.io/
    - affiliation: PyMC dev
citation: true
---


hello world


In [None]:
import pymc as pm
import numpy as np 
import pandas as pd
import arviz as az
import pytensor.tensor as pt
from scipy.interpolate import BSpline

If you look to Odysseus on the morning the gates of Troy fell, he is well set up for a happy journey home. He is the architect of victory, his ships are loaded with spoils, and the wind is at his back. Yet, an odyssey can't be completed in a single day and conclusions drawn on the outset rarely survive journey's end. 

When we rely on static snapshots, like a single blood draw or a particular sales campaign, we are essentially watching Odysseus board his ship and guessing how the story ends. We ignore the __consequences emerging in time.__


In [None]:
df = pd.read_csv("aalen_simdata.csv")
df = df[['subject', 'x', 'dose', 'M', 'start', 'stop', 'event']]
df.head()

In [None]:
df.groupby(['x', 'dose'])[['event', 'M']].agg(['mean', 'sum'])

In [None]:
# | code-fold: true

import matplotlib.pyplot as plt
import pandas as pd

# Derive subject-level info for ordering
subject_info = (
    df.groupby('subject')
      .agg(
          x=('x', 'first'),
          max_stop=('stop', 'max')
      )
      .sort_values(['x', 'max_stop'])
)

subjects = subject_info.index.tolist()
subject_to_y = {s: i for i, s in enumerate(subjects)}

fig, ax = plt.subplots(figsize=(8, 0.1 * len(subjects)))

for _, row in df.iterrows():
    y = subject_to_y[row['subject']]
    
    color = 'tab:blue' if row['x'] == 1 else 'tab:orange'
    
    ax.hlines(
        y=y,
        xmin=row['start'],
        xmax=row['stop'],
        color=color,
        linewidth=3
    )
    
    if row['event'] == 1:
        ax.plot(
            row['stop'],
            y,
            marker='o',
            color='red',
            markersize=6,
            zorder=3
        )

# Axis formatting
ax.set_yticks(range(len(subjects)))
ax.set_yticklabels(subjects)
ax.set_xlabel("Time")
ax.set_ylabel("Subject")

# Visual separation between treatment groups
x0_count = (subject_info['x'] == 0).sum()
ax.axhline(x0_count - 0.5, color='black', linestyle='--', linewidth=1)

# Legend
from matplotlib.lines import Line2D

legend_elements = [
    Line2D([0], [0], color='tab:blue', lw=3, label='x = 1'),
    Line2D([0], [0], color='tab:orange', lw=3, label='x = 0'),
    Line2D([0], [0], marker='o', color='red', lw=0, label='Event', markersize=6)
]

ax.legend(handles=legend_elements, loc='upper right')

ax.set_title("Subject Timelines Ordered by Treatment Level")

plt.tight_layout()
plt.show()


## Data Preparation


In [None]:
def prepare_aalen_dpa_data(
    df,
    subject_col="subject",
    start_col="start",
    stop_col="stop",
    event_col="event",
    x_col="x",
    m_col="M",
):
    """
    Prepare Andersen–Gill / Aalen dynamic path data for PyMC.

    Parameters
    ----------
    df : pd.DataFrame
        Long-format start–stop survival data
    subject_col : str
        Subject identifier
    start_col, stop_col : str
        Interval boundaries
    event_col : str
        Event indicator (0/1)
    x_col : str
        Exposure / treatment
    m_col : str
        Mediator measured at interval start

    Returns
    -------
    dict
        Dictionary of numpy arrays ready for PyMC
    """

    df = df.copy()

    # -------------------------------------------------
    # 1. Basic quantities
    # -------------------------------------------------
    df["dt"] = df[stop_col] - df[start_col]

    if (df["dt"] <= 0).any():
        raise ValueError("Non-positive interval lengths detected.")

    N = df[event_col].astype(int).values
    Y = np.ones(len(df), dtype=int)  # Andersen–Gill at-risk indicator

    # -------------------------------------------------
    # 2. Time-bin indexing (piecewise-constant effects)
    # -------------------------------------------------
    bins = (
        df[[start_col, stop_col]]
        .drop_duplicates()
        .sort_values([start_col, stop_col])
        .reset_index(drop=True)
    )
    bins["bin_idx"] = np.arange(len(bins))

    df = df.merge(
        bins,
        on=[start_col, stop_col],
        how="left",
        validate="many_to_one"
    )

    bin_idx = df["bin_idx"].values
    n_bins = bins.shape[0]

    # -------------------------------------------------
    # 3. Center covariates (important for Aalen models)
    # -------------------------------------------------
    df["x_c"] = df[x_col]
    df["m_c"] = df[m_col] - df[m_col].mean()

    x = df["x_c"].values
    m = df["m_c"].values

    # -------------------------------------------------
    # 4. Predictable mediator (lag within subject)
    # -------------------------------------------------
    df = df.sort_values([subject_col, start_col])

    df["m_lag"] = (
        df.groupby(subject_col)["m_c"]
          .shift(1)
          .fillna(0.0)
    )

    m_lag = df["m_lag"].values

    df["I_low"]  = (df["dose"] == "low").astype(int)
    df["I_high"] = (df["dose"] == "high").astype(int)

    # -------------------------------------------------
    # 5. Assemble output
    # -------------------------------------------------
    data = {
        "bins": bins,     # useful for plotting
        "df_long": df     # optional: debugging / inspection
    }

    return data


In [None]:
data = prepare_aalen_dpa_data(df)
df_long = data['df_long']
df_long[['subject', 'x', 'dose', 'M', 'event', 'dt', 'bin_idx']].head(14)

In [None]:
def create_bspline_basis(n_bins, n_knots=10, degree=3):
    """
    Create B-spline basis functions for smooth time-varying effects.
    
    Parameters
    ----------
    n_bins : int
        Number of time bins
    n_knots : int
        Number of internal knots (fewer = smoother)
    degree : int
        Degree of spline (3 = cubic, recommended)
    
    Returns
    -------
    basis : np.ndarray
        Matrix of shape (n_bins, n_basis) with basis function values
    """
    # Create knot sequence
    # Internal knots equally spaced across time range
    internal_knots = np.linspace(0, n_bins-1, n_knots)
    
    # Add boundary knots (repeated degree+1 times for clamped spline)
    knots = np.concatenate([
        np.repeat(internal_knots[0], degree),
        internal_knots,
        np.repeat(internal_knots[-1], degree)
    ])
    
    # Number of basis functions
    n_basis = len(knots) - degree - 1
    
    # Evaluate each basis function at each time point
    t = np.arange(n_bins, dtype=float)
    basis = np.zeros((n_bins, n_basis))
    
    for i in range(n_basis):
        # Create coefficient vector (indicator for basis i)
        coef = np.zeros(n_basis)
        coef[i] = 1.0
        
        # Evaluate B-spline
        spline = BSpline(knots, coef, degree, extrapolate=False)
        basis[:, i] = spline(t)
    
    return basis

n_knots = 10
n_bins = data['bins'].shape[0]
basis = create_bspline_basis(n_bins, n_knots=n_knots, degree=3)
n_cols = basis.shape[1]
basis_df = pd.DataFrame(basis, columns=[f'feature_{i}' for i in range(n_cols)])
basis_df.head(10)

In [None]:
# | output: false

def make_model(data, basis, sample=True, observed=True): 
    df_long = data['df_long'].copy()
    n_basis = basis.shape[1]
    n_obs = data['df_long'].shape[0]
    time_bins = data['bins']['bin_idx'].values
    b = df_long['bin_idx']

    observed_mediator = df_long["m_c"].values
    observed_events = df_long['event'].astype(int).values
    observed_treatment = df_long['x'].astype(int).values
    observed_mediator_lag = df_long['m_lag'].values

    coords = {'tv': ['intercept', 'direct', 'mediator'], 
            'splines': ['spline_f_{i}' for i in range(n_basis)], 
            'obs': range(n_obs), 
            'time_bins': time_bins}

    with pm.Model(coords=coords) as aalen_dpa_model:

        trt = pm.Data("trt", observed_treatment, dims="obs")
        med = pm.Data("mediator", observed_mediator, dims="obs")
        med_lag = pm.Data("mediator_lag", observed_mediator_lag, dims="obs")
        events = pm.Data("events", observed_events, dims="obs")
        I_low  = pm.Data("I_low",  df_long["I_low"].values,  dims="obs")
        I_high = pm.Data("I_high", df_long["I_high"].values, dims="obs")
        dt = pm.Data("duration", df_long['dt'].values, dims='obs')
        ## because our long data format has a cell per obs
        at_risk = pm.Data("at_risk", np.ones(len(observed_events)), dims="obs")
        basis_ = pm.Data("basis", basis, dims=('time_bins', 'splines') )

        # -------------------------------------------------
        # 1. B-spline coefficients for HAZARD model
        # -------------------------------------------------
        # Prior on spline coefficients
        # Smaller sigma = less wiggliness
        # Random Walk 1 (RW1) Prior for coefficients
        # This is the Bayesian version of the smoothing penalty in R's 'mgcv' or 'timereg'
        sigma_smooth = pm.Exponential("sigma_smooth", [1, 1, 1], dims='tv')
        beta_raw = pm.Normal("beta_raw", 0, 1, dims=('splines', 'tv'))

        # Cumulative sum makes it a Random Walk
        # This ensures coefficients evolve smoothly over time
        coef_alpha = pm.Deterministic("coef_alpha", pt.cumsum(beta_raw * sigma_smooth, axis=0), dims=('splines', 'tv'))

        # Construct smooth time-varying functions
        alpha_0_t = pt.dot(basis_, coef_alpha[:, 0])
        alpha_1_t = pt.dot(basis_, coef_alpha[:, 1])
        alpha_2_t = pt.dot(basis_, coef_alpha[:, 2])
        
        # -------------------------------------------------
        # 2. B-spline coefficients for MEDIATOR model
        # -------------------------------------------------
        sigma_beta_smooth = pm.Exponential("sigma_beta_smooth", 0.1)
        beta_raw = pm.Normal("beta_raw_m", 0, 1, dims=('splines'))
        coef_beta = pt.cumsum(beta_raw * sigma_beta_smooth)
        
        beta_t = pt.dot(basis_, coef_beta)

        # -------------------------------------------------
        # 3. Mediator model (A path: x → M)
        # -------------------------------------------------
        sigma_m = pm.HalfNormal("sigma_m", 1.0)
        
        # Autoregressive component
        rho = pm.Beta("rho", 2, 2)
        
        mu_m = beta_t[b] * trt + rho * med_lag

        pm.Normal(
            "obs_m",
            mu=mu_m,
            sigma=sigma_m,
            observed=med,
            dims='obs'
        )

        # -------------------------------------------------
        # 4. Hazard model (direct + B path)
        # -------------------------------------------------
        beta_low  = pm.Normal("beta_low",  0, 0.1)
        beta_high = pm.Normal("beta_high", 0, 0.1)
        # Log-additive hazard
        log_lambda_t = (alpha_0_t[b] 
                        + alpha_1_t[b] * trt # direct effect
                        + alpha_2_t[b] * med  # mediator effect
                        + beta_low  * I_low
                        + beta_high * I_high
        )
        
        # Expected number of events
        time_at_risk = at_risk * dt
        Lambda = time_at_risk * pm.math.log1pexp(log_lambda_t)

        if observed:
            pm.Poisson(
                "obs_event",
                mu=Lambda,
                observed=events, 
                dims='obs'
            )
        else: 
            pm.Poisson(
                "obs_event",
                mu=Lambda,
                dims='obs'
            )

        # -------------------------------------------------
        # 5. Causal path effects
        # -------------------------------------------------
        # Store time-varying coefficients
        pm.Deterministic("alpha_0_t", alpha_0_t, dims='time_bins')
        pm.Deterministic("alpha_1_t", alpha_1_t, dims='time_bins')  # direct effect
        pm.Deterministic("alpha_2_t", alpha_2_t, dims='time_bins')  # B path
        pm.Deterministic("beta_t", beta_t, dims='time_bins')        # A path
        
        # Cumulative direct effect
        cum_de = pm.Deterministic(
            "tv_direct_effect",
            alpha_1_t, 
            dims='time_bins'
        )

        # Cumulative indirect effect (product of paths)
        cum_ie = pm.Deterministic(
            "tv_indirect_effect",
            beta_t * alpha_2_t, 
            dims='time_bins'
        )

        # Total effect
        cum_te = pm.Deterministic(
            "tv_total_effect",
            cum_de + cum_ie,
            dims='time_bins'
        )

        pm.Deterministic('tv_baseline_hazard', pm.math.log1pexp(alpha_0_t), 
            dims='time_bins')

        pm.Deterministic('tv_hazard_with_exposure', pm.math.log1pexp(alpha_0_t + alpha_1_t), 
            dims='time_bins')

        pm.Deterministic(
        "tv_RR",
        pm.math.log1pexp(alpha_0_t + alpha_1_t) /
        pm.math.log1pexp(alpha_0_t),
        dims="time_bins"
        )

        # -------------------------------------------------
        # 6. Sample
        # -------------------------------------------------
        if sample:
            idata = pm.sample_prior_predictive()
            idata.extend(pm.sample(
                draws=2000,
                tune=2000,
                target_accept=0.95,
                chains=4,
                nuts_sampler="numpyro",
                random_seed=42,
                init="adapt_diag", 
                idata_kwargs={"log_likelihood": True}
            ))
            idata.extend(pm.sample_posterior_predictive(idata))
    
    return aalen_dpa_model, idata

basis = create_bspline_basis(n_bins, n_knots=12, degree=3)
aalen_dpa_model, idata_aalen =  make_model(data, basis)

In [None]:
pm.model_to_graphviz(aalen_dpa_model)

In [None]:
#| eval: false

models = {}
idatas = {}
for i in range(4, 15, 2):
    basis = create_bspline_basis(n_bins, n_knots=i, degree=3)
    aalen_dpa_model, idata = make_model(data, basis)
    models[i] = aalen_dpa_model
    idatas[f"splines_{i}"] = idata

compare_df = az.compare(idatas, var_name='obs_event')
az.plot_compare(compare_df, figsize=(8, 6), plot_ic_diff=True)

![](spline_loo_comparison.png)


In [None]:
# | eval: false

ax = az.plot_forest([idatas[k] for k in idatas.keys()], combined=True, var_names=['tv_direct_effect'], model_names=idatas.keys(), coords={'time_bins': [180, 182, 182, 183, 184, 185, 186, 187, 188]}, 
figsize=(12, 10),  r_hat=True)
ax[0].set_title("Time Vary Direct Effects \n Comparing Models on Final Time Intervals", fontsize=15)
ax[0].set_ylabel("Nth Time Interval", fontsize=15)
fig = ax[0].figure
fig.savefig('forest_plot_comparing_tv_direct.png')

![](forest_plot_comparing_tv_direct.png)


In [None]:
az.plot_trace(idata_aalen, var_names=['tv_direct_effect', 'tv_indirect_effect', 'tv_total_effect', 'beta_high', 'beta_low'], divergences=False);
plt.tight_layout()


In [None]:
vars_to_plot = ['tv_direct_effect', 'tv_indirect_effect', 'tv_total_effect']
labels = ['Time varying Direct Effect', 'Time varying Indirect Effect', 'Time varying Total Effect']

def plot_effects(idata, vars_to_plot, labels, scale="Log Hazard Ratio Scale"):
    fig, axs = plt.subplots(1, 3, figsize=(20, 10))
    color='teal'
    if scale != "Log Hazard Ratio Scale":
        color='darkred'

    for i, var in enumerate(vars_to_plot):
        # 1. Extract the posterior samples for this variable
        # Shape will be (chain * draw, time)
        post_samples = az.extract(idata, var_names=[var]).values.T
        
        # 2. Calculate the mean and the 94% HDI across the chains/draws
        mean_val = post_samples.mean(axis=0)
        hdi_val = az.hdi(post_samples, hdi_prob=0.94) # Returns [time, 2] array
        
        # 3. Plot the Mean line
        x_axis = np.arange(len(mean_val))
        axs[i].plot(x_axis, mean_val, label=labels[i], color=color, lw=2)
        
        # 4. Plot the Shaded HDI region
        axs[i].fill_between(x_axis, hdi_val[:, 0], hdi_val[:, 1], color=color, alpha=0.2, label='94% HDI')
        
        # Formatting
        axs[i].set_title(labels[i])
        axs[i].legend()
        axs[i].grid(alpha=0.3)
        axs[i].set_ylabel(scale)
    plt.tight_layout()
    return fig

plot_effects(idata_aalen, vars_to_plot, labels);

In [None]:
vars_to_plot = ['tv_baseline_hazard', 'tv_hazard_with_exposure', 'tv_RR']
labels = ['Time varying Baseline Hazard', 'Time varying Hazard + Exposure', 'Time varying RR']
plot_effects(idata_aalen, vars_to_plot, labels, scale='Hazard Scale');