In [None]:
# --- Load the main notebook (which loads pairwise) so this notebook is standalone
from __future__ import annotations

from pathlib import Path
import nbformat
from IPython import get_ipython

def _exec_notebook(path: Path) -> None:
    # Execute all code cells of a notebook into the current IPython kernel.
    nb = nbformat.read(path, as_version=4)
    ip = get_ipython()
    for cell in nb.cells:
        if cell.cell_type == "code" and cell.source.strip():
            ip.run_cell(cell.source)

def _find_local_notebook(name: str) -> Path:
    # Find a notebook in cwd, parents, or /mnt/data (sandbox).
    candidates = []
    p = Path.cwd().resolve()
    for _ in range(4):
        candidates.append(p / name)
        p = p.parent
    candidates.append(Path("/mnt/data") / name)
    for c in candidates:
        if c.exists():
            return c
    raise FileNotFoundError("Cannot find %s. Looked in: %s" % (name, ", ".join(str(c) for c in candidates)))

main_nb = _find_local_notebook("did_multiplegt_stat_main.ipynb")
if "did_multiplegt_stat_main" not in globals():
    _exec_notebook(main_nb)


In [None]:
# --- Imports
from typing import Any, Dict, List, Optional, Sequence, Union
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# -----------------------------
# Helpers: checks + plotting
# -----------------------------
def by_check(df: pd.DataFrame, ID: str, by_var: str) -> bool:
    """R-equivalent of utils::by_check(): checks that ID is nested within by_var."""
    tmp = df[[ID, by_var]].dropna()
    if tmp.empty:
        return True
    n_unique = tmp.groupby(ID)[by_var].nunique(dropna=True)
    return bool((n_unique <= 1).all())

def by_graph(obj: Dict[str, Any]):
    """Minimal plot: estimates by group (by())."""
    by_levels = obj.get("by_levels", None)
    if not by_levels:
        return None

    rows = []
    for j, lev in enumerate(by_levels, start=1):
        key = "results" if (j == 1 and "results" in obj and len(by_levels) == 1) else f"results_by_{j}"
        res = obj.get(key, None)
        if not isinstance(res, dict) or "table" not in res:
            continue
        table = res["table"]
        if not isinstance(table, pd.DataFrame) or table.empty:
            continue

        for rname in table.index.astype(str):
            if rname.startswith(("aoss_", "waoss_", "ivwaoss_")) and rname.endswith("_1_1_XX"):
                rows.append({
                    "by": lev,
                    "row": rname,
                    "Estimate": float(table.loc[rname, "Estimate"]) if "Estimate" in table.columns else np.nan,
                    "LB CI": float(table.loc[rname, "LB CI"]) if "LB CI" in table.columns else np.nan,
                    "UB CI": float(table.loc[rname, "UB CI"]) if "UB CI" in table.columns else np.nan,
                })

    if not rows:
        return None

    plot_df = pd.DataFrame(rows)
    fig, ax = plt.subplots()
    for rname, g in plot_df.groupby("row"):
        x = np.arange(len(g))
        y = g["Estimate"].to_numpy(dtype=float)
        yerr = np.vstack([
            (g["Estimate"] - g["LB CI"]).to_numpy(dtype=float),
            (g["UB CI"] - g["Estimate"]).to_numpy(dtype=float),
        ])
        ax.errorbar(x, y, yerr=yerr, fmt="o", label=rname)

    ax.set_xticks(range(len(by_levels)))
    ax.set_xticklabels([str(x) for x in by_levels], rotation=45, ha="right")
    ax.set_xlabel("by group")
    ax.set_ylabel("Estimate")
    ax.grid(True, axis="y", alpha=0.3)
    ax.legend()
    fig.tight_layout()
    return {"figure": fig, "data": plot_df}

