
# Geo Experiments for Incrementality Measurement

This notebook is a **playbook for geo experiments** (a.k.a. geographic lift tests).

Instead of randomizing at the **user** level, we randomize at the **geo / market / region**
level (e.g., cities, DMAs, countries) and measure incremental impact on an aggregate KPI.

We cover:

1. Simulating geo-level time series data (pre and post period).  
2. Simple geo **difference-in-differences** (DiD) estimator.  
3. Geo-level **regression adjustment** for more power.  
4. Basic **power analysis via simulation**.  
5. Practical notes on when and how to use geo experiments.

All code is typed, documented, and designed to be adapted to your own geo datasets.


In [None]:
from __future__ import annotations

from dataclasses import dataclass
from typing import Tuple, Dict, Any

import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

plt.rcParams["figure.figsize"] = (8, 4.5)
plt.rcParams["axes.grid"] = True


# -------------------------------------------------------------------
# 1) Simulated geo-level panel data
# -------------------------------------------------------------------

@dataclass
class GeoSimConfig:
    """Configuration for simulating a geo experiment panel.

    Attributes
    ----------
    n_geos : int
        Number of geos (clusters).
    pre_days : int
        Length of the pre period in days.
    post_days : int
        Length of the post period in days.
    baseline_mean : float
        Average baseline KPI per geo per day.
    baseline_geo_sd : float
        Standard deviation for geo-level baseline heterogeneity.
    day_noise_sd : float
        Standard deviation for day-to-day noise around each geo's baseline.
    treatment_uplift : float
        Multiplicative uplift in the post period for treatment geos
        (e.g., 0.05 means +5%).
    trend_per_day : float
        Optional linear time trend (same for all geos).
    """
    n_geos: int = 40
    pre_days: int = 28
    post_days: int = 14
    baseline_mean: float = 1000.0
    baseline_geo_sd: float = 400.0
    day_noise_sd: float = 250.0
    treatment_uplift: float = 0.05
    trend_per_day: float = 0.0


