### Bayesian CUPED (Assertions) Analysis

This script reproduces the CUPED-style analysis for binary assertion data
using the updated `stan` interface (PyStan 3) rather than the older
`pystan` interface.  The goal is to mirror the structure of the original
``cuped_analysis.ipynb`` while accommodating binary outcomes and
pre/post comparisons on logical assertions.  If the `stan` package is
available, the script will fit a hierarchical logistic model with a
Bayesian CUPED adjustment.  When `stan` is not installed, the script
falls back to a classical CUPED adjustment using control variates.

The expected input files are two CSV or Parquet datasets with the
columns:

    - ``RunId``: an identifier for the evaluation run
    - ``TestCaseId``: the test case identifier
    - ``AssertionId``: the assertion identifier
    - ``IsTrue``: binary indicator (1 if the assertion passes, 0 otherwise)

The script performs the following steps:

1. Reads the previous and next datasets, coercing ``IsTrue`` to binary.
2. Produces a coverage snapshot between the two periods.
3. Aggregates the data to compute baseline pass rates per
   ``(TestCaseId, AssertionId)`` pair.
4. If available, fits a Bayesian CUPED model using ``stan`` with the
   appropriate array syntax.  Otherwise, estimates a CUPED coefficient
   via covariance and adjusts the raw binary outcomes accordingly.
5. Summarises per‑pair differences, produces bootstrap confidence
   intervals, and computes variance reduction diagnostics.
6. Generates plots analogous to those in the original notebook:
   distribution KDEs, bootstrap distributions, and bar charts of mean
   differences with confidence intervals.
7. Writes summary tables and plots to an ``outputs`` directory.

Note: Running the Bayesian model requires that the ``stan`` package
be installed (PyStan ≥ 3).  If ``stan`` is not available, the code
will still execute the classical CUPED steps and generate the plots
accordingly.  The Stan model code uses the new array declarations:

    ``array[N] int y``
    ``array[N] int period``
    ``array[N] int pair_id``
    ``vector[N] x_cuped``

and can be adjusted to your particular modelling needs.


In [1]:
import sys
import warnings
from pathlib import Path
from typing import Dict, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import stan  # type: ignore[import]
import nest_asyncio
nest_asyncio.apply()

from statsmodels.api import OLS, add_constant
import re

  import pkg_resources


In [2]:
# Set these paths to point at your previous and next assertion datasets.
PREVIOUS_PATH = Path('data/assertions_previous.csv')
NEXT_PATH     = Path('data/assertions_next.csv')

# Random seed for reproducibility
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

# Number of bootstrap iterations for uncertainty quantification in the
# classical CUPED fallback.
N_BOOT = 10_000

# Output directories
OUTPUT_DIR = Path('outputs_assertions')
PLOT_DIR   = OUTPUT_DIR / 'plots'
STAN_DIR   = OUTPUT_DIR / 'stan_models'
for _dir in [OUTPUT_DIR, PLOT_DIR, STAN_DIR]:
    _dir.mkdir(parents=True, exist_ok=True)

# Plotting defaults (match original notebook aesthetics)
sns.set_theme(style='whitegrid')
plt.rcParams['figure.figsize'] = (10, 6)

In [3]:
def load_assertion_dataset(path: Path) -> pd.DataFrame:
    """Load a CSV or Parquet assertion dataset and coerce `IsTrue` to {0,1}.

    Parameters
    ----------
    path : pathlib.Path
        The path to the dataset.

    Returns
    -------
    pd.DataFrame
        A DataFrame with columns ``RunId``, ``TestCaseId``, ``AssertionId``
        and ``IsTrue`` (0 or 1).
    """
    if not path.exists():
        raise FileNotFoundError(f"Input file not found: {path}")
    if path.suffix.lower() == '.csv':
        df = pd.read_csv(path)
    elif path.suffix.lower() in {'.parquet', '.pq'}:
        df = pd.read_parquet(path)
    else:
        raise ValueError(f"Unsupported file type for {path}. Use CSV or Parquet.")

    required_cols = {'RunId', 'TestCaseId', 'AssertionId', 'IsTrue'}
    missing = required_cols - set(df.columns)
    if missing:
        raise ValueError(f"{path} missing required columns: {sorted(missing)}")

    # Coerce IsTrue to binary 0/1
    if df['IsTrue'].dtype == bool:
        df['IsTrue'] = df['IsTrue'].astype(int)
    else:
        def _to_binary(x):
            if pd.isna(x):
                return np.nan
            s = str(x).strip().lower()
            if s in {'1', 'true', 't', 'yes', 'y'}:
                return 1
            if s in {'0', 'false', 'f', 'no', 'n'}:
                return 0
            try:
                v = int(float(s))
                return 1 if v != 0 else 0
            except Exception:
                return np.nan
        df['IsTrue'] = df['IsTrue'].map(_to_binary)

    if df['IsTrue'].isna().any():
        bad = df[df['IsTrue'].isna()].head()
        raise ValueError(
            f"Non-binary values detected in 'IsTrue'. First offenders:\n{bad}"
        )

    return df[['RunId', 'TestCaseId', 'AssertionId', 'IsTrue']].copy()

