---
title: "Aalen's Dynamic Path Model"
subtitle: "Causal Inference with Time Varying Effects in PyMC"
categories: ["path-models", "sem", "causal inference"]
keep-ipynb: true
self-contained: true
draft: true
toc: true
execute: 
  freeze: auto 
  execute: true
  eval: false
jupyter: applied-bayesian-regression-modeling-env
image: 'forking_paths.jpg'
author:
    - url: https://nathanielf.github.io/
    - affiliation: PyMC dev
citation: true
---


hello world


In [1]:
import pymc as pm
import numpy as np 
import pandas as pd
import arviz as az
import pytensor.tensor as pt
from scipy.interpolate import BSpline

If you look to Odysseus on the morning the gates of Troy fell, he is well set up for a happy journey home. He is the architect of victory, his ships are loaded with spoils, and the wind is at his back. Yet, an odyssey can't be completed in a single day and conclusions drawn on the outset rarely survive journey's end. 

When we rely on static snapshots, like a single blood draw or a particular sales campaign, we are essentially watching Odysseus board his ship and guessing how the story ends. We ignore the __consequences emerging in time.__ This is an apt observation with which to begin the year. Will your new year's resolutions, survive the January? This is the kind of question we'll assess here. 

## Interventions and Attenuated Effects

To track how intention precedes and predicts evolving outcomes we need a statistical framework that doesn't just record the departure but tracks the entire voyage. Enter **Aalen’s Additive Model**, formulated as a **Dynamic Path Model**. The question is when will we achieve our goals? When will Odysseus get home? How does the risk of attaining our goal vary over time?

While traditional models, like the Cox Proportional Hazards or Accelerated Failure time models, often assume that an intervention’s effect is a constant "multiplier" throughout the study period, Aalen’s approach treats effects as a living process. It allows the impact of a policy or treatment to wax, wane, or even reverse as the narrative unfolds. 

### The Machinery of Change

In a dynamic path system, we decompose the total risk into a series of additive "layers." If we are interested in how an intervention ($X$) works through a mediator ($M$), we model the instantaneous hazard $\lambda(t)$ as:

$$\lambda(t | X, M) = \alpha_0(t) + \alpha_1(t)X + \alpha_2(t)M$$

Where:

- $\alpha_0(t)$** is the **Baseline Hazard**, representing the background "tension" of the story.
- **$\alpha_1(t)$** is the **Time-Varying Direct Effect**, showing how the intervention influences risk at every specific moment.
- **$\alpha_2(t)$** is the **Time-Varying Mediator Effect**, capturing how the intermediate variable (the "storm" or the "detour") contributes to the outcome.

All three components are modelled as time-varying functions that distil the effects over time. 

### Visualizing the Evolving DAG

Standard causal inference often relies on a static Directed Acyclic Graph (DAG) to represent the "rules" of the system. But in a survival context, the DAG itself is dynamic. We can think of the model as a sequence of DAGs—one for each "scene" in our odyssey—where the causal arrows between $X$, $M$, and the Hazard ($\lambda$) strengthen or weaken over time.


We can visualize this "filmstrip" of causality using `networkx`. We have a series of dynamic causal DAGs representing our assumptions of the relationships:


In [2]:
# | code-fold: true
import matplotlib.pyplot as plt
import networkx as nx

def plot_temporal_dag(stages=[0.1, 0.5, 0.9], labels=["Act I", "Act II", "Act III"]):
    fig, axs = plt.subplots(1, 3, figsize=(18, 5))
    
    # Define nodes: Exposure, Mediator, Hazard
    nodes = {'X': (0, 1), 'M': (1, 2), 'H': (2, 1)}
    
    for i, (stage, label) in enumerate(zip(stages, labels)):
        G = nx.DiGraph()
        G.add_nodes_from(nodes.keys())
        
        # We vary the weights to represent alpha_1(t) and alpha_2(t)
        # Act I: Direct effect strong, Mediator weak
        # Act III: Mediator dominates or effects attenuate
        direct_w = 4 * (1 - stage) 
        indirect_w = 6 * stage
        
        edges = [('X', 'H', direct_w), ('X', 'M', 3), ('M', 'H', indirect_w)]
        
        for u, v, w in edges:
            G.add_edge(u, v, weight=w)
            
        pos = nodes
        nx.draw_networkx_nodes(G, pos, ax=axs[i], node_color='maroon', node_size=2000)
        nx.draw_networkx_labels(G, pos, ax=axs[i], font_color='white', font_weight='bold')
        
        # Draw edges with widths corresponding to causal strength
        weights = [G[u][v]['weight'] for u, v in G.edges()]
        nx.draw_networkx_edges(G, pos, ax=axs[i], width=weights, 
                               edge_color='gray', arrowsize=30, connectionstyle="arc3,rad=0.1")
        
        axs[i].set_title(f"{label}\n(Time-varying Causal Strengths)", fontsize=14)
        axs[i].axis('off')

    plt.tight_layout()
    return fig