def simulate_geo_panel(
    config: GeoSimConfig,
    seed: int | None = 123,
) -> pd.DataFrame:
    """Simulate geo-level daily KPI panel data for a geo experiment.

    Parameters
    ----------
    config : GeoSimConfig
        Simulation configuration.
    seed : int | None
        Random seed for reproducibility.

    Returns
    -------
    DataFrame
        Columns: geo_id, group, day, period, kpi.
    """
    rng = np.random.default_rng(seed)
    G = config.n_geos
    pre_days = config.pre_days
    post_days = config.post_days

    # Geo-level baselines
    geo_ids = np.arange(G)
    geo_baseline = rng.normal(
        loc=config.baseline_mean,
        scale=config.baseline_geo_sd,
        size=G,
    )

    # Assign half of geos to treatment
    rng.shuffle(geo_ids)
    treat_geos = set(geo_ids[: G // 2])
    groups = np.array(["control"] * G, dtype=object)
    for g in treat_geos:
        groups[g] = "treatment"

    records: list[dict[str, Any]] = []

    # Pre period
    for t in range(pre_days):
        trend_factor = 1.0 + config.trend_per_day * t
        for geo in geo_ids:
            mu = geo_baseline[geo] * trend_factor
            y = rng.normal(loc=mu, scale=config.day_noise_sd)
            records.append(
                {
                    "geo_id": int(geo),
                    "group": str(groups[geo]),
                    "day": int(t),
                    "period": "pre",
                    "kpi": float(y),
                }
            )

    # Post period
    for t in range(pre_days, pre_days + post_days):
        trend_factor = 1.0 + config.trend_per_day * t
        for geo in geo_ids:
            uplift = config.treatment_uplift if groups[geo] == "treatment" else 0.0
            mu = geo_baseline[geo] * (1.0 + uplift) * trend_factor
            y = rng.normal(loc=mu, scale=config.day_noise_sd)
            records.append(
                {
                    "geo_id": int(geo),
                    "group": str(groups[geo]),
                    "day": int(t),
                    "period": "post",
                    "kpi": float(y),
                }
            )

    df = pd.DataFrame.from_records(records)
    return df


# Example config and data
config = GeoSimConfig(
    n_geos=40,
    pre_days=28,
    post_days=14,
    baseline_mean=1000.0,
    baseline_geo_sd=400.0,
    day_noise_sd=250.0,
    treatment_uplift=0.08,
    trend_per_day=0.001,
)

df = simulate_geo_panel(config)
display(df.head())

display(
    df.groupby(["group", "period"])["kpi"].agg(["mean", "std", "count"])
)


# -------------------------------------------------------------------
# 2) Geo-level pre/post aggregation
# -------------------------------------------------------------------

def aggregate_geo_pre_post(df: pd.DataFrame) -> pd.DataFrame:
    """Aggregate daily panel to geo-level pre/post metrics.

    Parameters
    ----------
    df : DataFrame
        Input with columns geo_id, group, period, kpi.

    Returns
    -------
    DataFrame
        One row per geo with columns:
        - geo_id, group
        - kpi_pre_mean, kpi_post_mean
        - kpi_pre_total, kpi_post_total
        - delta_mean = post_mean - pre_mean
        - delta_total = post_total - pre_total
    """
    grp = (
        df.groupby(["geo_id", "group", "period"])["kpi"]
          .agg(["mean", "sum", "count"])
          .rename(columns={"mean": "kpi_mean", "sum": "kpi_total"})
          .reset_index()
    )

    pre = grp[grp["period"] == "pre"].copy()
    post = grp[grp["period"] == "post"].copy()

    merged = pre.merge(
        post,
        on=["geo_id", "group"],
        suffixes=("_pre", "_post"),
        validate="one_to_one",
    )

    merged["delta_mean"] = merged["kpi_mean_post"] - merged["kpi_mean_pre"]
    merged["delta_total"] = merged["kpi_total_post"] - merged["kpi_total_pre"]

    return merged[
        [
            "geo_id",
            "group",
            "kpi_mean_pre",
            "kpi_mean_post",
            "kpi_total_pre",
            "kpi_total_post",
            "delta_mean",
            "delta_total",
        ]
    ]


geo_agg = aggregate_geo_pre_post(df)
display(geo_agg.head())

display(
    geo_agg.groupby("group")[
        ["kpi_mean_pre", "kpi_mean_post", "delta_mean"]
    ].agg(["mean", "std", "count"])
)


# -------------------------------------------------------------------
# 3) Geo-level DiD estimator
# -------------------------------------------------------------------

@dataclass
class DiDResult:
    estimate: float
    se: float
    ci_low: float
    ci_high: float
    dof: float
    p_value: float


def two_sample_ttest_from_groups(
    x_treat: np.ndarray,
    x_ctrl: np.ndarray,
    alpha: float = 0.05,
) -> DiDResult:
    """Welch two-sample t-test for difference in means (treat - control)."""
    x_treat = np.asarray(x_treat, dtype=float)
    x_ctrl = np.asarray(x_ctrl, dtype=float)

    n_t = x_treat.size
    n_c = x_ctrl.size

    mean_t = float(x_treat.mean())
    mean_c = float(x_ctrl.mean())

    var_t = float(x_treat.var(ddof=1)) if n_t > 1 else 0.0
    var_c = float(x_ctrl.var(ddof=1)) if n_c > 1 else 0.0

    se = math.sqrt(var_t / n_t + var_c / n_c) if n_t > 0 and n_c > 0 else float("nan")

    num = (var_t / n_t + var_c / n_c) ** 2
    den = (var_t ** 2) / (n_t ** 2 * (n_t - 1)) + (var_c ** 2) / (n_c ** 2 * (n_c - 1))
    dof = num / den if den > 0 else float("nan")

    diff = mean_t - mean_c
    t_stat = diff / se if se > 0 else float("nan")

    cdf = 0.5 * (1.0 + math.erf(t_stat / math.sqrt(2.0)))
    p_two_sided = 2.0 * min(cdf, 1.0 - cdf)

    z = 1.96 if abs(dof) > 30 else 1.96
    ci_low = diff - z * se
    ci_high = diff + z * se

    return DiDResult(
        estimate=diff,
        se=se,
        ci_low=ci_low,
        ci_high=ci_high,
        dof=dof,
        p_value=p_two_sided,
    )


def geo_did_from_agg(
    geo_agg: pd.DataFrame,
    value_col: str = "delta_mean",
    alpha: float = 0.05,
) -> DiDResult:
    """Run a geo-level DiD (difference-in-means on per-geo deltas)."""
    treat_vals = geo_agg.loc[geo_agg["group"] == "treatment", value_col].to_numpy()
    ctrl_vals = geo_agg.loc[geo_agg["group"] == "control", value_col].to_numpy()
    return two_sample_ttest_from_groups(treat_vals, ctrl_vals, alpha=alpha)


did_res = geo_did_from_agg(geo_agg, value_col="delta_mean", alpha=0.05)
display(did_res)

pre_overall = float(geo_agg["kpi_mean_pre"].mean())
rel_uplift = did_res.estimate / pre_overall
display({
    "abs_lift": did_res.estimate,
    "rel_lift": rel_uplift,
    "ci_rel_low": did_res.ci_low / pre_overall,
    "ci_rel_high": did_res.ci_high / pre_overall,
})


# -------------------------------------------------------------------
# 4) Geo-level regression adjustment
# -------------------------------------------------------------------

@dataclass
class OLSResult:
    coef: np.ndarray
    se: np.ndarray
    ci_low: np.ndarray
    ci_high: np.ndarray


def ols_with_intercept(
    X: np.ndarray,
    y: np.ndarray,
    alpha: float = 0.05,
) -> OLSResult:
    """Simple OLS with intercept using normal theory."""
    X = np.asarray(X, dtype=float)
    y = np.asarray(y, dtype=float)

    n, p = X.shape
    X1 = np.column_stack([np.ones(n), X])

    XtX = X1.T @ X1
    XtX_inv = np.linalg.inv(XtX)
    beta_hat = XtX_inv @ (X1.T @ y)

    residuals = y - X1 @ beta_hat
    dof = n - (p + 1)
    sigma2_hat = float((residuals @ residuals) / dof)

    cov_beta = sigma2_hat * XtX_inv
    se_beta = np.sqrt(np.diag(cov_beta))

    z = 1.96 if dof > 30 else 1.96
    ci_low = beta_hat - z * se_beta
    ci_high = beta_hat + z * se_beta

    return OLSResult(
        coef=beta_hat,
        se=se_beta,
        ci_low=ci_low,
        ci_high=ci_high,
    )


def geo_regression_adjustment(
    geo_agg: pd.DataFrame,
    alpha: float = 0.05,
) -> Tuple[OLSResult, Dict[str, Any]]:
    """Run a simple geo-level regression with pre KPI adjustment."""
    df = geo_agg.copy()
    df["treat_flag"] = (df["group"] == "treatment").astype(float)

    y = df["delta_mean"].to_numpy()
    X = df[["treat_flag", "kpi_mean_pre"]].to_numpy()

    ols_res = ols_with_intercept(X, y, alpha=alpha)

    beta0, beta1, beta2 = ols_res.coef
    se0, se1, se2 = ols_res.se
    ci1_low, ci1_high = ols_res.ci_low[1], ols_res.ci_high[1]

    pre_overall = float(df["kpi_mean_pre"].mean())
    rel_lift = beta1 / pre_overall
    rel_low = ci1_low / pre_overall
    rel_high = ci1_high / pre_overall

    info = {
        "beta1_treat_effect": beta1,
        "beta1_se": se1,
        "beta1_ci_low": ci1_low,
        "beta1_ci_high": ci1_high,
        "rel_lift": rel_lift,
        "rel_ci_low": rel_low,
        "rel_ci_high": rel_high,
    }
    return ols_res, info


ols_res, adj_info = geo_regression_adjustment(geo_agg, alpha=0.05)
display(adj_info)


# -------------------------------------------------------------------
# 5) Power analysis via simulation
# -------------------------------------------------------------------

def simulate_geo_power(
    config: GeoSimConfig,
    n_sims: int = 500,
    alpha: float = 0.05,
    seed: int | None = 999,
) -> Dict[str, float]:
    """Estimate power of the geo DiD test via simulation."""
    rng = np.random.default_rng(seed)

    rejections = 0
    ests: list[float] = []

    for _ in range(n_sims):
        s = int(rng.integers(0, 1_000_000))
        df_sim = simulate_geo_panel(config, seed=s)
        geo_agg_sim = aggregate_geo_pre_post(df_sim)
        did = geo_did_from_agg(geo_agg_sim, value_col="delta_mean", alpha=alpha)
        ests.append(did.estimate)

        if did.p_value < alpha:
            rejections += 1

    power_hat = rejections / n_sims
    return {
        "power_hat": power_hat,
        "mean_estimate": float(np.mean(ests)),
        "sd_estimate": float(np.std(ests, ddof=1)),
    }


power_summary = simulate_geo_power(
    config=config,
    n_sims=200,
    alpha=0.05,
    seed=2025,
)
display(power_summary)



## Practical notes for real geo experiments

1. **Cluster randomization**
   - Randomize at the geo level, not at the user level.  
   - Make sure treatment and control geos are balanced in terms of baseline KPI and other covariates.

2. **Pre-period is crucial**
   - Use a sufficiently long pre-period to measure baseline differences.  
   - Use pre-period KPI as a covariate (regression) or for matching / stratification.

3. **Analysis at the geo level**
   - Respect the unit of randomization: geos are the experimental units.  
   - Perform inference at the geo level (e.g., t-test on geo-level deltas).

4. **Guardrails and multiple metrics**
   - You can define geo-level guardrail metrics (e.g., refund rate, visits) and apply
     the same multi-metric decision rules as in user-level A/B tests.

5. **Extensions**
   - More advanced methods include synthetic control or Bayesian hierarchical models
     for geo experiments. This notebook stays with simple, transparent estimators that
     are often enough for many practical use cases.



## 6) Matching / stratified randomization by pre-KPI

In real geo experiments you rarely just flip a coin for each geo.

A common pattern is **stratified randomization**:

1. Use pre-period KPI to build a baseline score per geo.  
2. Sort geos by this baseline.  
3. Form small blocks (pairs or quadruples).  
4. Randomize treatment vs control **within each block**.

This keeps treatment and control balanced on past performance.
We will implement a simple **pair-matching by pre-KPI** helper.


In [None]:

from typing import Sequence

def assign_geos_stratified_by_pre_kpi(
    pre_df: pd.DataFrame,
    geo_col: str = "geo_id",
    kpi_col: str = "kpi_mean_pre",
    rng_seed: int | None = 2024,
) -> pd.DataFrame:
    """Assign treatment/control using stratified randomization on pre KPI.
    
    Steps:
    - Take one row per geo with baseline KPI.
    - Sort geos by baseline.
    - Form pairs (blocks of size 2).
    - Within each pair, randomly assign one to treatment, one to control.
    
    Parameters
    ----------
    pre_df : DataFrame
        Must contain columns [geo_col, kpi_col].
    geo_col : str
        Name of the geo id column.
    kpi_col : str
        Column storing the pre-period KPI per geo.
    rng_seed : int | None
        Random seed for reproducibility.
    
    Returns
    -------
    DataFrame
        With columns [geo_col, 'group'] where group ∈ {'treatment', 'control'}.
    """
    df = pre_df[[geo_col, kpi_col]].drop_duplicates().copy()
    df = df.sort_values(kpi_col).reset_index(drop=True)
    
    rng = np.random.default_rng(rng_seed)
    groups: list[str] = [""] * len(df)
    
    # Pair geos: (0,1), (2,3), ...
    for i in range(0, len(df), 2):
        block_indices = list(range(i, min(i + 2, len(df))))
        # Randomly assign within block
        perm = rng.permutation(block_indices)
        if len(perm) == 1:
            # Odd geo out: assign randomly
            groups[perm[0]] = rng.choice(["treatment", "control"])
        else:
            groups[perm[0]] = "treatment"
            groups[perm[1]] = "control"
    
    df["group"] = groups
    return df[[geo_col, "group"]]


# Example: design a stratified assignment using the geo_agg pre means
pre_baseline = geo_agg[["geo_id", "kpi_mean_pre"]].copy()
assignments = assign_geos_stratified_by_pre_kpi(pre_baseline)
assignments.head()



You would use this **before** running the experiment:

1. Use historical data to build `kpi_mean_pre` per geo.  
2. Call `assign_geos_stratified_by_pre_kpi`.  
3. Feed the assigned `group` into your ad server / experiment system.

In our simulation, we generated `group` internally, but this section shows how you would
design a geo experiment prospectively.



## 7) Synthetic control flavour (per-geo)

Synthetic control methods build a **weighted combination of control geos** that mimics
the pre-period trajectory of each treated geo.

Here we implement a simple flavour:

- Let \(Y^{(T)}_{g,t}\) be the KPI of treated geo \(g\) in the pre period.  
- Let \(C_t\) be the matrix of control geo KPIs (columns = control geos).  
- For each treated geo, we find weights \(w_g\) such that:
  \(C w_g \approx Y^{(T)}_{g}\) in least-squares sense.  
- We then use \(C_{\text{post}} w_g\) as a **synthetic control series** in the post period.

This is not a full constrained synthetic control (e.g. non-negative weights,
sum-to-one), but a readable starting point.


In [None]:

def build_synthetic_control_weights(
    df: pd.DataFrame,
    treated_geo: int,
    period_pre: str = "pre",
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Compute synthetic control weights for one treated geo using OLS.
    
    Parameters
    ----------
    df : DataFrame
        Panel with columns [geo_id, group, period, day, kpi].
    treated_geo : int
        Geo id of the treated geo.
    period_pre : str
        Label for the pre period.
    
    Returns
    -------
    (weights, ctrl_geos, days_pre)
        weights: np.ndarray of shape (n_ctrl,)
        ctrl_geos: np.ndarray of control geo ids
        days_pre: np.ndarray of day indices used in the fit
    """
    # Pre-period data
    pre = df[df["period"] == period_pre].copy()
    
    # Identify control geos
    ctrl_geos = (
        pre.loc[pre["group"] == "control", "geo_id"]
        .drop_duplicates()
        .to_numpy()
    )
    
    # Treated geo series
    y_pre = (
        pre.loc[pre["geo_id"] == treated_geo]
        .sort_values("day")["kpi"]
        .to_numpy()
    )
    days_pre = (
        pre.loc[pre["geo_id"] == treated_geo]
        .sort_values("day")["day"]
        .to_numpy()
    )
    
    # Control matrix: for each control geo, align on same days
    ctrl_series = []
    for g in ctrl_geos:
        s = (
            pre.loc[pre["geo_id"] == g]
            .sort_values("day")["kpi"]
            .to_numpy()
        )
        ctrl_series.append(s)
    C = np.column_stack(ctrl_series)  # shape (T_pre, n_ctrl)
    
    # OLS: minimize ||C w - y||^2 -> w = (C^T C)^{-1} C^T y
    CtC = C.T @ C
    CtY = C.T @ y_pre
    weights = np.linalg.pinv(CtC) @ CtY  # use pseudoinverse for robustness
    
    return weights, ctrl_geos, days_pre


def synthetic_control_post_series(
    df: pd.DataFrame,
    weights: np.ndarray,
    ctrl_geos: np.ndarray,
    period_post: str = "post",
) -> Tuple[np.ndarray, np.ndarray]:
    """Build synthetic control post-period series for a treated geo.
    
    Parameters
    ----------
    df : DataFrame
        Panel with columns [geo_id, group, period, day, kpi].
    weights : np.ndarray
        Weights for each control geo (shape (n_ctrl,)).
    ctrl_geos : np.ndarray
        Control geo ids corresponding to the weights.
    period_post : str
        Label for the post period.
    
    Returns
    -------
    (days_post, y_synth_post)
        days_post: day indices for the post period.
        y_synth_post: synthetic control KPI series for those days.
    """
    post = df[df["period"] == period_post].copy()
    days_post = (
        post.loc[post["geo_id"] == int(ctrl_geos[0])]
        .sort_values("day")["day"]
        .to_numpy()
    )
    
    ctrl_post_series = []
    for g in ctrl_geos:
        s = (
            post.loc[post["geo_id"] == g]
            .sort_values("day")["kpi"]
            .to_numpy()
        )
        ctrl_post_series.append(s)
    C_post = np.column_stack(ctrl_post_series)  # shape (T_post, n_ctrl)
    
    y_synth_post = C_post @ weights
    return days_post, y_synth_post


# Example: synthetic control for the first treated geo
treated_geos = (
    df.loc[df["group"] == "treatment", "geo_id"]
    .drop_duplicates()
    .sort_values()
    .to_numpy()
)

example_treated = int(treated_geos[0])
w_sc, ctrl_ids, days_pre = build_synthetic_control_weights(df, treated_geo=example_treated)
days_post, y_synth_post = synthetic_control_post_series(df, w_sc, ctrl_ids)

# Extract true treated geo post-period series
treated_post = (
    df[(df["geo_id"] == example_treated) & (df["period"] == "post")]
    .sort_values("day")
)
y_treated_post = treated_post["kpi"].to_numpy()
days_post_true = treated_post["day"].to_numpy()

plt.figure()
plt.plot(days_post_true, y_treated_post, label=f"Treated geo {example_treated}")
plt.plot(days_post, y_synth_post, label="Synthetic control", linestyle="--")
plt.xlabel("day")
plt.ylabel("kpi")
plt.title("Synthetic control flavour: treated vs synthetic control (post period)")
plt.legend()
plt.tight_layout()
plt.show()



This plot gives a **per-geo counterfactual trajectory** suggestion for what would have
happened in that treated geo if it had followed a combination of control geos that
matches its pre-period behaviour.

You can summarize impact as, for example, the **difference in average post-period KPI**
between treated and synthetic control for each treated geo, and then aggregate across them.



## 8) Bayesian hierarchical geo model (partial pooling)

We can also analyze geo experiments with a simple **Bayesian hierarchical model** that
partially pools information across geos.

Here we take a lightweight Normal–Normal model on the **per-geo changes**:

- For control geos: \(D_g^{(C)} \sim \mathcal{N}(\mu_C, \sigma_C^2)\).  
- For treatment geos: \(D_g^{(T)} \sim \mathcal{N}(\mu_T, \sigma_T^2)\).  

We place vague Normal priors on the group means:

\[
\mu_C \sim \mathcal{N}(m_0, s_0^2), \quad
\mu_T \sim \mathcal{N}(m_0, s_0^2).
\]

Given observed \(D_g\), the posteriors \(p(\mu_C \mid D)\) and \(p(\mu_T \mid D)\)
are also Normal (conjugate).

We can then derive the posterior for the **lift**:

\[
\Delta = \mu_T - \mu_C
\]

which is Normal with closed-form mean and variance.


In [None]:

@dataclass
class BayesGeoPosterior:
    mu_t_mean: float
    mu_t_sd: float
    mu_c_mean: float
    mu_c_sd: float
    delta_mean: float
    delta_sd: float
    prob_delta_gt_0: float
    ci_low: float
    ci_high: float


def bayes_geo_hierarchical_normal(
    geo_agg: pd.DataFrame,
    value_col: str = "delta_mean",
    prior_mean: float = 0.0,
    prior_sd: float = 100.0,
    alpha: float = 0.05,
) -> BayesGeoPosterior:
    """Simple Normal–Normal hierarchical model on per-geo changes.

    We approximate within-group variance by the sample variance of `value_col`
    in treatment and control geos.

    Parameters
    ----------
    geo_agg : DataFrame
        Geo-level aggregated data with columns ['group', value_col].
    value_col : str
        Column storing per-geo changes (e.g. delta_mean).
    prior_mean : float
        Prior mean m0 for both group means.
    prior_sd : float
        Prior standard deviation s0 for both group means.
    alpha : float
        Credible interval size (1 - alpha).

    Returns
    -------
    BayesGeoPosterior
        Posterior summaries for mu_T, mu_C and delta = mu_T - mu_C.
    """
    df = geo_agg[["group", value_col]].copy()

    d_t = df.loc[df["group"] == "treatment", value_col].to_numpy()
    d_c = df.loc[df["group"] == "control", value_col].to_numpy()

    n_t = d_t.size
    n_c = d_c.size

    mean_t = float(d_t.mean())
    mean_c = float(d_c.mean())

    # Plug-in estimates of group variances (treated as known)
    var_t = float(d_t.var(ddof=1)) if n_t > 1 else 1.0
    var_c = float(d_c.var(ddof=1)) if n_c > 1 else 1.0

    s0_sq = prior_sd**2

    # Posterior for mu_T: Normal(mn_t, sn_t^2)
    sn_t_sq = 1.0 / (1.0 / s0_sq + n_t / var_t)
    mn_t = sn_t_sq * (prior_mean / s0_sq + n_t * mean_t / var_t)

    # Posterior for mu_C: Normal(mn_c, sn_c^2)
    sn_c_sq = 1.0 / (1.0 / s0_sq + n_c / var_c)
    mn_c = sn_c_sq * (prior_mean / s0_sq + n_c * mean_c / var_c)

    # Delta = mu_T - mu_C ~ Normal(mn_t - mn_c, sn_t^2 + sn_c^2)
    delta_mean = mn_t - mn_c
    delta_var = sn_t_sq + sn_c_sq
    delta_sd = math.sqrt(delta_var)

    # Posterior probability that delta > 0
    z = delta_mean / delta_sd if delta_sd > 0 else 0.0
    cdf = 0.5 * (1.0 + math.erf(z / math.sqrt(2.0)))
    prob_delta_gt_0 = cdf

    # Symmetric (1 - alpha) credible interval
    from scipy.stats import norm  # type: ignore

    z_ci = norm.ppf(1.0 - alpha / 2.0)
    ci_low = delta_mean - z_ci * delta_sd
    ci_high = delta_mean + z_ci * delta_sd

    return BayesGeoPosterior(
        mu_t_mean=mn_t,
        mu_t_sd=math.sqrt(sn_t_sq),
        mu_c_mean=mn_c,
        mu_c_sd=math.sqrt(sn_c_sq),
        delta_mean=delta_mean,
        delta_sd=delta_sd,
        prob_delta_gt_0=prob_delta_gt_0,
        ci_low=ci_low,
        ci_high=ci_high,
    )


try:
    bayes_post = bayes_geo_hierarchical_normal(
        geo_agg,
        value_col="delta_mean",
        prior_mean=0.0,
        prior_sd=100.0,
        alpha=0.05,
    )
    bayes_post
except Exception as e:
    print("Bayesian geo model skipped (scipy not available):", e)



This model:

- **Pools information across geos** within each group to estimate group means \(\mu_T\) and \(\mu_C\).  
- Produces a posterior for the **lift** \(\Delta = \mu_T - \mu_C\) with:
  - mean and standard deviation,  
  - probability that lift is positive,  
  - a credible interval.

You can combine this with the decision rules from the other notebooks, e.g.:

- Ship if \(P(\Delta > 0) > 0.9\) and geo-level guardrails are acceptable.



## 9) Per-geo posterior effects (shrunken deltas)

The group-level Bayesian model gives us a posterior for the **average** lift across
treatment geos. Sometimes we also want **per-geo effects**, but with some **shrinkage**
towards the group average to avoid over-reacting to noisy geos.

Here we build a simple empirical-Bayes shrinker for per-geo deltas:

For each group separately (treatment or control):

- Let \(D_g\) be the raw delta for geo \(g\) (e.g. `delta_mean`).  
- Let \(\bar D\) and \(s^2\) be the sample mean and variance of \(D_g\) across geos.  
- Place a Normal prior on the geo's true effect \(\theta_g\):  
  \(\theta_g \sim \mathcal{N}(\bar D, s_0^2)\) with some prior variance \(s_0^2\).  
- Assume the observation noise variance is approximately \(s^2\).  

Then the posterior mean for \(\theta_g\) is:

\[
\mathbb{E}[\theta_g \mid D_g] = w D_g + (1 - w) \bar D, \quad
w = \frac{s_0^2}{s_0^2 + s^2}.
\]

Each geo's delta is **shrunk** towards its group average by the same weight \(w\).


In [None]:

from dataclasses import dataclass

@dataclass
class ShrunkenGeoEffects:
    df: pd.DataFrame
    shrink_weight_treat: float
    shrink_weight_control: float


def compute_shrunken_geo_deltas(
    geo_agg: pd.DataFrame,
    value_col: str = "delta_mean",
    prior_sd: float = 100.0,
) -> ShrunkenGeoEffects:
    """Compute shrunken per-geo deltas within each group (empirical Bayes).
    
    For each group (treatment/control), we shrink raw per-geo deltas towards the
    group's mean delta by a weight w:
    
        w = s0^2 / (s0^2 + s^2)
    
    where:
    - s0^2 is a prior variance (prior_sd^2),
    - s^2 is the sample variance of the per-geo deltas in that group.
    
    Parameters
    ----------
    geo_agg : DataFrame
        Geo-level aggregated data, with columns ['group', value_col].
    value_col : str
        Column containing per-geo effects (e.g., delta_mean).
    prior_sd : float
        Prior standard deviation for the group mean around the empirical mean.
    
    Returns
    -------
    ShrunkenGeoEffects
        DataFrame with added columns 'delta_shrunken' and shrink weights.
    """
    df = geo_agg.copy()
    df["delta_raw"] = df[value_col]
    
    s0_sq = prior_sd ** 2
    
    shrink_weights: dict[str, float] = {}
    
    for grp in ["control", "treatment"]:
        mask = df["group"] == grp
        d = df.loc[mask, "delta_raw"].to_numpy()
        if d.size <= 1:
            # Nothing to shrink
            shrink_weights[grp] = 1.0
            df.loc[mask, "delta_shrunken"] = d
            continue
        
        mean_d = float(d.mean())
        var_d = float(d.var(ddof=1))
        if var_d <= 0.0:
            shrink_weights[grp] = 1.0
            df.loc[mask, "delta_shrunken"] = d
            continue
        
        w = s0_sq / (s0_sq + var_d)
        shrink_weights[grp] = w
        
        df.loc[mask, "delta_shrunken"] = w * d + (1.0 - w) * mean_d
    
    return ShrunkenGeoEffects(
        df=df,
        shrink_weight_treat=shrink_weights.get("treatment", 1.0),
        shrink_weight_control=shrink_weights.get("control", 1.0),
    )


shrunken = compute_shrunken_geo_deltas(geo_agg, value_col="delta_mean", prior_sd=100.0)
shrunken.df.head()


In [None]:

# Plot per-geo raw vs shrunken deltas, colored by group
df_plot = shrunken.df.sort_values("delta_shrunken").reset_index(drop=True)

plt.figure(figsize=(10, 5))
colors = df_plot["group"].map({"control": "C0", "treatment": "C1"})
plt.scatter(df_plot.index, df_plot["delta_raw"], label="raw delta", alpha=0.4, marker="o")
plt.scatter(df_plot.index, df_plot["delta_shrunken"], label="shrunken delta", alpha=0.8, marker="x", c=colors)
plt.axhline(0.0, color="black", linewidth=1, linestyle="--")
plt.xlabel("geo (sorted by shrunken delta)")
plt.ylabel("delta_mean")
plt.title("Per-geo raw vs shrunken deltas")
plt.legend()
plt.tight_layout()
plt.show()

shrunken.shrink_weight_treat, shrunken.shrink_weight_control



This plot gives you a sense of which geos have extreme raw deltas that get pulled back
towards their group averages. You can use the **shrunken deltas** for:

- Ranking geos by performance while de-emphasizing noise.  
- Identifying outlier geos that still look strong/weak **after** shrinkage.



## 10) Decision layer: Bayesian geo model + guardrail rules

We now add a small **decision layer** that combines:

- The Bayesian geo model for the **primary metric lift** \(\Delta\).  
- A simple **guardrail** constraint for another metric (e.g. refund rate, complaints).

The idea is to encode something like:

> Ship if P(primary lift > 0) ≥ 0.9 **and** guardrail not degraded beyond threshold.


In [None]:

from dataclasses import dataclass

@dataclass
class GeoDecisionConfig:
    main_prob_threshold: float      # e.g. 0.9
    min_abs_lift: float | None     # e.g. 5.0 (units of delta_mean), or None for no size check
    guardrail_max_degradation: float  # maximum allowed degradation (absolute units)
    guardrail_direction: str       # 'lower_is_better' or 'higher_is_better'


@dataclass
class GeoDecisionOutcome:
    decision: str       # 'ship', 'hold', 'do_not_ship'
    reason: str
    details: Dict[str, float]


def decide_geo_bayes_with_guardrail(
    bayes_post: "BayesGeoPosterior",
    guardrail_estimate: float,
    guardrail_ci_high: float,
    config: GeoDecisionConfig,
) -> GeoDecisionOutcome:
    """Combine Bayesian lift and a guardrail rule into a simple decision.

    Parameters
    ----------
    bayes_post : BayesGeoPosterior
        Posterior summary for the primary metric lift (delta).
    guardrail_estimate : float
        Estimated change in guardrail (e.g. +0.005 absolute).
    guardrail_ci_high : float
        Upper bound of a CI or a conservative bound on guardrail change.
    config : GeoDecisionConfig
        Decision thresholds.

    Returns
    -------
    GeoDecisionOutcome
        Decision ('ship', 'hold', 'do_not_ship'), reason, and numeric details.
    """
    # Primary metric checks
    prob_ok = bayes_post.prob_delta_gt_0 >= config.main_prob_threshold
    size_ok = True
    if config.min_abs_lift is not None:
        size_ok = abs(bayes_post.delta_mean) >= config.min_abs_lift

    # Guardrail check: interpret degradation depending on direction
    if config.guardrail_direction == "lower_is_better":
        # Degradation is positive increase
        guardrail_ok = guardrail_ci_high <= config.guardrail_max_degradation
    elif config.guardrail_direction == "higher_is_better":
        # Degradation is negative decrease
        guardrail_ok = guardrail_ci_high >= -config.guardrail_max_degradation
    else:
        raise ValueError("guardrail_direction must be 'lower_is_better' or 'higher_is_better'.")

    details = {
        "delta_mean": bayes_post.delta_mean,
        "delta_sd": bayes_post.delta_sd,
        "prob_delta_gt_0": bayes_post.prob_delta_gt_0,
        "guardrail_estimate": guardrail_estimate,
        "guardrail_ci_high": guardrail_ci_high,
    }

    if prob_ok and size_ok and guardrail_ok:
        return GeoDecisionOutcome(
            decision="ship",
            reason=(
                "Primary metric shows positive lift with high posterior probability, "
                "and guardrail is within acceptable degradation."
            ),
            details=details,
        )

    if prob_ok and size_ok and not guardrail_ok:
        return GeoDecisionOutcome(
            decision="hold",
            reason=(
                "Primary metric looks good, but guardrail may be degraded beyond the "
                "acceptable threshold. Investigate or adjust before rolling out."
            ),
            details=details,
        )

    return GeoDecisionOutcome(
        decision="do_not_ship",
        reason=(
            "Primary metric lift is not sufficiently positive or not large enough in magnitude "
            "given the configured thresholds."
        ),
        details=details,
    )


# Example: mock guardrail effect and decision
try:
    _ = bayes_post  # type: ignore[name-defined]
except NameError:
    print("bayes_post is not available; run the Bayesian geo model cell first.")
else:
    # Suppose guardrail is 'refund_rate' (lower is better)
    guardrail_estimate = 0.005    # +0.5pp
    guardrail_ci_high = 0.015     # worst-case +1.5pp

    dec_config = GeoDecisionConfig(
        main_prob_threshold=0.9,
        min_abs_lift=None,            # no minimum absolute size check here
        guardrail_max_degradation=0.02,  # allow up to +2pp
        guardrail_direction="lower_is_better",
    )

    decision = decide_geo_bayes_with_guardrail(
        bayes_post=bayes_post,      # from the Bayesian geo model
        guardrail_estimate=guardrail_estimate,
        guardrail_ci_high=guardrail_ci_high,
        config=dec_config,
    )
    decision



This gives you a small, explicit **decision object** summarizing:

- Whether to **ship**, **hold**, or **not ship**.  
- A human-readable reason string.  
- Numeric details: posterior lift stats and guardrail summary.

You can hook this into the **story-layer notebook** we built earlier to auto-generate
a decision memo for geo experiments as well (primary effect, guardrails, and rollout plan).