In [4]:
def coverage_snapshot(prev: pd.DataFrame, nxt: pd.DataFrame) -> pd.DataFrame:
    """Compute coverage statistics between previous and next datasets.

    Returns a DataFrame with the number of unique (TestCaseId, AssertionId)
    pairs in each period, the size of their intersection, and the coverage
    rate (intersection / next).  Flags low overlap when coverage < 0.7.
    """
    prev_pairs = set((prev['TestCaseId'].astype(str) + '||' + prev['AssertionId'].astype(str)))
    next_pairs = set((nxt['TestCaseId'].astype(str) + '||' + nxt['AssertionId'].astype(str)))
    inter_pairs = prev_pairs & next_pairs
    coverage_rate_next = (len(inter_pairs) / len(next_pairs)) if next_pairs else np.nan
    return pd.DataFrame({
        'pairs_prev': [len(prev_pairs)],
        'pairs_next': [len(next_pairs)],
        'pairs_intersection': [len(inter_pairs)],
        'coverage_rate_next': [coverage_rate_next],
        'low_overlap_flag': [coverage_rate_next < 0.7 if not np.isnan(coverage_rate_next) else False]
    })


In [5]:
def compute_baseline(prev: pd.DataFrame) -> pd.DataFrame:
    """Compute baseline pass rate (mean IsTrue) for each (TestCaseId, AssertionId).

    Parameters
    ----------
    prev : pd.DataFrame
        The previous period assertion dataset.

    Returns
    -------
    pd.DataFrame
        A DataFrame with columns ``TestCaseId``, ``AssertionId`` and
        ``X_baseline`` representing the baseline pass rate per pair.
    """
    baseline = (prev
                .groupby(['TestCaseId', 'AssertionId'], as_index=False)['IsTrue']
                .mean()
                .rename(columns={'IsTrue': 'X_baseline'}))
    return baseline

 # 1. Load the previous and next assertion datasets

In [6]:
prev_df = load_assertion_dataset(PREVIOUS_PATH)
next_df = load_assertion_dataset(NEXT_PATH)

# 2. Coverage snapshot

In [7]:
coverage_df = coverage_snapshot(prev_df, next_df)
baseline_df = compute_baseline(prev_df)

# 3. STAN

In [8]:
print('stan package detected. Attempting to build and sample Bayesian model...')
# Prepare data for Stan: aggregate counts per pair
agg = (prev_df.groupby(['TestCaseId', 'AssertionId'], as_index=False)['IsTrue']
       .agg(y_prev='sum', n_prev='size'))
agg_next = (next_df.groupby(['TestCaseId', 'AssertionId'], as_index=False)['IsTrue']
            .agg(y_next='sum', n_next='size'))