fig = plot_temporal_dag()
fig.savefig("evolving_dag.png")

![](evolving_dag.png)

Seeing it this way makes it clearer that we are estimating a system of equations over time. To truly capture the journey, we cannot look at these variables in isolation. In our Odyssey, the "storm" ($M$) is not an independent accident; it is a consequence of the path Odysseus ($X$) chose to take. To model this, we must treat the intervention and the outcome as a **system of simultaneous equations**.

By solving these equations in tandem, we ensure that the mediator is not just another covariate, but a character with its own backstory, influenced by the intervention even as it influences the final risk.

#### 1. The Mediator Equation (The Will of Poseidon)
Before we can calculate the risk of shipwreck, we must determine how the intervention has altered the environment. We model the mediator as a function of the exposure:

$$M_i = \beta_0 + \beta_{1}(t)X_i + \epsilon_i, \quad \epsilon_i \sim \mathcal{N}(0, \sigma^2)$$

Where $\beta_{1}(t)$ tells us how effectively the intervention "recruits" or "triggers" the mediator.

#### 2. The Hazard Equation (The Risk of Shipwreck)
The instantaneous risk at time $t$ is then a combination of the background noise, the direct path from $X$, and the indirect path from the now-defined $M$:

$$\lambda(t | X, M) = g\left( \alpha_0(t) + \alpha_1(t)X_i + \alpha_2(t)M_i \right)$$

where $g$ is a linking function to translate the covariates into a hazard scale. Taken together these observations should make it clear why this modelling strategy is known as _dynamic path analysis_. We essentially have a structural equation model (SEM) with specific path coefficients that trace out the influence between variables. But we also require that we capture the time evolution of these relationships. It's not enough to say of Odysseus that he had good intentions, it all ended well. We lose something if we don't understand the journey.

### Causal Inference and Additive Effects

Aalen's main concern was how to represent mechanisms that unfold in time without forcing them into a proportional straitjacket as in Cox style regressions. In many applied settings, an intervention does not exert a constant relative effect. Instead we act, with effects accumulating, and sometimes attenuated or disappearing over time.

By modeling covariate effects as _increments to risk_, the coefficients themselves become causal estimands. A time-varying coefficient $\alpha_{1}(t)$ answers a concrete question: what is the instantaneous risk attributable to the exposure at time $t$? Once effects are additive, they can be decomposed and recombined. Direct, indirect, and total effects become sums of paths traced through time. The model reads naturally as a time-indexed causal graph with estimable path strengths.

> "We have a sequence of dynamic path models, one for each time t when we collect information. The estimation of each dynamic path model is done by recursive least squares regression as usual in path analysis" - Fosen et al in _"Dynamic path analysis – a new approach to
analyzing time-dependent covariates"_

The central objects in dynamic path analysis are the time-varying regression functions that link treatment, mediator, and survival. In the canonical `dpasurv` formulation, these are the functions $\beta_1(t), \alpha_2(t)$ and $\alpha_1(t)$, which parameterize the direct path from treatment to the mediator, the effect of mediator on the hazard, and the effect of the treatment on the hazard, respectively.

#### Cumulative Effects in `dpasurv`

Dynamic path analysis is concerned not with isolated coefficients, but with how these functions evolve over time. They form the building blocks from which direct, indirect, and total effects are constructed.

Because effects act continuously, interpretation is naturally expressed in cumulative terms. A cumulative coefficient is like a running total of a covariate’s effect on risk — it sums up, over time, how much that treatment or mediator has nudged the chance of an event happening. The **cumulative direct effect** up to time $t$ is defined as