def by_fd_graph(obj: Dict[str, Any]):
    """Minimal plot for by_fd(): estimate vs median |ΔD| (or |ΔZ|)."""
    switch_df = obj.get("switch_df", None)
    by_levels = obj.get("by_levels", None)
    if switch_df is None or by_levels is None:
        return None

    xvals = None
    if isinstance(switch_df, pd.DataFrame) and "delta_median" in switch_df.columns:
        xvals = switch_df.set_index("partition_XX").reindex(by_levels)["delta_median"].to_numpy(dtype=float)
    else:
        xvals = np.arange(len(by_levels), dtype=float)

    rows = []
    for j, lev in enumerate(by_levels, start=1):
        key = f"results_by_{j}"
        res = obj.get(key, None)
        if not isinstance(res, dict) or "table" not in res:
            continue
        table = res["table"]
        if not isinstance(table, pd.DataFrame) or table.empty:
            continue

        rname = next((r for r in table.index.astype(str) if r.startswith(("waoss_", "ivwaoss_", "aoss_")) and r.endswith("_1_1_XX")), None)
        if rname is None:
            continue

        rows.append({
            "bin": int(lev),
            "x": float(xvals[j-1]) if j-1 < len(xvals) else np.nan,
            "Estimate": float(table.loc[rname, "Estimate"]) if "Estimate" in table.columns else np.nan,
            "LB CI": float(table.loc[rname, "LB CI"]) if "LB CI" in table.columns else np.nan,
            "UB CI": float(table.loc[rname, "UB CI"]) if "UB CI" in table.columns else np.nan,
            "row": rname,
        })

    if not rows:
        return None

    plot_df = pd.DataFrame(rows).sort_values("bin")
    fig, ax = plt.subplots()
    y = plot_df["Estimate"].to_numpy(dtype=float)
    yerr = np.vstack([
        (plot_df["Estimate"] - plot_df["LB CI"]).to_numpy(dtype=float),
        (plot_df["UB CI"] - plot_df["Estimate"]).to_numpy(dtype=float),
    ])
    ax.errorbar(plot_df["x"], y, yerr=yerr, fmt="o")
    ax.set_xlabel("Median |ΔD| (or |ΔZ|) in bin")
    ax.set_ylabel("Estimate")
    ax.grid(True, axis="y", alpha=0.3)
    fig.tight_layout()
    return {"figure": fig, "data": plot_df}

# -----------------------------
# by_fd quantile partitioning
# -----------------------------
def _balance_panel_fill(df: pd.DataFrame, id_col: str, t_col: str) -> pd.DataFrame:
    """Replicates plm::make.pbalanced(..., balance.type='fill') + remap T to 1..T."""
    df = df.copy()
    times = pd.Series(df[t_col].dropna().unique()).sort_values().to_list()
    ids = pd.Series(df[id_col].dropna().unique()).to_list()

    idx = pd.MultiIndex.from_product([ids, times], names=[id_col, t_col])
    df0 = df.set_index([id_col, t_col])
    df_bal = df0.reindex(idx).reset_index()

    # tsfilled_XX via merge indicator
    orig = df[[id_col, t_col]].copy()
    orig["_orig_row_XX"] = 1
    df_bal = df_bal.merge(orig, on=[id_col, t_col], how="left")
    df_bal["tsfilled_XX"] = df_bal["_orig_row_XX"].isna().astype(int)
    df_bal = df_bal.drop(columns=["_orig_row_XX"])

    # Remap time to 1..T
    t_map = {t: i + 1 for i, t in enumerate(times)}
    df_bal[t_col] = df_bal[t_col].map(t_map).astype(int)
    return df_bal