full = agg.merge(agg_next, on=['TestCaseId', 'AssertionId'], how='outer').fillna(0)
N_pairs = len(full)
# Stan data using array declarations
stan_data = {
    'N': N_pairs,
    'y_prev': full['y_prev'].astype(int).tolist(),
    'n_prev': full['n_prev'].astype(int).tolist(),
    'y_next': full['y_next'].astype(int).tolist(),
    'n_next': full['n_next'].astype(int).tolist(),
}
# CUPED model code with array syntax
stan_code_cuped = """
data {
  int<lower=1> N;
  array[N] int<lower=0> y_prev;
  array[N] int<lower=0> n_prev;
  array[N] int<lower=0> y_next;
  array[N] int<lower=0> n_next;
}
parameters {
  real mu_eta;
  real<lower=0> sigma_eta;
  vector[N] eta_raw;
  real tau;
  real theta;
}
transformed parameters {
  vector[N] eta = mu_eta + sigma_eta * eta_raw;
  vector[N] q   = inv_logit(eta);
  real qbar     = mean(q);
  vector[N] logit_p_next = eta + tau + theta * (q - qbar);
}
model {
  mu_eta    ~ normal(0, 2);
  sigma_eta ~ normal(0, 1);
  eta_raw   ~ std_normal();
  tau       ~ normal(0, 1.5);
  theta     ~ normal(0, 3);
  y_prev ~ binomial(n_prev, q);
  y_next ~ binomial(n_next, inv_logit(logit_p_next));
}
generated quantities {
  vector[N] p_prev = q;
  vector[N] p_next = inv_logit(logit_p_next);
  vector[N] delta  = p_next - p_prev;
  real overall_delta = mean(delta);
}
"""
# Build and sample
posterior = stan.build(stan_code_cuped, data=stan_data, random_seed=RANDOM_SEED)
fit = posterior.sample(num_chains=4, num_samples=1000)
# Extract draws to DataFrame
stan_df = fit.to_frame()
# Save the draws for further analysis (optional)
stan_draws_path = OUTPUT_DIR / 'stan_draws.parquet'
stan_df.to_parquet(stan_draws_path)
print(f"Stan draws saved to {stan_draws_path}")

Exception in callback Task.__step()
handle: <Handle Task.__step()>
Traceback (most recent call last):
  File "/usr/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
RuntimeError: cannot enter context: <_contextvars.Context object at 0x7c0a34467400> is already entered
Exception in callback Task.__step()
handle: <Handle Task.__step()>
Traceback (most recent call last):
  File "/usr/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
RuntimeError: cannot enter context: <_contextvars.Context object at 0x7c0a34467400> is already entered
Exception in callback Task.__step()
handle: <Handle Task.__step()>
Traceback (most recent call last):
  File "/usr/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
RuntimeError: cannot enter context: <_contextvars.Context object at 0x7c0a34467400> is already entered
Exception in callback Task.__step()
h

stan package detected. Attempting to build and sample Bayesian model...