$$\text{cumdir}(t) = \int_0^t \beta_1(s)\, ds,$$

representing the accumulated contribution of the treatment along the direct path to the hazard. In an additive hazards framework like Aalen’s model, we are directly estimating the cumulative coefficient function — the total accumulated effect up to time t, rather than first estimating the instantaneous effect $\beta_{1}(t)$ and then summing or integrating it. Similarly, the **cumulative indirect effect** is defined as

$$\text{cumind}(t) = \int_0^t \alpha_1(s)\, \beta_2(s)\, ds,$$

which aggregates the mediated influence of treatment on survival. At each instant, the indirect effect is obtained by multiplying the strength of the treatment–mediator link with the strength of the mediator–hazard link, mirroring the logic of path analysis. A defining feature of dynamic path analysis is that these quantities satisfy an exact analytical decomposition:

$$\text{cumtot}(t) = \text{cumdir}(t) + \text{cumind}(t).$$

This identity holds because the model is additive. The `dpasurv` package demonstrates these decomposition with the following data. 


In [3]:
df = pd.read_csv("aalen_simdata.csv")
df = df[['subject', 'x', 'dose', 'M', 'start', 'stop', 'event']]
df.head()

The data is an "long format" which is multiple rows per individual `subject` with running values for the state of the treatment indicator `x` and the mediator `M`, and crucially the `even` flag which denotes whether or not the terminal event occured. This is simulated data from the `dpasurv` r package, but the abstract names serve to illustrate the ubiquity of the mediator relationship! 

### Replicating the `dpasurv` Benchmark

The figure belwo shows the cumulative direct, indirect, and total effects estimated by dpasurv from our simulated dataset. The direct effect (left panel) traces the immediate influence of the treatment on the outcome, while the indirect effect (middle panel) captures the pathway mediated through $M$. The total effect (right panel) is simply the sum of the two.

![](dpasurv_benchmark.png)

Notice the jumpy, step-like patterns in all three panels. This is characteristic of the `dpasurv` estimator, which produces cumulative coefficients by summing contributions at discrete event time intervals. Each step corresponds to a unique event time, with the height of the step reflecting the combined contribution of all events that occurred at that time. The gray lines indicate the approximate bootstrap confidence intervals, which widen as events become sparse, reflecting increasing uncertainty over time.

The main pattern we're seeing here is that the direct effect of `x` in the first panel. The curve begins to dip below zero almost immediately and maintains a consistent downward slope, meaning the negative cumulative effect indicates that the intervention itself is actively __reducing the hazard.__. This is good, but contrast it with slight emerging effect in central panel.For the first 100 days, the mediator is a bystander. It is either not being triggered by the intervention or has no impact on survival. Around Day 100, the curve "breaks" and spikes upward. This represents a __positive contribution to the hazard.__ The combination of the total effects reflects this too. The total effect become less negative (closer to zero) after Day 100. This tells us that while the intervention is still helpful overall, its net benefit is being attenuated by the mediator.


These are the patterns we will try and replicate in our Bayesian dynamic path model. But first we'll look a bit more into the data. 

## Exploring the Data


In [4]:
df.groupby(['x', 'dose'])[['event', 'M']].agg(['mean', 'sum'])

In [5]:
# | code-fold: true
# | output: false
# | echo: false

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

# -----------------------
# Color palette (muted & clean)
# -----------------------
COLOR_X0 = "#6B7280"    # muted slate gray
COLOR_X1 = "#2563EB"    # soft rich blue
EVENT_COLOR = "#991B1B" # deep muted red

# -----------------------
# Split subject ordering by treatment group
# -----------------------
subject_info = (
    df.groupby('subject')
      .agg(
          x=('x', 'first'),
          max_stop=('stop', 'max')
      )
)

subject_info_0 = (
    subject_info[subject_info['x'] == 0]
    .sort_values('max_stop')
)

subject_info_1 = (
    subject_info[subject_info['x'] == 1]
    .sort_values('max_stop')
)

subjects_0 = subject_info_0.index.tolist()
subjects_1 = subject_info_1.index.tolist()

subject_to_y_0 = {s: i for i, s in enumerate(subjects_0)}
subject_to_y_1 = {s: i for i, s in enumerate(subjects_1)}