def did_multiplegt_stat_quantiles(
    df: pd.DataFrame,
    ID: str,
    Time: str,
    D: str,
    Z: Optional[str] = None,
    by_opt: int = 2,
    quantiles: Optional[Sequence[float]] = None,
) -> Dict[str, Any]:
    """
    High-fidelity translation of the R helper that constructs partition_XX for by_fd.

    Key behaviors matched to did_multiplegt_stat_quantiles.R:
    - Build the distribution of |ΔD| (or |ΔZ|) among switchers in periods eligible for aggregation.
    - Default bin assignment is the last bin (by_opt), so values with CDF == 1 stay in the last bin.
    - Return `quantiles` overwritten by the realized CDF cutpoints (quantiles_temp).
    """
    if quantiles is None:
        quantiles = np.linspace(0, 1, by_opt + 1).tolist()

    df_bal = _balance_panel_fill(df, ID, Time)
    diff_var = Z if Z is not None else D

    # Ensure diff_var is numeric enough for diff()
    if diff_var in df_bal.columns:
        if not pd.api.types.is_numeric_dtype(df_bal[diff_var]):
            try:
                df_bal[diff_var] = pd.to_numeric(df_bal[diff_var])
            except Exception as e:
                raise TypeError(f"{diff_var} must be numeric to compute first differences.") from e

    df_bal = df_bal.sort_values([ID, Time])
    df_bal["delta_pre_XX"] = df_bal.groupby(ID)[diff_var].diff().abs()

    # period-level counts (eligible periods require >=1 switcher and >=2 stayers)
    df_bal["switchers_dummy_XX_num"] = ((df_bal["delta_pre_XX"].notna()) & (df_bal["delta_pre_XX"] != 0)).astype(int)
    df_bal["stayers_dummy_XX_num"] = ((df_bal["delta_pre_XX"].notna()) & (df_bal["delta_pre_XX"] == 0)).astype(int)
    df_bal["switchers_N_XX"] = df_bal.groupby(Time)["switchers_dummy_XX_num"].transform("sum")
    df_bal["stayers_N_XX"] = df_bal.groupby(Time)["stayers_dummy_XX_num"].transform("sum")
    df_bal["in_aggregation_XX"] = ((df_bal["switchers_N_XX"] > 0) & (df_bal["stayers_N_XX"] > 1)).astype(int)

    # Switcher rows used to build the distribution
    df_switch0 = df_bal[(df_bal["delta_pre_XX"].notna()) & (df_bal["delta_pre_XX"] != 0) & (df_bal["in_aggregation_XX"] == 1)].copy()

    if df_switch0.empty:
        df_bal["partition_XX"] = np.nan
        return {
            "df": df_bal,
            "val_quantiles": list(quantiles),
            "quantiles": list(quantiles),
            "switch_df": pd.DataFrame(columns=["partition_XX", "N_switchers", "delta_median"]),
            "quantiles_plot": None,
            "cut_off": [],
        }

    # Frequency distribution over unique |Δ|
    dist = (
        df_switch0.groupby("delta_pre_XX", as_index=False)
        .size()
        .rename(columns={"size": "freq"})
        .sort_values("delta_pre_XX")
        .reset_index(drop=True)
    )
    dist["cdf"] = dist["freq"].cumsum() / dist["freq"].sum()

    # IMPORTANT (matches R): default is last bin, so CDF==1 stays in last bin
    dist["partition_XX"] = int(by_opt)

    cut_off: List[float] = []
    quantiles_temp: List[float] = [0.0]

    for j in range(2, len(quantiles) + 1):  # j = 2..len(quantiles)
        lo = float(quantiles[j - 2])
        hi = float(quantiles[j - 1])
        mask = (dist["cdf"] >= lo) & (dist["cdf"] < hi)
        dist.loc[mask, "partition_XX"] = j - 1
        if (dist["partition_XX"] == (j - 1)).any():
            cut_off.append(float(dist.loc[dist["partition_XX"] == (j - 1), "delta_pre_XX"].min()))
            quantiles_temp.append(float(dist.loc[dist["partition_XX"] == (j - 1), "cdf"].max()))

    # last cutoff is max |Δ|
    cut_off.append(float(dist["delta_pre_XX"].max()))

    # Map |Δ| -> partition (exactly as in R, but in pandas)
    map_part = dist.set_index("delta_pre_XX")["partition_XX"]

    df_bal["switchers_XX"] = ((df_bal["delta_pre_XX"].notna()) & (df_bal["delta_pre_XX"] != 0) & (df_bal["in_aggregation_XX"] == 1)).astype(int)
    df_bal["partition_XX"] = 0.0
    m_sw = df_bal["switchers_XX"] == 1
    df_bal.loc[m_sw, "partition_XX"] = df_bal.loc[m_sw, "delta_pre_XX"].map(map_part).astype(float)

    # Not in aggregation => NA partition (matches R behavior via subset(... != 0))
    df_bal.loc[df_bal["in_aggregation_XX"] == 0, "partition_XX"] = np.nan

    # Summary by partition for plotting
    df_sw = df_bal[m_sw].copy()
    # some bins may be collapsed; keep only finite partitions
    df_sw = df_sw[df_sw["partition_XX"].notna()]
    if not df_sw.empty:
        df_sw["partition_XX"] = df_sw["partition_XX"].astype(int)
        switch_df = (
            df_sw.groupby("partition_XX", as_index=False)
            .agg(N_switchers=("delta_pre_XX", "size"), delta_median=("delta_pre_XX", "median"))
            .sort_values("partition_XX")
            .reset_index(drop=True)
        )
    else:
        switch_df = pd.DataFrame(columns=["partition_XX", "N_switchers", "delta_median"])

    # Quick histogram plot of |Δ| among switchers (optional)
    fig = None
    try:
        fig, ax = plt.subplots()
        ax.hist(df_switch0["delta_pre_XX"].to_numpy(dtype=float), bins=min(30, max(5, len(dist))))
        ax.set_xlabel("|Δ|")
        ax.set_ylabel("Count")
        ax.set_title("|Δ| distribution among switchers (eligible periods)")
        ax.grid(True, axis="y", alpha=0.3)
        fig.tight_layout()
    except Exception:
        fig = None

    return {
        "df": df_bal,
        "val_quantiles": quantiles_temp,
        "quantiles": quantiles_temp,   # matches R overwrite
        "switch_df": switch_df,
        "quantiles_plot": {"figure": fig} if fig is not None else None,
        "cut_off": cut_off,
    }