[1A[0J[36mSampling:[0m   0% (1/8000)
[1A[0J[36mSampling:[0m   1% (101/8000)
[1A[0J[36mSampling:[0m   3% (201/8000)
[1A[0J[36mSampling:[0m   4% (301/8000)
[1A[0J[36mSampling:[0m  12% (1000/8000)
[1A[0J[36mSampling:[0m  22% (1800/8000)
[1A[0J[36mSampling:[0m  46% (3700/8000)
[1A[0J[36mSampling:[0m  70% (5600/8000)
[1A[0J[36mSampling:[0m  86% (6900/8000)
[1A[0J[36mSampling:[0m 100% (8000/8000)
[1A[0J[32mSampling:[0m 100% (8000/8000), done.
[36mMessages received during sampling:[0m
  Gradient evaluation took 0.000325 seconds
  1000 transitions using 10 leapfrog steps per transition would take 3.25 seconds.
  Adjust your expectations accordingly!
  Gradient evaluation took 0.000361 seconds
  1000 transitions using 10 leapfrog steps per transition would take 3.61 seconds.
  Adjust your expectations accordingly!
  Gradient evaluation took 0.000257 seconds
  1000 transitions using 10 leapfrog steps per transition would take 2.57 seconds.
  Adjust y

Stan draws saved to outputs_assertions/stan_draws.parquet


In [None]:
# Print basic summaries to stdout
print('\nCoverage overview:')
print(coverage_df.to_string(index=False))

# 4. STAN Plots

In [None]:
def plot_overall_delta(stan_df: pd.DataFrame, plot_dir: Path = PLOT_DIR) -> Path:
    """
    Plot the posterior distribution of overall_delta (mean p_next - p_prev)
    with 95% credible interval, zero line, and annotation of P(delta > 0).
    """
    if "overall_delta" not in stan_df.columns:
        raise KeyError("Column 'overall_delta' not found in stan_df. "
                       "Check fit.to_frame().columns to confirm the name.")

    draws = stan_df["overall_delta"].values
    mean = draws.mean()
    ci_lo, ci_hi = np.percentile(draws, [2.5, 97.5])
    p_gt0 = np.mean(draws > 0)

    plt.figure(figsize=(10, 6))
    sns.kdeplot(draws, fill=True)
    plt.axvline(0.0, color="black", linestyle="--", label="No change (0)")
    plt.axvline(mean, color="tab:blue", linestyle="-", label=f"Posterior mean = {mean:.3f}")
    plt.axvline(ci_lo, color="tab:red", linestyle="--", label="95% CI")
    plt.axvline(ci_hi, color="tab:red", linestyle="--")

    plt.title("Posterior of Overall Change (Next − Previous)")
    plt.xlabel("Overall Δ (mean pass prob, Next − Previous)")
    plt.ylabel("Posterior density")

    # Text box with summary
    text = (
        f"Posterior mean Δ = {mean:.3f}\n"
        f"95% credible interval = [{ci_lo:.3f}, {ci_hi:.3f}]\n"
        f"P(Δ > 0) = {p_gt0:.3f}"
    )
    plt.gca().text(
        0.98, 0.95, text,
        transform=plt.gca().transAxes,
        ha="right", va="top",
        bbox=dict(boxstyle="round", facecolor="white", alpha=0.8)
    )

    plt.legend(loc="upper left")
    plt.tight_layout()
    out_path = plot_dir / "overall_delta_posterior.png"
    plt.savefig(out_path)
    plt.close()
    return out_path


In [None]:
def compute_mean_prev_next(stan_df: pd.DataFrame) -> pd.DataFrame:
    """
    Compute per-draw mean pass probability across all pairs for previous and next.

    Handles column names like:
        p_prev[1], p_prev[2], ...
        p_prev.1,  p_prev.2,  ...
    and similarly for p_next.
    """

    def _find_cols(base: str):
        # Match patterns like base[1], base[2] or base.1, base.2
        pattern = re.compile(rf"^{base}(\[|\.)(\d+)(\])?$")
        cols = []
        for c in stan_df.columns:
            m = pattern.match(c)
            if m:
                # Extract index to sort later
                idx = int(m.group(2))
                cols.append((idx, c))
        # Sort by index so column order aligns with pair order
        cols_sorted = [c for _, c in sorted(cols, key=lambda x: x[0])]
        return cols_sorted

    prev_cols = _find_cols("p_prev")
    next_cols = _find_cols("p_next")

    if not prev_cols or not next_cols:
        candidates = [c for c in stan_df.columns if "p_prev" in c or "p_next" in c]
        raise KeyError(
            "Could not find per-pair 'p_prev'/'p_next' columns in stan_df.\n"
            "Columns containing 'p_prev' or 'p_next' were:\n"
            f"{candidates[:50]}"
        )

    stan_df = stan_df.copy()
    stan_df["mean_p_prev"] = stan_df[prev_cols].mean(axis=1)
    stan_df["mean_p_next"] = stan_df[next_cols].mean(axis=1)
    return stan_df


def plot_old_vs_new_mean_pass_rate(stan_df: pd.DataFrame, plot_dir: Path = PLOT_DIR) -> Path:
    """
    Violin/box style plot comparing posterior mean pass rate
    of previous vs next system.
    """
    stan_df = compute_mean_prev_next(stan_df)

    long_df = pd.melt(
        stan_df[["mean_p_prev", "mean_p_next"]],
        var_name="Period",
        value_name="mean_pass_rate"
    )
    long_df["Period"] = long_df["Period"].map({
        "mean_p_prev": "Previous (old)",
        "mean_p_next": "Next (new)",
    })

    # Probability that new > old
    diff = stan_df["mean_p_next"] - stan_df["mean_p_prev"]
    p_new_gt_old = np.mean(diff > 0)
    mean_diff = diff.mean()
    ci_lo, ci_hi = np.percentile(diff, [2.5, 97.5])

    plt.figure(figsize=(8, 6))
    sns.violinplot(
        data=long_df,
        x="Period",
        y="mean_pass_rate",
        inner="quartile",
        cut=0
    )

    plt.title("Average Assertion Pass Rate\nPrevious vs Next (Posterior)")
    plt.ylabel("Mean pass probability")
    plt.xlabel("System")

    text = (
        f"Posterior mean (new − old) = {mean_diff:.3f}\n"
        f"95% CrI = [{ci_lo:.3f}, {ci_hi:.3f}]\n"
        f"P(new > old) = {p_new_gt_old:.3f}"
    )
    plt.gca().text(
        0.98, 0.05, text,
        transform=plt.gca().transAxes,
        ha="right", va="bottom",
        bbox=dict(boxstyle="round", facecolor="white", alpha=0.8)
    )

    plt.tight_layout()
    out_path = plot_dir / "mean_pass_rate_old_vs_new.png"
    plt.savefig(out_path)
    plt.close()
    return out_path


In [None]:
def compute_per_pair_delta_summary(stan_df: pd.DataFrame, full_pairs: pd.DataFrame) -> pd.DataFrame:
    """
    Summarise posterior delta per (TestCaseId, AssertionId) pair.

    Returns a DataFrame with columns:
        TestCaseId, AssertionId, delta_mean, delta_ci_lo, delta_ci_hi, prob_improved
    """
    delta_cols = [c for c in stan_df.columns if c.startswith("delta.")]
    if not delta_cols:
        raise KeyError("No 'delta[...'] columns found in stan_df.")

    # Sort columns so index matches pair index (1..N)
    delta_cols_sorted = sorted(
        delta_cols,
        key=lambda c: int(c.split(".")[1])
    )
    deltas = stan_df[delta_cols_sorted].values  # shape: (S, N_pairs)

    delta_mean = deltas.mean(axis=0)
    delta_ci_lo = np.percentile(deltas, 2.5, axis=0)
    delta_ci_hi = np.percentile(deltas, 97.5, axis=0)
    prob_improved = (deltas > 0).mean(axis=0)

    if len(full_pairs) != len(delta_mean):
        raise ValueError("full_pairs length does not match number of delta columns.")

    summary = full_pairs[["TestCaseId", "AssertionId"]].copy()
    summary["delta_mean"] = delta_mean
    summary["delta_ci_lo"] = delta_ci_lo
    summary["delta_ci_hi"] = delta_ci_hi
    summary["prob_improved"] = prob_improved

    return summary


def plot_top_pairs_forest(
    stan_df: pd.DataFrame,
    full_pairs: pd.DataFrame,
    top_k: int = 30,
    plot_dir: Path = PLOT_DIR
) -> Path:
    """
    Forest plot of the top |Δ| pairs, with 95% CrIs and P(Δ>0) annot.
    """
    summary = compute_per_pair_delta_summary(stan_df, full_pairs)

    # Choose top_k pairs by absolute delta_mean
    top = summary.reindex(summary["delta_mean"].abs().sort_values(ascending=False).index)
    top = top.head(top_k).copy()
    top["label"] = top["TestCaseId"].astype(str) + "::" + top["AssertionId"].astype(str)

    # Plot
    plt.figure(figsize=(10, max(6, 0.3 * len(top))))
    y_positions = np.arange(len(top))

    plt.errorbar(
        x=top["delta_mean"],
        y=y_positions,
        xerr=[
            top["delta_mean"] - top["delta_ci_lo"],
            top["delta_ci_hi"] - top["delta_mean"]
        ],
        fmt="o",
        color="tab:blue",
        ecolor="black",
        capsize=3
    )
    plt.axvline(0.0, color="black", linestyle="--")
    plt.yticks(y_positions, top["label"])
    plt.xlabel("Δ (p_next − p_prev)")
    plt.title(f"Top {len(top)} Assertions by |Posterior Mean Δ|")

    # Annotate prob_improved per point on the right
    for y, p in zip(y_positions, top["prob_improved"]):
        plt.text(
            x=plt.gca().get_xlim()[1],
            y=y,
            s=f"P(Δ>0)={p:.2f}",
            va="center",
            ha="right",
            fontsize=8
        )

    plt.tight_layout()
    out_path = plot_dir / "top_pairs_forest.png"
    plt.savefig(out_path)
    plt.close()
    return out_path


In [None]:
stan_df = fit.to_frame()

In [None]:
overall_path = plot_overall_delta(stan_df)
mean_rate_path = plot_old_vs_new_mean_pass_rate(stan_df)
forest_path = plot_top_pairs_forest(stan_df, full, top_k=30)

print("Key plots:")
print("  Overall delta posterior:", overall_path)
print("  Mean pass rate old vs new:", mean_rate_path)
print("  Top per-pair changes (forest):", forest_path)