# -----------------------
# Create side-by-side plots
# -----------------------
fig, axes = plt.subplots(
    ncols=2,
    figsize=(12, 0.1 * max(len(subjects_0), len(subjects_1))),
    sharex=True
)

ax0, ax1 = axes

# -----------------------
# Plot x = 0 group
# -----------------------
for _, row in df[df['x'] == 0].iterrows():
    y = subject_to_y_0[row['subject']]

    ax0.hlines(
        y=y,
        xmin=row['start'],
        xmax=row['stop'],
        color=COLOR_X0,
        linewidth=2.5
    )

    if row['event'] == 1:
        ax0.plot(
            row['stop'],
            y,
            marker='o',
            color=EVENT_COLOR,
            markersize=6,
            zorder=3
        )

ax0.set_yticks(range(len(subjects_0)))
ax0.set_yticklabels(subjects_0)
ax0.set_title("x = 0")
ax0.set_ylabel("Subject")

# -----------------------
# Plot x = 1 group
# -----------------------
for _, row in df[df['x'] == 1].iterrows():
    y = subject_to_y_1[row['subject']]

    ax1.hlines(
        y=y,
        xmin=row['start'],
        xmax=row['stop'],
        color=COLOR_X1,
        linewidth=2.5
    )

    if row['event'] == 1:
        ax1.plot(
            row['stop'],
            y,
            marker='o',
            color=EVENT_COLOR,
            markersize=6,
            zorder=3
        )

ax1.set_yticks(range(len(subjects_1)))
ax1.set_yticklabels(subjects_1)
ax1.set_title("x = 1")

# -----------------------
# Shared formatting & polish
# -----------------------
for ax in axes:
    ax.set_xlabel("Time")
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.grid(False)

# -----------------------
# Legend
# -----------------------
legend_elements = [
    Line2D([0], [0], color=COLOR_X1, lw=2.5, label='x = 1'),
    Line2D([0], [0], color=COLOR_X0, lw=2.5, label='x = 0'),
    Line2D([0], [0], marker='o', color=EVENT_COLOR,
           lw=0, label='Event', markersize=6),
]

ax0.tick_params(axis='y', labelsize=8)
ax1.tick_params(axis='y', labelsize=8)

fig.legend(
    handles=legend_elements,
    loc='upper center',
    ncol=3,
    frameon=False,
    bbox_to_anchor=(0.5, 0.98)
)

fig.suptitle(
    "Subject Timelines by Treatment Group",
    y=1.08
)

plt.tight_layout(rect=[0, 0, 1, 0.95])

fig.savefig("timelines_plot.png")

![](timelines_plot.png)

### Data Preparation


In [6]:
def prepare_aalen_dpa_data(
    df,
    subject_col="subject",
    start_col="start",
    stop_col="stop",
    event_col="event",
    x_col="x",
    m_col="M",
):
    """
    Prepare Andersen–Gill / Aalen dynamic path data for PyMC.

    Parameters
    ----------
    df : pd.DataFrame
        Long-format start–stop survival data
    subject_col : str
        Subject identifier
    start_col, stop_col : str
        Interval boundaries
    event_col : str
        Event indicator (0/1)
    x_col : str
        Exposure / treatment
    m_col : str
        Mediator measured at interval start

    Returns
    -------
    dict
        Dictionary of numpy arrays ready for PyMC
    """

    df = df.copy()

    # -------------------------------------------------
    # 1. Basic quantities
    # -------------------------------------------------
    df["dt"] = df[stop_col] - df[start_col]

    if (df["dt"] <= 0).any():
        raise ValueError("Non-positive interval lengths detected.")

    N = df[event_col].astype(int).values
    Y = np.ones(len(df), dtype=int)  # Andersen–Gill at-risk indicator

    # -------------------------------------------------
    # 2. Time-bin indexing (piecewise-constant effects)
    # -------------------------------------------------
    bins = (
        df[[start_col, stop_col]]
        .drop_duplicates()
        .sort_values([start_col, stop_col])
        .reset_index(drop=True)
    )
    bins["bin_idx"] = np.arange(len(bins))

    df = df.merge(
        bins,
        on=[start_col, stop_col],
        how="left",
        validate="many_to_one"
    )

    bin_idx = df["bin_idx"].values
    n_bins = bins.shape[0]

    # -------------------------------------------------
    # 3. Center covariates (important for Aalen models)
    # -------------------------------------------------
    df["x_c"] = df[x_col]
    df["m_c"] = df[m_col] - df[m_col].mean()

    x = df["x_c"].values
    m = df["m_c"].values

    # -------------------------------------------------
    # 4. Predictable mediator (lag within subject)
    # -------------------------------------------------
    df = df.sort_values([subject_col, start_col])

    df["m_lag"] = (
        df.groupby(subject_col)["m_c"]
          .shift(1)
          .fillna(0.0)
    )

    m_lag = df["m_lag"].values

    df["I_low"]  = (df["dose"] == "low").astype(int)
    df["I_high"] = (df["dose"] == "high").astype(int)

    # -------------------------------------------------
    # 5. Assemble output
    # -------------------------------------------------
    data = {
        "bins": bins,     # useful for plotting
        "df_long": df     # optional: debugging / inspection
    }

    return data