def did_multiplegt_stat(
    df: pd.DataFrame,
    Y: str,
    ID: str,
    Time: str,
    D: str,
    Z: Optional[str] = None,
    estimator: Optional[Union[str, Sequence[str]]] = None,
    estimation_method: Optional[str] = None,
    order: int = 1,
    noextrapolation: bool = False,
    placebo: bool = False,
    switchers: Optional[str] = None,
    disaggregate: bool = False,
    aoss_vs_waoss: bool = False,
    exact_match: bool = False,
    by: Optional[Sequence[str]] = None,
    by_fd: Optional[int] = None,
    other_treatments: Optional[Sequence[str]] = None,
    cluster: Optional[str] = None,
    weight: Optional[str] = None,
    legacy_r_phi_scale: str = "auto",
) -> Dict[str, Any]:
    """Python interface mirroring did_multiplegt_stat.R."""
    if not isinstance(order, int):
        raise TypeError("order must be an integer.")
    if switchers is not None and switchers not in ("up", "down"):
        raise ValueError("Switchers could be either None, 'up' or 'down'.")

    # Default estimator (match R)
    if estimator is None and Z is None:
        estimator_list = ["aoss", "waoss"]
    elif estimator is None and Z is not None:
        estimator_list = ["ivwaoss"]
    elif isinstance(estimator, str):
        estimator_list = [estimator]
    else:
        estimator_list = list(estimator)

    allowed = {"aoss", "waoss", "ivwaoss"}
    if len([e for e in estimator_list if e in allowed]) != len(estimator_list):
        raise ValueError("Syntax error in estimator option: only aoss, waoss and ivwaoss allowed.")
    # --- Stata .ado behavior (Aug 2025): estimation_method is overridden ---
    if bool(exact_match):
        # Stata: exact_match => forces RA and ignores order/noextrapolation
        if estimation_method not in (None, "ra", "dr", "ps"):
            raise ValueError("Syntax error in estimation_method option.")
        if estimation_method not in (None, "ra"):
            print("As exact_match is specified, estimation_method() is ignored and set to 'ra'.")
        estimation_method = "ra"

        if bool(noextrapolation):
            print("As exact_match is specified, noextrapolation is ignored.")
            noextrapolation = False

        if order != 1:
            print("As exact_match is specified, order() is ignored and set to 1.")
            order = 1
    else:
        # Stata: without exact_match => forces DR for all estimators (ignores user input)
        if estimation_method is not None and str(estimation_method).lower() != "dr":
            print("Without exact_match, estimation_method() is ignored (Stata behavior). Using 'dr'.")
        estimation_method = "dr"
    # ---------------------------------------------------------------
    if estimation_method not in ("ra", "dr", "ps"):
        raise ValueError("Syntax error in estimation_method option.")
    # Stata behavior: even if overall method is DR, AOSS itself is computed via RA.
    # Therefore we do not error out when estimation_method is "dr" with aoss.
    if ("aoss" in estimator_list) and (estimation_method == "ps"):
        raise ValueError("The propensity score-based approach is only available for waoss and ivwaoss (not aoss).")

    if ("ivwaoss" in estimator_list) and any(e in ("aoss", "waoss") for e in estimator_list):
        raise ValueError("The estimation of AOSS or WAOSS cannot be combined with IV-WAOSS.")
    if bool(aoss_vs_waoss) and sum(e in ("aoss", "waoss") for e in estimator_list) != 2:
        raise ValueError("To test equality between AOSS and WAOSS you must specify both aoss and waoss.")
    if ("ivwaoss" in estimator_list) and (Z is None):
        raise ValueError("To compute ivwaoss you must specify the IV variable Z.")

    if by is not None and by_fd is not None:
        raise ValueError("You cannot specify both by and by_fd.")

    out: Dict[str, Any] = {
        "args": {
            "Y": Y,
            "ID": ID,
            "Time": Time,
            "D": D,
            "Z": Z,
            "estimator": estimator_list,
            "estimation_method": estimation_method,
            "order": int(order),
            "noextrapolation": bool(noextrapolation),
            "placebo": bool(placebo),
            "switchers": switchers,
            "disaggregate": bool(disaggregate),
            "aoss_vs_waoss": bool(aoss_vs_waoss),
            "exact_match": bool(exact_match),
            "by": list(by) if by is not None else None,
            "by_fd": int(by_fd) if by_fd is not None else None,
            "other_treatments": list(other_treatments) if other_treatments is not None else None,
            "cluster": cluster,
            "weight": weight,
        }
    }

    df_work = df.copy()
    mode = "_no_by"
    by_levels = ["_no_by"]
    by_str = None

    if by is not None:
        by = list(by)
        for v in by:
            if not by_check(df_work, ID, v):
                raise ValueError("The by option requires that the variable(s) are constant within ID.")
        comp = df_work[by].copy()
        na_row = comp.isna().any(axis=1)
        comp = comp.astype(str)
        by_total = comp[by[0]].astype(object)
        for v in by[1:]:
            by_total = by_total.astype(str) + "," + comp[v].astype(str)
        by_total = by_total.where(~na_row, np.nan)
        df_work["by_total"] = by_total
        by_levels = pd.Series(df_work["by_total"].dropna().unique()).sort_values().tolist()
        by_str = ",".join(by)
        out["by_levels"] = by_levels
        mode = "by"

    if by_fd is not None:
        if not isinstance(by_fd, int):
            raise TypeError("by_fd must be an integer.")
        if 100 % by_fd != 0:
            raise ValueError("Syntax error in by option. When by_fd is specified it must divide 100.")
        q_levels = [0.0]
        for _ in range(by_fd):
            q_levels.append(q_levels[-1] + 1.0 / by_fd)

        by_set = did_multiplegt_stat_quantiles(df=df_work, ID=ID, Time=Time, D=D, Z=Z, by_opt=by_fd, quantiles=q_levels)
        df_work = by_set["df"]
        out["val_quantiles"] = by_set.get("val_quantiles")
        out["quantiles"] = by_set.get("quantiles")
        out["switch_df"] = by_set.get("switch_df")
        out["quantiles_plot"] = by_set.get("quantiles_plot")

        part = df_work.loc[df_work["partition_XX"].notna() & (df_work["partition_XX"] != 0), "partition_XX"]
        by_levels = sorted(part.astype(int).unique().tolist())
        out["by_levels"] = by_levels
        # Match R: warn if point mass collapses bins
        if len(by_levels) != by_fd:
            print(f"Point mass > {100/by_fd:.0f}% detected. {by_fd - len(by_levels)} bin(s) collapsed.")
        mode = "by_fd"

    def _call_main(df_in: pd.DataFrame, by_fd_opt: Optional[Any]) -> Dict[str, Any]:
        return did_multiplegt_stat_main(
            df=df_in,
            Y=Y,
            ID=ID,
            Time=Time,
            D=D,
            Z=Z,
            estimator=estimator_list,
            estimation_method=estimation_method,
            order=int(order),
            noextrapolation=bool(noextrapolation),
            placebo=bool(placebo),
            switchers=switchers,
            disaggregate=bool(disaggregate),
            aoss_vs_waoss=bool(aoss_vs_waoss),
            exact_match=bool(exact_match),
            weight=weight,
            cluster=cluster,
            by_fd_opt=by_fd_opt,
            other_treatments=list(other_treatments) if other_treatments is not None else None,
            legacy_r_phi_scale=legacy_r_phi_scale,
        )

    if mode == "_no_by":
        out["results"] = _call_main(df_work, by_fd_opt=None)
    elif mode == "by":
        for j, lev in enumerate(by_levels, start=1):
            df_sub = df_work[df_work["by_total"] == lev].copy()
            print(f"Running did_multiplegt_stat with {by_str} = {lev}")
            out[f"results_by_{j}"] = _call_main(df_sub, by_fd_opt=None)
    else:  # by_fd
        diff_var = "Z" if "ivwaoss" in estimator_list else "D"
        qs = out.get("quantiles", None)
        for j, lev in enumerate(by_levels, start=1):
            if isinstance(qs, (list, tuple)) and len(qs) >= j + 1:
                left = float(qs[j-1]) * 100
                right = float(qs[j]) * 100
            else:
                left = np.nan
                right = np.nan
            sep = "[" if j == 1 else "("
            print(f"Running did_multiplegt_stat with switchers with abs(delta_{diff_var}) in {sep}{left:.0f}%, {right:.0f}%] quantile bin")
            out[f"results_by_{j}"] = _call_main(df_work, by_fd_opt=int(lev))

    # Optional graphs
    if out["args"].get("by") is not None:
        out["by_graph"] = by_graph(out)
    if out.get("quantiles") is not None:
        out["by_fd_graph"] = by_fd_graph(out)

    out["_class"] = "did_multiplegt_stat"
    return out