In [7]:
data = prepare_aalen_dpa_data(df)
df_long = data['df_long']
df_long[['subject', 'x', 'dose', 'M', 'event', 'dt', 'bin_idx']].head(14)

In [8]:
def create_bspline_basis(n_bins, n_knots=10, degree=3):
    """
    Create B-spline basis functions for smooth time-varying effects.
    
    Parameters
    ----------
    n_bins : int
        Number of time bins
    n_knots : int
        Number of internal knots (fewer = smoother)
    degree : int
        Degree of spline (3 = cubic, recommended)
    
    Returns
    -------
    basis : np.ndarray
        Matrix of shape (n_bins, n_basis) with basis function values
    """
    # Create knot sequence
    # Internal knots equally spaced across time range
    internal_knots = np.linspace(0, n_bins-1, n_knots)
    
    # Add boundary knots (repeated degree+1 times for clamped spline)
    knots = np.concatenate([
        np.repeat(internal_knots[0], degree),
        internal_knots,
        np.repeat(internal_knots[-1], degree)
    ])
    
    # Number of basis functions
    n_basis = len(knots) - degree - 1
    
    # Evaluate each basis function at each time point
    t = np.arange(n_bins, dtype=float)
    basis = np.zeros((n_bins, n_basis))
    
    for i in range(n_basis):
        # Create coefficient vector (indicator for basis i)
        coef = np.zeros(n_basis)
        coef[i] = 1.0
        
        # Evaluate B-spline
        spline = BSpline(knots, coef, degree, extrapolate=False)
        basis[:, i] = spline(t)
    
    return basis

n_knots = 10
n_bins = data['bins'].shape[0]
basis = create_bspline_basis(n_bins, n_knots=n_knots, degree=3)
n_cols = basis.shape[1]
basis_df = pd.DataFrame(basis, columns=[f'feature_{i}' for i in range(n_cols)])
basis_df.head(10)

In [9]:
# | output: false

def make_model(data, basis, sample=True, observed=True): 
    df_long = data['df_long'].copy()
    n_basis = basis.shape[1]
    n_obs = data['df_long'].shape[0]
    time_bins = data['bins']['bin_idx'].values
    b = df_long['bin_idx']

    observed_mediator = df_long["m_c"].values
    observed_events = df_long['event'].astype(int).values
    observed_treatment = df_long['x'].astype(int).values
    observed_mediator_lag = df_long['m_lag'].values

    coords = {'tv': ['intercept', 'direct', 'mediator'], 
            'splines': ['spline_f_{i}' for i in range(n_basis)], 
            'obs': range(n_obs), 
            'time_bins': time_bins}

    with pm.Model(coords=coords) as aalen_dpa_model:

        trt = pm.Data("trt", observed_treatment, dims="obs")
        med = pm.Data("mediator", observed_mediator, dims="obs")
        med_lag = pm.Data("mediator_lag", observed_mediator_lag, dims="obs")
        events = pm.Data("events", observed_events, dims="obs")
        I_low  = pm.Data("I_low",  df_long["I_low"].values,  dims="obs")
        I_high = pm.Data("I_high", df_long["I_high"].values, dims="obs")
        dt = pm.Data("duration", df_long['dt'].values, dims='obs')
        ## because our long data format has a cell per obs
        at_risk = pm.Data("at_risk", np.ones(len(observed_events)), dims="obs")
        basis_ = pm.Data("basis", basis, dims=('time_bins', 'splines') )

        # -------------------------------------------------
        # 1. B-spline coefficients for HAZARD model
        # -------------------------------------------------
        # Prior on spline coefficients
        # Smaller sigma = less wiggliness
        # Random Walk 1 (RW1) Prior for coefficients
        # This is the Bayesian version of the smoothing penalty in R's 'mgcv' or 'timereg'
        sigma_smooth = pm.Exponential("sigma_smooth", [1, 1, 1], dims='tv')
        beta_raw = pm.Normal("beta_raw", 0, 1, dims=('splines', 'tv'))

        # Cumulative sum makes it a Random Walk
        # This ensures coefficients evolve smoothly over time
        coef_alpha = pm.Deterministic("coef_alpha", pt.cumsum(beta_raw * sigma_smooth, axis=0), dims=('splines', 'tv'))

        # Construct smooth time-varying functions
        alpha_0_t = pt.dot(basis_, coef_alpha[:, 0])
        alpha_1_t = pt.dot(basis_, coef_alpha[:, 1])
        alpha_2_t = pt.dot(basis_, coef_alpha[:, 2])
        
        # -------------------------------------------------
        # 2. B-spline coefficients for MEDIATOR model
        # -------------------------------------------------
        sigma_beta_smooth = pm.Exponential("sigma_beta_smooth", 0.1)
        beta_raw = pm.Normal("beta_raw_m", 0, 1, dims=('splines'))
        coef_beta = pt.cumsum(beta_raw * sigma_beta_smooth)
        
        beta_t = pt.dot(basis_, coef_beta)

        # -------------------------------------------------
        # 3. Mediator model (A path: x → M)
        # -------------------------------------------------
        sigma_m = pm.HalfNormal("sigma_m", 1.0)
        
        # Autoregressive component
        rho = pm.Beta("rho", 2, 2)
        
        mu_m = beta_t[b] * trt + rho * med_lag

        pm.Normal(
            "obs_m",
            mu=mu_m,
            sigma=sigma_m,
            observed=med,
            dims='obs'
        )

        # -------------------------------------------------
        # 4. Hazard model (direct + B path)
        # -------------------------------------------------
        beta_low  = pm.Normal("beta_low",  0, 0.1)
        beta_high = pm.Normal("beta_high", 0, 0.1)
        # Log-additive hazard
        log_lambda_t = (alpha_0_t[b] 
                        + alpha_1_t[b] * trt # direct effect
                        + alpha_2_t[b] * med  # mediator effect
                        + beta_low  * I_low
                        + beta_high * I_high
        )
        
        # Expected number of events
        time_at_risk = at_risk * dt
        Lambda = time_at_risk * pm.math.log1pexp(log_lambda_t)

        if observed:
            pm.Poisson(
                "obs_event",
                mu=Lambda,
                observed=events, 
                dims='obs'
            )
        else: 
            pm.Poisson(
                "obs_event",
                mu=Lambda,
                dims='obs'
            )

        # -------------------------------------------------
        # 5. Causal path effects
        # -------------------------------------------------
        # Store time-varying coefficients
        pm.Deterministic("alpha_0_t", alpha_0_t, dims='time_bins')
        pm.Deterministic("alpha_1_t", alpha_1_t, dims='time_bins')  # direct effect
        pm.Deterministic("alpha_2_t", alpha_2_t, dims='time_bins')  # B path
        pm.Deterministic("beta_t", beta_t, dims='time_bins')        # A path
        
        # Cumulative direct effect
        cum_de = pm.Deterministic(
            "tv_direct_effect",
            alpha_1_t, 
            dims='time_bins'
        )

        # Cumulative indirect effect (product of paths)
        cum_ie = pm.Deterministic(
            "tv_indirect_effect",
            beta_t * alpha_2_t, 
            dims='time_bins'
        )

        # Total effect
        cum_te = pm.Deterministic(
            "tv_total_effect",
            cum_de + cum_ie,
            dims='time_bins'
        )

        pm.Deterministic('tv_baseline_hazard', pm.math.log1pexp(alpha_0_t), 
            dims='time_bins')

        pm.Deterministic('tv_hazard_with_exposure', pm.math.log1pexp(alpha_0_t + alpha_1_t), 
            dims='time_bins')

        pm.Deterministic(
        "tv_RR",
        pm.math.log1pexp(alpha_0_t + alpha_1_t) /
        pm.math.log1pexp(alpha_0_t),
        dims="time_bins"
        )

        # -------------------------------------------------
        # 6. Sample
        # -------------------------------------------------
        if sample:
            idata = pm.sample_prior_predictive()
            idata.extend(pm.sample(
                draws=2000,
                tune=2000,
                target_accept=0.95,
                chains=4,
                nuts_sampler="numpyro",
                random_seed=42,
                init="adapt_diag", 
                idata_kwargs={"log_likelihood": True}
            ))
            idata.extend(pm.sample_posterior_predictive(idata))
    
    return aalen_dpa_model, idata

basis = create_bspline_basis(n_bins, n_knots=12, degree=3)
aalen_dpa_model, idata_aalen =  make_model(data, basis)

In [10]:
pm.model_to_graphviz(aalen_dpa_model)

In [11]:
#| eval: false

models = {}
idatas = {}
for i in range(4, 15, 2):
    basis = create_bspline_basis(n_bins, n_knots=i, degree=3)
    aalen_dpa_model, idata = make_model(data, basis)
    models[i] = aalen_dpa_model
    idatas[f"splines_{i}"] = idata

compare_df = az.compare(idatas, var_name='obs_event')
az.plot_compare(compare_df, figsize=(8, 6), plot_ic_diff=True)

![](spline_loo_comparison.png)


In [12]:
# | eval: false

ax = az.plot_forest([idatas[k] for k in idatas.keys()], combined=True, var_names=['tv_direct_effect'], model_names=idatas.keys(), coords={'time_bins': [180, 182, 182, 183, 184, 185, 186, 187, 188]}, 
figsize=(12, 10),  r_hat=True)
ax[0].set_title("Time Vary Direct Effects \n Comparing Models on Final Time Intervals", fontsize=15)
ax[0].set_ylabel("Nth Time Interval", fontsize=15)
fig = ax[0].figure
fig.savefig('forest_plot_comparing_tv_direct.png')

![](forest_plot_comparing_tv_direct.png)


In [13]:
az.plot_trace(idata_aalen, var_names=['tv_direct_effect', 'tv_indirect_effect', 'tv_total_effect', 'beta_high', 'beta_low'], divergences=False);
plt.tight_layout()


In [14]:
vars_to_plot = ['tv_direct_effect', 'tv_indirect_effect', 'tv_total_effect']
labels = ['Time varying Direct Effect', 'Time varying Indirect Effect', 'Time varying Total Effect']

def plot_effects(idata, vars_to_plot, labels, scale="Log Hazard Ratio Scale"):
    fig, axs = plt.subplots(1, 3, figsize=(20, 10))
    color='teal'
    if scale != "Log Hazard Ratio Scale":
        color='darkred'

    for i, var in enumerate(vars_to_plot):
        # 1. Extract the posterior samples for this variable
        # Shape will be (chain * draw, time)
        post_samples = az.extract(idata, var_names=[var]).values.T
        
        # 2. Calculate the mean and the 94% HDI across the chains/draws
        mean_val = post_samples.mean(axis=0)
        hdi_val = az.hdi(post_samples, hdi_prob=0.94) # Returns [time, 2] array
        
        # 3. Plot the Mean line
        x_axis = np.arange(len(mean_val))
        axs[i].plot(x_axis, mean_val, label=labels[i], color=color, lw=2)
        
        # 4. Plot the Shaded HDI region
        axs[i].fill_between(x_axis, hdi_val[:, 0], hdi_val[:, 1], color=color, alpha=0.2, label='94% HDI')
        
        # Formatting
        axs[i].set_title(labels[i])
        axs[i].legend()
        axs[i].grid(alpha=0.3)
        axs[i].set_ylabel(scale)
    plt.tight_layout()
    return fig

plot_effects(idata_aalen, vars_to_plot, labels);

In [15]:
vars_to_plot = ['tv_baseline_hazard', 'tv_hazard_with_exposure', 'tv_RR']
labels = ['Time varying Baseline Hazard', 'Time varying Hazard + Exposure', 'Time varying RR']
plot_effects(idata_aalen, vars_to_plot, labels, scale='Hazard Scale');