In [None]:
"""
did_multiplegt_stat_pairwise.py

A faithful, functional Python translation (pandas + statsmodels) of the R internal
function `did_multiplegt_stat_pairwise()`.

Key design choices:
- Keep the same logic/flow as the R code (pairwise DiD between consecutive periods).
- Keep variable names with `_XX` suffix (matches upstream standardization).
- Add detailed comments to make debugging/maintenance easy.

Dependencies:
    pip install pandas numpy statsmodels patsy

IMPORTANT ASSUMPTIONS (same as R path):
- `df` already contains standardized columns created upstream:
    ID_XX, T_XX, Y_XX, D_XX, (optional) Z_XX,
    tsfilled_XX (0/1 panel-fill flag),
    weight_XX (obs weights, optional),
    and if clustering is used: cluster_XX and weight_c_XX (cluster weights, optional).
- This function is called for a given `pairwise` time index p.
  It uses T_XX in {p-1,p} (or {p-2,p-1,p} when placebo=True),
  then internally relabels time to 1..2 (or 1..3).

Return format:
    {
      "scalars": updated_scalars_dict,
      "to_add":  df_with_ID_and_influence_functions_and_aux_columns (or None)
    }
"""

from __future__ import annotations

from typing import Any, Dict, List, Optional

import numpy as np
import pandas as pd

import statsmodels.api as sm
import statsmodels.formula.api as smf


# ============================================================
# Small robust helpers (avoid crashes on empty/all-NaN groups)
# ============================================================

def _nanmin_or_nan(s: pd.Series) -> float:
    """Return nanmin(s) but if empty or all-NaN -> NaN."""
    arr = s.to_numpy(dtype=float)
    if arr.size == 0 or np.all(np.isnan(arr)):
        return float("nan")
    return float(np.nanmin(arr))

def _nanmax_or_nan(s: pd.Series) -> float:
    """Return nanmax(s) but if empty or all-NaN -> NaN."""
    arr = s.to_numpy(dtype=float)
    if arr.size == 0 or np.all(np.isnan(arr)):
        return float("nan")
    return float(np.nanmax(arr))



def _nanmax_or_minus_inf(s: pd.Series) -> float:
    """Return nanmax(s) but if empty or all-NaN -> -Inf (matches R max(..., na.rm=TRUE))."""
    arr = s.to_numpy(dtype=float)
    if arr.size == 0 or np.all(np.isnan(arr)):
        return float("-inf")
    return float(np.nanmax(arr))

def _ensure_numeric(df: pd.DataFrame, col: str) -> None:
    """Force a column to numeric in-place (coerce errors to NaN)."""
    if col in df.columns:
        df[col] = pd.to_numeric(df[col], errors="coerce")


# ============================================================
# Weighted utilities (mirroring your R utils.R semantics)
# ============================================================

def _get_weight_col(df: pd.DataFrame, w: Optional[str]) -> Optional[str]:
    """Return weight column name if it exists, else None."""
    if w is None:
        w = "weight_XX"
    return w if (w in df.columns) else None

def wSum(df: pd.DataFrame, w: Optional[str] = None) -> float:
    """
    Weighted sum of weights (NOT sum of a variable).
    Used as "effective N" denominator in sd/sqrt(wSum).
    """
    wcol = _get_weight_col(df, w)
    if wcol is None:
        return float(len(df))
    _ensure_numeric(df, wcol)
    return float(np.nansum(df[wcol].to_numpy(dtype=float)))

def Mean(var: str, df: pd.DataFrame, w: Optional[str] = None) -> float:
    """Weighted mean of df[var] using df[w]."""
    if var not in df.columns:
        return float("nan")
    wcol = _get_weight_col(df, w)
    _ensure_numeric(df, var)
    x = df[var].to_numpy(dtype=float)

    if wcol is None:
        return float(np.nanmean(x)) if np.any(~np.isnan(x)) else float("nan")

    _ensure_numeric(df, wcol)
    ww = df[wcol].to_numpy(dtype=float)
    mask = (~np.isnan(x)) & (~np.isnan(ww))
    if not np.any(mask):
        return float("nan")
    return float(np.average(x[mask], weights=ww[mask]))

def Sum(var: str, df: pd.DataFrame, w: Optional[str] = None) -> float:
    """Weighted sum Σ w_i * x_i of df[var]."""
    if var not in df.columns:
        return 0.0
    wcol = _get_weight_col(df, w)
    _ensure_numeric(df, var)
    x = df[var].to_numpy(dtype=float)

    if wcol is None:
        return float(np.nansum(x))

    _ensure_numeric(df, wcol)
    ww = df[wcol].to_numpy(dtype=float)
    mask = (~np.isnan(x)) & (~np.isnan(ww))
    return float(np.sum(x[mask] * ww[mask]))

def Sd(var: str, df: pd.DataFrame, w: Optional[str] = None) -> float:
    """
    Weighted standard deviation mirroring R's Hmisc::wtd.var default:
        Sd(x,w) = sqrt( Σ w_i (x_i-μ_w)^2 / (Σ w_i - 1) )

    Notes:
    - This corresponds to frequency weights with an "unbiased" denominator (sum(w)-1).
    - If no weights are provided/available, fall back to the usual sample sd (ddof=1).
    """
    if var not in df.columns:
        return float("nan")

    wcol = _get_weight_col(df, w)
    _ensure_numeric(df, var)
    x = df[var].to_numpy(dtype=float)

    if wcol is None:
        return float(np.nanstd(x, ddof=1))

    _ensure_numeric(df, wcol)
    ww = df[wcol].to_numpy(dtype=float)
    mask = (~np.isnan(x)) & (~np.isnan(ww))
    if not np.any(mask):
        return float("nan")

    x = x[mask]
    ww = ww[mask]
    sw = float(np.sum(ww))
    if sw <= 1.0:
        return float("nan")

    mu = float(np.sum(ww * x) / sw)
    varw = float(np.sum(ww * (x - mu) ** 2) / (sw - 1.0))
    return float(np.sqrt(varw))

# ============================================================
# Statsmodels wrappers to mimic the R "stata_logit" + predict
# ============================================================

def stata_logit(formula: str, df: pd.DataFrame, wcol: str = "weight_XX",
                maxit: int = 300, tol: float = 1e-8):
    """
    Emulate R `stata_logit()`:
    Fit weighted logit using GLM Binomial with frequency weights.
    """
    if wcol in df.columns:
        _ensure_numeric(df, wcol)
        freq_w = df[wcol]
        freq_w = freq_w.fillna(0.0)
    else:
        freq_w = None

    model = smf.glm(
        formula=formula,
        data=df,
        family=sm.families.Binomial(),
        freq_weights=freq_w
    )
    res = model.fit(maxiter=maxit, tol=tol, disp=0)
    return res

def _lpredict_fallback(df: pd.DataFrame, outcol: str, fitted_model) -> pd.DataFrame:
    """
    Add predictions to df[outcol].

    This mirrors the R helper `lpredict()` used in the original package:
    - For Binomial GLM (logit): predicted probabilities, then apply:
        * NaN -> 1
        * values < 1e-10 -> 0
      (R uses `sensitivity <- 10^-10` and applies these only for prob=TRUE.)
    - For WLS/OLS: fitted values (no clipping).

    Notes:
    - This behavior matters for strict replication of the WAOSS/IVWAOSS paths where
      propensity scores close to zero are hard-trimmed in R.
    """
    pred = np.asarray(fitted_model.predict(df), dtype=float)

    # Detect Binomial GLM results (statsmodels)
    is_binom = False
    try:
        fam = getattr(getattr(fitted_model, "model", None), "family", None)
        is_binom = isinstance(fam, sm.families.Binomial)
    except Exception:
        is_binom = False

    if is_binom:
        sensitivity = 1e-10
        pred = np.where(np.isnan(pred), 1.0, pred)
        pred = np.where(pred < sensitivity, 0.0, pred)

    df[outcol] = pred
    return df






# ============================================================
# Compatibility wrapper for `lpredict`
# ============================================================
# In this project, `lpredict` is implemented in `lpredict.ipynb` with the same
# signature as the R helper:
#   _lpredict(df, varname, model, varlist, const=True, prob=False, factor=False)
#
# Earlier versions of this file used a simplified internal `_lpredict(df, outcol, model)`.
# To make the code robust (and to avoid Jupyter order-of-execution issues), we:
#   - keep a fallback implementation (`_lpredict_fallback`) that works with statsmodels
#   - call the external `lpredict` if it is already loaded AND has the R-like signature
#   - otherwise, fall back to `_lpredict_fallback`.
#
# IMPORTANT: For logit/GLM(Binomial) models, R calls lpredict(..., prob=TRUE).
# The wrapper below auto-detects Binomial models and forces `prob=True` in that case.

import inspect

def _infer_varlist_from_model(model) -> list[str]:
    """Infer RHS variable names from a statsmodels-like fitted model."""
    if hasattr(model, "params"):
        params = model.params
        try:
            names = list(params.index)
        except Exception:
            names = []
    else:
        names = []
    drop = {"(Intercept)", "Intercept", "const"}
    return [n for n in names if n not in drop]

def _is_binomial_glm(model) -> bool:
    try:
        fam = getattr(getattr(model, "model", None), "family", None)
        return isinstance(fam, sm.families.Binomial)
    except Exception:
        return False

def _lpredict(df: pd.DataFrame, varname: str, model, prob: bool = False) -> pd.DataFrame:
    """Predict into df[varname], using external R-like lpredict if available."""
    prob_eff = bool(prob) or _is_binomial_glm(model)

    ext = globals().get("lpredict", None)
    if callable(ext):
        try:
            sig = inspect.signature(ext)
            if len(sig.parameters) >= 4:
                varlist = _infer_varlist_from_model(model)
                return ext(df, varname, model, varlist, const=True, prob=prob_eff, factor=False)
        except Exception:
            pass

    return _lpredict_fallback(df, varname, model)


# Alias used by upstream calls in this module
__lpredict = _lpredict

# ============================================================
# Main: did_multiplegt_stat_pairwise
# ============================================================

def did_multiplegt_stat_pairwise(
    df: pd.DataFrame,
    Y: str,
    ID: str,
    Time: str,
    D: str,
    Z: Optional[str],
    estimator: str,
    order: int,
    noextrapolation: bool,
    weight: Optional[str],
    switchers: Optional[str],
    pairwise: int,
    IDs: Any,
    aoss: int,
    waoss: int,
    ivwaoss: int,
    estimation_method: str,
    scalars: Dict[str, Any],
    placebo: int,  # FIXED: 0=no placebo, 1+=placebo index
    exact_match: bool,
    cluster: Optional[str],
    by_fd_opt: Optional[int],
    other_treatments: Optional[List[str]],
) -> Dict[str, Any]:
    """
    Internal function for estimation of pairwise DiD between consecutive time periods.

    NOTE:
    - Interface keeps Y/ID/Time/D/Z for compatibility, but function uses *_XX columns.
    - `scalars` is updated in-place (like in R), and returned too.
    """

    df = df.copy()

    # ------------------------------------------------------------
    # 0) Ensure standardized *_XX columns exist (standalone-friendly)
    # ------------------------------------------------------------
    if "ID_XX" not in df.columns and ID in df.columns:
        df["ID_XX"] = df[ID]
    if "T_XX" not in df.columns and Time in df.columns:
        df["T_XX"] = df[Time]
    if "Y_XX" not in df.columns and Y in df.columns:
        df["Y_XX"] = df[Y]
    if "D_XX" not in df.columns and D in df.columns:
        df["D_XX"] = df[D]
    if Z is not None:
        if "Z_XX" not in df.columns and Z in df.columns:
            df["Z_XX"] = df[Z]
    # `tsfilled_XX` is 0/1 indicating panel-filled rows (created upstream in R main).
    # If absent, default to 0 (assume original observations).
    if "tsfilled_XX" not in df.columns:
        df["tsfilled_XX"] = 0.0
    # Observation weights: in the R code, missing weights are set to 0 (not dropped).
    if weight is not None and ("weight_XX" not in df.columns) and (weight in df.columns):
        df["weight_XX"] = df[weight]
    # Cluster id alias (if provided and not already standardized)
    if cluster is not None and ("cluster_XX" not in df.columns) and (cluster in df.columns):
        df["cluster_XX"] = df[cluster]

    pl = "_pl" if placebo > 0 else ""
    placebo_index = placebo  # FIXED: store placebo index

    # ------------------------------------------------------------
    # 1) Subset time window: {p-1,p} or {p-2,p-1,p} for placebo
    # ------------------------------------------------------------
    # FIXED: Period selection based on placebo index
    if placebo_index == 0:
        # Main effect: periods {p-1, p}
        df = df[df["T_XX"].isin([pairwise - 1, pairwise])]
    else:
        # Placebo: periods {p - placebo - 1, p - placebo, p - 1, p}
        periods_to_keep = sorted(set([
            pairwise - placebo_index - 1,
            pairwise - placebo_index,
            pairwise - 1,
            pairwise
        ]))
        df = df[df["T_XX"].isin(periods_to_keep)]

    # ------------------------------------------------------------
    # 2) Detect "gap" via tsfilled_XX
    #    gap_XX = max_t min_i tsfilled_{it}
    # ------------------------------------------------------------
    if len(df) == 0:
        gap_XX = 1.0
    else:
        df["tsfilled_min_XX"] = df.groupby("T_XX")["tsfilled_XX"].transform(_nanmin_or_nan)
        gap_XX = float(_nanmax_or_nan(df["tsfilled_min_XX"]))

    # ------------------------------------------------------------
    # 3) Relabel time to consecutive ids: 1..k
    # ------------------------------------------------------------
    if len(df) > 0:
        tvals = np.sort(df["T_XX"].dropna().unique())
        tmap = {t: i + 1 for i, t in enumerate(tvals)}
        df["T_XX"] = df["T_XX"].map(tmap).astype(float)

    # ------------------------------------------------------------
    # 4) Sort and compute within-ID first differences
    # ------------------------------------------------------------
    _ensure_numeric(df, "ID_XX")
    df = df.sort_values(["ID_XX", "T_XX"], kind="mergesort").reset_index(drop=True)

    g = df.groupby("ID_XX")
    df["delta_Y_XX"] = g["Y_XX"].diff()  # backward diff: Y_t - Y_{t-1} (plm::diff alignment)
    df["delta_D_XX"] = g["D_XX"].diff()  # backward diff: D_t - D_{t-1} (plm::diff alignment)

    if ivwaoss == 1:
        if "Z_XX" not in df.columns:
            raise ValueError("ivwaoss==1 but df has no Z_XX column.")
        df["delta_Z_XX"] = g["Z_XX"].diff()  # backward diff: Z_t - Z_{t-1} (plm::diff alignment)

    # Other treatments: fd_ot = sum(diff(ot)) within ID
    if other_treatments:
        for v in other_treatments:
            df[f"_fdtmp_{v}"] = df.groupby("ID_XX")[v].diff()
        for v in other_treatments:
            df[f"fd_{v}_XX"] = df.groupby("ID_XX")[f"_fdtmp_{v}"].transform(
                lambda s: float(np.nansum(s.to_numpy(dtype=float)))
            )
            df.drop(columns=[f"_fdtmp_{v}"], inplace=True)

    # by_fd: keep lead of partition then drop partition
    if "partition_XX" in df.columns:
        df["partition_lead_XX"] = df.groupby("ID_XX")["partition_XX"].shift(-1)
        df.drop(columns=["partition_XX"], inplace=True)

    # ------------------------------------------------------------
    # 5) Make delta_Y constant per ID
    # ------------------------------------------------------------
    # FIXED: deltaY calculation based on placebo index
    if placebo_index == 0:
        df["delta_Y_XX"] = df.groupby("ID_XX")["delta_Y_XX"].transform("mean")
    else:
        # For placebo: take deltaY at T_XX == 2 (corresponds to Y_{t-placebo} - Y_{t-placebo-1})
        df["delta_temp"] = np.where(df["T_XX"] == 2, df["delta_Y_XX"], np.nan)
        df["delta_Y_XX"] = df.groupby("ID_XX")["delta_temp"].transform("mean")
        df.drop(columns=["delta_temp"], inplace=True)

    # ------------------------------------------------------------
    # 6) Placebo restriction
    # ------------------------------------------------------------
    # FIXED: Placebo restriction based on placebo index
    if placebo_index > 0 and (aoss == 1 or waoss == 1):
        df["inSamplePlacebo_temp_XX"] = np.where(
            (df["delta_D_XX"] == 0) & (df["T_XX"] == 2),
            1.0,
            0.0,
        )
        df.loc[df["delta_D_XX"].isna(), "inSamplePlacebo_temp_XX"] = np.nan
        df["inSamplePlacebo_XX"] = df.groupby("ID_XX")["inSamplePlacebo_temp_XX"].transform(_nanmax_or_minus_inf)

        # Stata: drop if T_XX == 1
        df = df[df["T_XX"] != 1]
        
        # Stata: if placebo > 1: drop if T_XX == 2
        if placebo_index > 1:
            df = df[df["T_XX"] != 2]
        
        # Stata: deltaD only at the correct period
        # if placebo == 1: deltaD = . if T_XX != 3
        # if placebo > 1: deltaD = . if T_XX != 4
        if placebo_index == 1:
            df["delta_D_XX"] = np.where(df["T_XX"] != 3, np.nan, df["delta_D_XX"])
        else:  # placebo_index > 1
            df["delta_D_XX"] = np.where(df["T_XX"] != 4, np.nan, df["delta_D_XX"])

    # FIXED: IV placebo restriction based on placebo index
    if placebo_index > 0 and ivwaoss == 1:
        df["inSamplePlacebo_IV_temp_XX"] = np.where(
            (df["delta_Z_XX"] == 0) & (df["T_XX"] == 2),
            1.0,
            0.0,
        )
        df.loc[df["delta_Z_XX"].isna(), "inSamplePlacebo_IV_temp_XX"] = np.nan
        df["inSamplePlacebo_XX"] = df.groupby("ID_XX")["inSamplePlacebo_IV_temp_XX"].transform(_nanmax_or_minus_inf)

        df = df[df["T_XX"] != 1]
        
        if placebo_index > 1:
            df = df[df["T_XX"] != 2]
        
        if placebo_index == 1:
            df["delta_Z_XX"] = np.where(df["T_XX"] != 3, np.nan, df["delta_Z_XX"])
        else:
            df["delta_Z_XX"] = np.where(df["T_XX"] != 4, np.nan, df["delta_Z_XX"])

    # ------------------------------------------------------------
    # 7) Empty after restrictions => early exit
    # ------------------------------------------------------------
    if len(df) == 0:
        if aoss == 1:
            scalars[f"P_{pairwise}{pl}_XX"] = 0.0
        if waoss == 1:
            scalars[f"E_abs_delta_D_{pairwise}{pl}_XX"] = 0.0
        if ivwaoss == 1:
            scalars[f"denom_delta_IV_{pairwise}{pl}_XX"] = 0.0

        scalars[f"non_missing_{pairwise}{pl}_XX"] = 0.0

        for v in ("Switchers", "Stayers"):
            for n in (1, 2, 3):
                scalars[f"N_{v}_{n}_{pairwise}{pl}_XX"] = 0.0

        for i, active in enumerate([aoss, waoss, ivwaoss], start=1):
            if active == 1:
                scalars[f"delta_{i}_{pairwise}{pl}_XX"] = 0.0
                scalars[f"sd_delta_{i}_{pairwise}{pl}_XX"] = np.nan
                scalars[f"LB_{i}_{pairwise}{pl}_XX"] = np.nan
                scalars[f"UB_{i}_{pairwise}{pl}_XX"] = np.nan

        return {"scalars": scalars, "to_add": None}

    # ------------------------------------------------------------
    # 8) Make delta_D (and delta_Z) constant per ID
    # ------------------------------------------------------------
    df["delta_D_XX"] = df.groupby("ID_XX")["delta_D_XX"].transform("mean")

    if ivwaoss == 1:
        df["delta_Z_XX"] = df.groupby("ID_XX")["delta_Z_XX"].transform("mean")
        df["SI_XX"] = np.sign(df["delta_Z_XX"]).astype(float)
        df["Z1_XX"] = df["Z_XX"]

    # ------------------------------------------------------------
    # 9) used_in indicators + switcher sign S_XX + abs deltas
    # ------------------------------------------------------------
    df[f"used_in_{pairwise}{pl}_XX"] = (
        (~df["delta_Y_XX"].isna()) & (~df["delta_D_XX"].isna())
    ).astype(float)

    if ivwaoss == 1:
        df[f"used_in_IV_{pairwise}{pl}_XX"] = (
            (df[f"used_in_{pairwise}{pl}_XX"] == 1.0) & (~df["delta_Z_XX"].isna())
        ).astype(float)
        df = df[df[f"used_in_IV_{pairwise}{pl}_XX"] == 1.0]

    df["S_XX"] = np.sign(df["delta_D_XX"]).astype(float)

    if (waoss == 1 or aoss == 1):
        df["abs_delta_D_XX"] = df["S_XX"] * df["delta_D_XX"]
        if switchers == "up":
            df = df[df["S_XX"] != -1.0]
        elif switchers == "down":
            df = df[df["S_XX"] != 1.0]

    if ivwaoss == 1:
        if switchers == "up":
            df = df[df["SI_XX"] != -1.0]
        elif switchers == "down":
            df = df[df["SI_XX"] != 1.0]
        df["abs_delta_Z_XX"] = df["SI_XX"] * df["delta_Z_XX"]

    # ------------------------------------------------------------
    # 10) Drop the second-year line (keep first row of pair)
    # ------------------------------------------------------------
    df = df[df["T_XX"] != df["T_XX"].max()]

    df["D1_XX"] = df["D_XX"]
    df.drop(columns=["D_XX"], inplace=True)

    df["Ht_XX"] = ((~df["delta_D_XX"].isna()) & (~df["delta_Y_XX"].isna())).astype(float)
    df.loc[df["Ht_XX"] == 0, "S_XX"] = np.nan

    if ivwaoss == 1:
        df["Ht_XX"] = ((df["Ht_XX"] == 1.0) & (~df["delta_Z_XX"].isna())).astype(float)
        df.loc[df["Ht_XX"] == 0, "SI_XX"] = np.nan

    if by_fd_opt is not None and "partition_lead_XX" in df.columns:
        df = df[(df["partition_lead_XX"] == 0) | (df["partition_lead_XX"] == by_fd_opt)]

    # ------------------------------------------------------------
    # 11) Set missing if placebo condition fails or other_treatments change
    # ------------------------------------------------------------
    vars_to_set_missing = ["S_XX", "delta_D_XX", "delta_Y_XX", "D1_XX"]
    if aoss == 1 or waoss == 1:
        vars_to_set_missing += ["abs_delta_D_XX"]
    else:
        vars_to_set_missing += ["Z1_XX", "SI_XX"]

    if placebo_index > 0 and "inSamplePlacebo_XX" in df.columns:
        mask_bad = (df["inSamplePlacebo_XX"] == 0)
        for v in vars_to_set_missing:
            if v in df.columns:
                df.loc[mask_bad, v] = np.nan
        df.loc[mask_bad, "Ht_XX"] = np.nan

    if other_treatments:
        for ot in other_treatments:
            colfd = f"fd_{ot}_XX"
            if colfd in df.columns:
                mask_bad = (df[colfd] != 0)
                for v in vars_to_set_missing:
                    if v in df.columns:
                        df.loc[mask_bad, v] = np.nan
                df.loc[mask_bad, "Ht_XX"] = np.nan

    # ------------------------------------------------------------
    # 12) No-extrapolation trimming
    # ------------------------------------------------------------
    scalars.setdefault("N_drop_total_XX", 0.0)
    scalars.setdefault("N_drop_total_C_XX", 0.0)

    if noextrapolation:
        if aoss == 1 or waoss == 1:
            stayers = df[df["S_XX"] == 0]
            if len(stayers):
                max_D1 = float(np.nanmax(stayers["D1_XX"]))
                min_D1 = float(np.nanmin(stayers["D1_XX"]))
            else:
                # mimic R's max/min with na.rm=TRUE on empty: -Inf / +Inf
                max_D1 = float("-inf")
                min_D1 = float("inf")

            d1 = df["D1_XX"].to_numpy(dtype=float)
            df["outofBounds_XX"] = (~np.isnan(d1)) & ((d1 < min_D1) | (d1 > max_D1))
            N_drop = float(np.nansum(df["outofBounds_XX"].astype(float)))
            scalars[f"N_drop_{pairwise}{pl}_XX"] = N_drop
            df = df[~df["outofBounds_XX"]]

            if (N_drop > 0) and (placebo_index == 0) and (gap_XX == 0) and (N_drop < len(df) - 1):
                scalars["N_drop_total_XX"] += N_drop

        if ivwaoss == 1:
            stayers = df[df["SI_XX"] == 0]
            if len(stayers):
                max_Z1 = float(np.nanmax(stayers["Z1_XX"]))
                min_Z1 = float(np.nanmin(stayers["Z1_XX"]))
            else:
                # mimic R's max/min with na.rm=TRUE on empty: -Inf / +Inf
                max_Z1 = float("-inf")
                min_Z1 = float("inf")

            z1 = df["Z1_XX"].to_numpy(dtype=float)
            df["outofBoundsIV_XX"] = (~np.isnan(z1)) & ((z1 < min_Z1) | (z1 > max_Z1))
            N_IVdrop = float(np.nansum(df["outofBoundsIV_XX"].astype(float)))
            scalars[f"N_IVdrop_{pairwise}{pl}_XX"] = N_IVdrop
            df = df[~df["outofBoundsIV_XX"]]

            if (N_IVdrop > 0) and (placebo_index == 0) and (gap_XX == 0) and (N_IVdrop < len(df) - 1):
                scalars["N_drop_total_XX"] += N_IVdrop

    # ------------------------------------------------------------
    # 13) Exact matching feasibility drops
    # ------------------------------------------------------------
    if exact_match:
        if aoss == 1 or waoss == 1:
            group_cols = ["D1_XX"] + (other_treatments or [])
            g = df.groupby(group_cols, dropna=False)

            df["has_match_min_XX"] = g["abs_delta_D_XX"].transform(_nanmin_or_nan)
            df["has_match_max_XX"] = g["abs_delta_D_XX"].transform(_nanmax_or_minus_inf)

            df["s_has_match_XX"] = np.where(
                ~df["S_XX"].isna(),
                (df["has_match_min_XX"] == 0).astype(float),
                -1.0
            )
            df.loc[df["S_XX"] == 0, "s_has_match_XX"] = -1.0

            df["c_has_match_XX"] = np.where(
                ~df["S_XX"].isna(),
                (df["has_match_max_XX"] > 0).astype(float),
                -1.0
            )
            df.loc[(df["S_XX"] != 0) & (~df["S_XX"].isna()), "c_has_match_XX"] = -1.0

        else:
            group_cols = ["Z1_XX"] + (other_treatments or [])
            g = df.groupby(group_cols, dropna=False)

            df["has_match_min_XX"] = g["abs_delta_Z_XX"].transform(_nanmin_or_nan)
            df["has_match_max_XX"] = g["abs_delta_Z_XX"].transform(_nanmax_or_minus_inf)

            df["s_has_match_XX"] = np.where(
                ~df["SI_XX"].isna(),
                (df["has_match_min_XX"] == 0).astype(float),
                -1.0
            )
            df.loc[df["SI_XX"] == 0, "s_has_match_XX"] = -1.0

            df["c_has_match_XX"] = np.where(
                ~df["SI_XX"].isna(),
                (df["has_match_max_XX"] > 0).astype(float),
                -1.0
            )
            df.loc[(df["SI_XX"] != 0) & (~df["SI_XX"].isna()), "c_has_match_XX"] = -1.0

        N_drop_s = float((df["s_has_match_XX"] == 0).sum())
        N_drop_c = float((df["c_has_match_XX"] == 0).sum())
        scalars[f"N_drop_{pairwise}{pl}_XX"] = N_drop_s
        scalars[f"N_drop_{pairwise}{pl}_C_XX"] = N_drop_c

        if (N_drop_s > 0) and (N_drop_s != len(df)) and (gap_XX == 0):
            scalars["N_drop_total_XX"] += N_drop_s
        if (N_drop_c > 0) and (N_drop_c != len(df)) and (gap_XX == 0):
            scalars["N_drop_total_C_XX"] += N_drop_c

        mask_bad = (df["s_has_match_XX"] == 0) | (df["c_has_match_XX"] == 0)
        for v in vars_to_set_missing:
            if v in df.columns:
                df.loc[mask_bad, v] = np.nan
        df.loc[mask_bad, "Ht_XX"] = np.nan

        if "D1_XX" in df.columns:
            nun = int(df["D1_XX"].nunique(dropna=True))
            if nun >= 1:
                order = min(order, nun)

        df.drop(columns=[c for c in ["has_match_min_XX", "has_match_max_XX"] if c in df.columns], inplace=True)

    # ------------------------------------------------------------
    # 14) Bookkeeping scalars for this pair
    # ------------------------------------------------------------
    if "weight_XX" not in df.columns:
        df["weight_XX"] = 1.0
    _ensure_numeric(df, "weight_XX")

    # R behavior: set missing weights to 0 (not NA)
    df["weight_XX"] = df["weight_XX"].fillna(0.0)
    # Guardrail: negative weights are almost surely unintended
    if (df["weight_XX"] < 0).any():
        raise ValueError("Negative weights are not supported (weight_XX < 0).")

    scalars[f"W{pl}_XX"] = float(np.nansum(df["weight_XX"].to_numpy(dtype=float)))
    scalars[f"N{pl}_XX"] = float(len(df))

    if waoss == 1 or aoss == 1:
        scalars[f"N_Switchers{pl}_XX"] = float(((df["S_XX"] != 0) & (~df["S_XX"].isna())).sum())
        scalars[f"N_Stayers{pl}_XX"] = float((df["S_XX"] == 0).sum())

    if ivwaoss == 1:
        scalars[f"N_Switchers_IV{pl}_XX"] = float(((df["SI_XX"] != 0) & (~df["SI_XX"].isna())).sum())
        scalars[f"N_Stayers_IV{pl}_XX"] = float((df["SI_XX"] == 0).sum())

    # ------------------------------------------------------------
    # 15) Build polynomial regressors and formula strings
    # ------------------------------------------------------------
    for pol_level in range(1, order + 1):
        df[f"D1_{pol_level}_XX"] = df["D1_XX"] ** pol_level

    reg_pol_terms = " + ".join([f"D1_{k}_XX" for k in range(1, order + 1)])

    if other_treatments:
        interact = "D1_1_XX"
        for v in other_treatments:
            interact = f"{interact} * {v}"
        reg_pol_terms = f"{reg_pol_terms} + {interact}"

    if ivwaoss == 1:
        for pol_level in range(1, order + 1):
            df[f"Z1_{pol_level}_XX"] = df["Z1_XX"] ** pol_level

        IV_reg_pol_terms = " + ".join([f"Z1_{k}_XX" for k in range(1, order + 1)])
        if other_treatments:
            interact = "Z1_1_XX"
            for v in other_treatments:
                interact = f"{interact} * {v}"
            IV_reg_pol_terms = f"{IV_reg_pol_terms} + {interact}"
    else:
        IV_reg_pol_terms = ""

    df["S_bis_XX"] = np.where(df["S_XX"].isna(), np.nan, (df["S_XX"] != 0).astype(float))

    # ------------------------------------------------------------
    # 16) Feasibility check
    # ------------------------------------------------------------
    if aoss == 1 or waoss == 1:
        feasible_est = (gap_XX == 0) and (scalars[f"N_Switchers{pl}_XX"] > 0) and (scalars[f"N_Stayers{pl}_XX"] > 1)
    else:
        feasible_est = (gap_XX == 0) and (scalars[f"N_Switchers_IV{pl}_XX"] > 0) and (scalars[f"N_Stayers_IV{pl}_XX"] > 1)

    scalars[f"P_Ht_{pairwise}{pl}_XX"] = Mean("Ht_XX", df)

    # ------------------------------------------------------------
    # 17) Cluster preparation
    # ------------------------------------------------------------
    cluster_col = None
    if cluster is not None:
        if "cluster_XX" in df.columns:
            cluster_col = "cluster_XX"
        elif cluster in df.columns:
            df["cluster_XX"] = df[cluster]
            cluster_col = "cluster_XX"
        else:
            raise ValueError("cluster specified but neither cluster_XX nor the given cluster column exists in df.")
        # If cluster is identical to ID, clustering collapses to ID-robust -> treat as no cluster (matches R main behavior).
        same_as_id = False
        try:
            same_as_id = (
                df[cluster_col].astype("string").fillna("<NA>")
                .equals(df["ID_XX"].astype("string").fillna("<NA>"))
            )
        except Exception:
            same_as_id = False

        if same_as_id:
            cluster_col = None
            cluster = None
        else:
            # In R main: weight_c_XX = sum(weight_XX) by (cluster, time).
            if "weight_c_XX" not in df.columns:
                if "weight_XX" in df.columns:
                    _ensure_numeric(df, "weight_XX")
                    df["weight_XX"] = df["weight_XX"].fillna(0.0)
                    df["weight_c_XX"] = df.groupby([cluster_col, "T_XX"])["weight_XX"].transform("sum")
                else:
                    # fallback: unweighted cluster-time counts
                    df["weight_c_XX"] = df.groupby([cluster_col, "T_XX"])["ID_XX"].transform("size").astype(float)

            _ensure_numeric(df, "weight_c_XX")
            df["weight_c_XX"] = df["weight_c_XX"].fillna(0.0)

        if cluster_col is not None:
            df["_first_in_id"] = df.groupby("ID_XX").cumcount().eq(0).astype(float)
            df["_Nc"] = df.groupby(cluster_col)["_first_in_id"].transform(lambda s: float(np.nansum(s.to_numpy(dtype=float))))
            scalars[f"N_bar_c_{pairwise}{pl}_XX"] = float(np.nanmean(df["_Nc"].to_numpy(dtype=float)))
        df.drop(columns=["_first_in_id", "_Nc"], inplace=True, errors="ignore")

    # ============================================================
    # 18) Estimation if feasible
    # ============================================================
    if feasible_est:

        # -------------------------
        # 18A) Common prelims AOSS/WAOSS
        # -------------------------
        if waoss == 1 or aoss == 1:
            df0 = df[df["S_XX"] == 0].copy()

            ra_formula = f"delta_Y_XX ~ {reg_pol_terms}"
            ra_model = smf.wls(ra_formula, data=df0, weights=df0["weight_XX"]).fit()
            df = __lpredict(df, "mean_pred_XX", ra_model)

            df["inner_sum_delta_1_2_XX"] = df["delta_Y_XX"] - df["mean_pred_XX"]
            df["S0_XX"] = 1.0 - df["S_bis_XX"]

            if not exact_match:
                ps0_formula = f"S0_XX ~ {reg_pol_terms}"
                ps0_model = stata_logit(ps0_formula, df)
                df = __lpredict(df, "PS_0_D_1_XX", ps0_model)
            else:
                esbis_formula = f"S_bis_XX ~ {reg_pol_terms}"
                esbis_model = smf.wls(esbis_formula, data=df, weights=df["weight_XX"]).fit()
                df = __lpredict(df, "ES_bis_XX_D_1", esbis_model)

                es_formula = f"S_XX ~ {reg_pol_terms}"
                es_model = smf.wls(es_formula, data=df, weights=df["weight_XX"]).fit()
                df = __lpredict(df, "ES_XX_D_1", es_model)

            scalars[f"PS_0{pl}_XX"] = Mean("S0_XX", df)

        # -------------------------
        # 18B) AOSS
        # -------------------------
        if aoss == 1:
            ES = Mean("S_bis_XX", df)
            scalars[f"ES{pl}_XX"] = ES

            scalars[f"P_{pairwise}{pl}_XX"] = ES * scalars[f"P_Ht_{pairwise}{pl}_XX"]
            scalars[f"PS_sum{pl}_XX"] = scalars.get(f"PS_sum{pl}_XX", 0.0) + scalars[f"P_{pairwise}{pl}_XX"]

            # Step 1: S_over_delta_D = S_bis / delta_D (for switchers)
            df["S_over_delta_D_XX"] = df["S_bis_XX"] / df["delta_D_XX"]
            df.loc[df["S_bis_XX"] == 0, "S_over_delta_D_XX"] = 0.0

            # Step 2: Regress S/deltaD on D1 to get E[S/deltaD | D1]
            sdd_formula = f"S_over_delta_D_XX ~ {reg_pol_terms}"
            sdd_model = smf.wls(sdd_formula, data=df, weights=df["weight_XX"]).fit()
            df = __lpredict(df, "mean_S_over_delta_D_XX", sdd_model)

            # Step 3: Compute delta_1 using DOUBLY-ROBUST formula (matching Stata)
            # For switchers (S_bis=1): S_over_deltaD * inner_sum_delta_1_2
            # For stayers (S_bis=0): -(mean_S_over_deltaD / PS_0) * inner_sum_delta_1_2
            if not exact_match:
                df["dr_delta1_DR_XX"] = np.where(
                    df["S_bis_XX"] == 0,
                    -(df["mean_S_over_delta_D_XX"] / df["PS_0_D_1_XX"].replace(0, np.nan)) * df["inner_sum_delta_1_2_XX"],
                    df["S_over_delta_D_XX"] * df["inner_sum_delta_1_2_XX"]
                )
            else:
                denom_exact = (1.0 - df["ES_bis_XX_D_1"]).replace(0, np.nan)
                df["dr_delta1_DR_XX"] = np.where(
                    df["S_bis_XX"] == 0,
                    -(df["mean_S_over_delta_D_XX"] / denom_exact) * df["inner_sum_delta_1_2_XX"],
                    df["S_over_delta_D_XX"] * df["inner_sum_delta_1_2_XX"]
                )
            
            # delta_1 = weighted mean of the DR estimator
            scalars[f"delta_1_{pairwise}{pl}_XX"] = Mean("dr_delta1_DR_XX", df)

            # Keep inner_sum_delta_1_XX for compatibility (used elsewhere)
            df["inner_sum_delta_1_XX"] = df["inner_sum_delta_1_2_XX"] / df["delta_D_XX"]
            df.loc[df["delta_D_XX"] == 0, "inner_sum_delta_1_XX"] = np.nan

            # Step 4: Influence function (raw_phi) - same as before
            if not exact_match:
                adj = (1.0 - df["S_bis_XX"]) / df["PS_0_D_1_XX"]
                raw_phi = (df["S_over_delta_D_XX"] - df["mean_S_over_delta_D_XX"] * adj) * df["inner_sum_delta_1_2_XX"]
            else:
                adj = (1.0 - df["S_bis_XX"]) / (1.0 - df["ES_bis_XX_D_1"])
                raw_phi = (df["S_over_delta_D_XX"] - df["mean_S_over_delta_D_XX"] * adj) * df["inner_sum_delta_1_2_XX"]

            df[f"Phi_1_{pairwise}{pl}_XX"] = (
                raw_phi - (scalars[f"delta_1_{pairwise}{pl}_XX"] * df["S_bis_XX"])
            ) / (ES * scalars[f"P_Ht_{pairwise}{pl}_XX"])

            df.loc[df["Ht_XX"] == 0, f"Phi_1_{pairwise}{pl}_XX"] = 0.0

            # SE delta_1
            if cluster_col is not None:
                phi = f"Phi_1_{pairwise}{pl}_XX"
                df["_phi_c"] = df.groupby(cluster_col)[phi].transform(lambda s: float(np.nansum(s.to_numpy(dtype=float))))
                df["_first_clus"] = df.groupby(cluster_col).cumcount().eq(0)
                df["_phi_c"] = np.where(df["_first_clus"], df["_phi_c"], np.nan) / scalars[f"N_bar_c_{pairwise}{pl}_XX"]

                nobs_c = wSum(df[~df["_phi_c"].isna()], w="weight_c_XX")
                sd_phi = Sd("_phi_c", df, w="weight_c_XX") / np.sqrt(nobs_c) if nobs_c > 0 else np.nan
                scalars[f"sd_delta_1_{pairwise}{pl}_XX"] = sd_phi
                df.drop(columns=["_phi_c", "_first_clus"], inplace=True)
            else:
                # IMPORTANT: R uses sd() unweighted with ddof=1 here
                phi_vals = df[f"Phi_1_{pairwise}{pl}_XX"].to_numpy(dtype=float)
                scalars[f"sd_delta_1_{pairwise}{pl}_XX"] = np.nanstd(phi_vals, ddof=1) / np.sqrt(wSum(df))

            se = scalars[f"sd_delta_1_{pairwise}{pl}_XX"]
            scalars[f"LB_1_{pairwise}{pl}_XX"] = scalars[f"delta_1_{pairwise}{pl}_XX"] - 1.96 * se
            scalars[f"UB_1_{pairwise}{pl}_XX"] = scalars[f"delta_1_{pairwise}{pl}_XX"] + 1.96 * se

            df[f"S_{pairwise}{pl}_XX"] = df["S_bis_XX"]
            df.loc[df["Ht_XX"] == 0, f"S_{pairwise}{pl}_XX"] = 0.0

        # -------------------------
        # 18C) WAOSS
        # -------------------------
        if waoss == 1:
            scalars[f"E_abs_delta_D{pl}_XX"] = Mean("abs_delta_D_XX", df)
            scalars[f"E_abs_delta_D_{pairwise}{pl}_XX"] = scalars[f"E_abs_delta_D{pl}_XX"] * scalars[f"P_Ht_{pairwise}{pl}_XX"]
            scalars[f"E_abs_delta_D_sum{pl}_XX"] = scalars.get(f"E_abs_delta_D_sum{pl}_XX", 0.0) + scalars[f"E_abs_delta_D_{pairwise}{pl}_XX"]

            for suffix in ("Minus", "Plus"):
                target_S = 1.0 if suffix == "Plus" else -1.0
                df["Ster_XX"] = np.where(df["S_XX"].isna(), np.nan, (df["S_XX"] == target_S).astype(float))

                df["prod_sgn_delta_D_delta_D_XX"] = df["S_XX"] * df["delta_D_XX"]
                sum_prod = Sum("prod_sgn_delta_D_delta_D_XX", df[df["Ster_XX"] == 1])
                scalars[f"w_{suffix}_{pairwise}{pl}_XX"] = sum_prod / scalars[f"N{pl}_XX"] if scalars[f"N{pl}_XX"] > 0 else 0.0

                denom = Sum("delta_D_XX", df[df["Ster_XX"] == 1])
                scalars[f"denom_delta_2_{suffix}_{pairwise}{pl}_XX"] = denom

                if estimation_method == "ra":
                    if denom == 0:
                        denom = 1.0
                        scalars[f"denom_delta_2_{suffix}_{pairwise}{pl}_XX"] = denom
                    num = Sum("inner_sum_delta_1_2_XX", df[df["Ster_XX"] == 1])
                    scalars[f"num_delta_2_{suffix}_{pairwise}{pl}_XX"] = num
                    scalars[f"delta_2_{suffix}_{pairwise}{pl}_XX"] = num / denom

                nb_sw = float(df[df["Ster_XX"] == 1].shape[0])
                scalars[f"nb_Switchers_{suffix}{pl}_XX"] = nb_sw
                scalars[f"PS_{suffix}1{pl}_XX"] = nb_sw / scalars[f"N{pl}_XX"] if scalars[f"N{pl}_XX"] > 0 else 0.0

                if not exact_match:
                    if scalars[f"PS_{suffix}1{pl}_XX"] == 0:
                        scalars[f"delta_2_{suffix}_{pairwise}{pl}_XX"] = 0.0
                        df[f"PS_1_{suffix}_D_1_XX"] = 0.0
                    else:
                        ps1_formula = f"Ster_XX ~ {reg_pol_terms}"
                        ps1_model = stata_logit(ps1_formula, df)
                        df = __lpredict(df, f"PS_1_{suffix}_D_1_XX", ps1_model)

                        if estimation_method == "ps":
                            df[f"delta_Y_P_{suffix}_XX"] = (
                                df["delta_Y_XX"]
                                * (df[f"PS_1_{suffix}_D_1_XX"] / df["PS_0_D_1_XX"])
                                * (scalars[f"PS_0{pl}_XX"] / scalars[f"PS_{suffix}1{pl}_XX"])
                            )
                            mean_delta_Y_P = Mean(f"delta_Y_P_{suffix}_XX", df[df["S_XX"] == 0])
                            mean_delta_Y = Mean("delta_Y_XX", df[df["Ster_XX"] == 1])
                            mean_delta_D = Mean("delta_D_XX", df[df["Ster_XX"] == 1])
                            scalars[f"delta_2_{suffix}_{pairwise}{pl}_XX"] = (mean_delta_Y - mean_delta_Y_P) / mean_delta_D

            if estimation_method in ("ra", "ps"):
                w_plus = scalars.get(f"w_Plus_{pairwise}{pl}_XX", 0.0)
                w_minus = scalars.get(f"w_Minus_{pairwise}{pl}_XX", 0.0)
                denomw = w_plus + w_minus
                scalars[f"W_Plus_{pairwise}{pl}_XX"] = (w_plus / denomw) if denomw != 0 else 0.0

            if not exact_match:
                df["dr_delta_Y_XX"] = (
                    (df["S_XX"]
                     - ((df.get("PS_1_Plus_D_1_XX", 0.0) - df.get("PS_1_Minus_D_1_XX", 0.0)) / df["PS_0_D_1_XX"])
                     * (1.0 - df["S_bis_XX"]))
                    * df["inner_sum_delta_1_2_XX"]
                )
                scalars[f"denom_dr_delta_2{pl}_XX"] = Sum("dr_delta_Y_XX", df)

            if estimation_method in ("ra", "ps"):
                Wp = scalars[f"W_Plus_{pairwise}{pl}_XX"]
                scalars[f"delta_2_{pairwise}{pl}_XX"] = (
                    Wp * scalars[f"delta_2_Plus_{pairwise}{pl}_XX"]
                    + (1.0 - Wp) * scalars[f"delta_2_Minus_{pairwise}{pl}_XX"]
                )
            elif estimation_method == "dr":
                sum_abs = Sum("abs_delta_D_XX", df)
                scalars[f"delta_2_{pairwise}{pl}_XX"] = scalars[f"denom_dr_delta_2{pl}_XX"] / sum_abs if sum_abs != 0 else 0.0

            if not exact_match:
                df[f"Phi_2_{pairwise}{pl}_XX"] = df["dr_delta_Y_XX"] - scalars[f"delta_2_{pairwise}{pl}_XX"] * df["abs_delta_D_XX"]
            else:
                df[f"Phi_2_{pairwise}{pl}_XX"] = (
                    (df["S_XX"] - df["ES_XX_D_1"] * ((1.0 - df["S_bis_XX"]) / (1.0 - df["ES_bis_XX_D_1"])))
                    * df["inner_sum_delta_1_2_XX"]
                    - scalars[f"delta_2_{pairwise}{pl}_XX"] * df["abs_delta_D_XX"]
                )

            denom_if = scalars[f"P_Ht_{pairwise}{pl}_XX"] * scalars[f"E_abs_delta_D{pl}_XX"]
            df[f"Phi_2_{pairwise}{pl}_XX"] = df[f"Phi_2_{pairwise}{pl}_XX"] / denom_if if denom_if != 0 else np.nan
            df.loc[df["Ht_XX"] == 0, f"Phi_2_{pairwise}{pl}_XX"] = 0.0

            # SE delta_2
            if cluster_col is not None:
                phi = f"Phi_2_{pairwise}{pl}_XX"
                df["_phi_c"] = df.groupby(cluster_col)[phi].transform(lambda s: float(np.nansum(s.to_numpy(dtype=float))))
                df["_first_clus"] = df.groupby(cluster_col).cumcount().eq(0)
                df["_phi_c"] = np.where(df["_first_clus"], df["_phi_c"], np.nan) / scalars[f"N_bar_c_{pairwise}{pl}_XX"]

                nobs_c = wSum(df[~df["_phi_c"].isna()], w="weight_c_XX")
                sd_phi = Sd("_phi_c", df, w="weight_c_XX") / np.sqrt(nobs_c) if nobs_c > 0 else np.nan
                scalars[f"sd_delta_2_{pairwise}{pl}_XX"] = sd_phi
                df.drop(columns=["_phi_c", "_first_clus"], inplace=True)
            else:
                # R uses its weighted Sd() helper here (variance divided by sum of weights).
                scalars[f"sd_delta_2_{pairwise}{pl}_XX"] = Sd(
                    f"Phi_2_{pairwise}{pl}_XX", df
                ) / np.sqrt(wSum(df))

            se = scalars[f"sd_delta_2_{pairwise}{pl}_XX"]
            scalars[f"LB_2_{pairwise}{pl}_XX"] = scalars[f"delta_2_{pairwise}{pl}_XX"] - 1.96 * se
            scalars[f"UB_2_{pairwise}{pl}_XX"] = scalars[f"delta_2_{pairwise}{pl}_XX"] + 1.96 * se

            df[f"abs_delta_D_{pairwise}{pl}_XX"] = np.where(df["Ht_XX"] == 0, 0.0, df["abs_delta_D_XX"])

        # -------------------------
        # 18D) IV-WAOSS (delta_3)
        # -------------------------
        if ivwaoss == 1:
            scalars[f"E_abs_delta_Z{pl}_XX"] = Mean("abs_delta_Z_XX", df)

            df["SI_bis_XX"] = ((df["SI_XX"] != 0) & (~df["SI_XX"].isna())).astype(float)
            df["SI_Plus_XX"] = np.where(df["SI_XX"].isna(), np.nan, (df["SI_XX"] == 1).astype(float))
            df["SI_Minus_XX"] = np.where(df["SI_XX"].isna(), np.nan, (df["SI_XX"] == -1).astype(float))

            df["S_IV_0_XX"] = 1.0 - df["SI_bis_XX"]

            if not exact_match:
                psiv0_formula = f"S_IV_0_XX ~ {IV_reg_pol_terms}"
                psiv0_model = stata_logit(psiv0_formula, df)
                df = __lpredict(df, "PS_IV_0_Z_1_XX", psiv0_model)
            else:
                esibis_formula = f"SI_bis_XX ~ {IV_reg_pol_terms}"
                esibis_model = smf.wls(esibis_formula, data=df, weights=df["weight_XX"]).fit()
                df = __lpredict(df, "ES_I_bis_XX_Z_1", esibis_model)

                esi_formula = f"SI_XX ~ {IV_reg_pol_terms}"
                esi_model = smf.wls(esi_formula, data=df, weights=df["weight_XX"]).fit()
                df = __lpredict(df, "ES_I_XX_Z_1", esi_model)

            scalars[f"PS_IV_0{pl}_XX"] = Mean("S_IV_0_XX", df)

            for suffix in ("Minus", "Plus"):
                flag = "SI_Minus_XX" if suffix == "Minus" else "SI_Plus_XX"
                nb = float((df[flag] == 1).sum())
                scalars[f"nb_Switchers_I_{suffix}{pl}_XX"] = nb
                scalars[f"PS_I_{suffix}_1{pl}_XX"] = nb / scalars[f"N{pl}_XX"] if scalars[f"N{pl}_XX"] > 0 else 0.0

                if scalars[f"PS_I_{suffix}_1{pl}_XX"] == 0:
                    df[f"PS_I_{suffix}_1_Z_1_XX"] = 0.0
                else:
                    if not exact_match:
                        psis_formula = f"{flag} ~ {IV_reg_pol_terms}"
                        psis_model = stata_logit(psis_formula, df)
                        df = __lpredict(df, f"PS_I_{suffix}_1_Z_1_XX", psis_model)

            
            # Products used by the PS estimator (mirrors the R code)
            df["prod_sgn_delta_Z_delta_Y_XX"] = df["SI_XX"] * df["delta_Y_XX"]
            df["prod_sgn_delta_Z_delta_D_XX"] = df["SI_XX"] * df["delta_D_XX"]

            df_temp = df[df["SI_XX"] == 0].copy()
            mY_formula = f"delta_Y_XX ~ {IV_reg_pol_terms}"
            mY_model = smf.wls(mY_formula, data=df_temp, weights=df_temp["weight_XX"]).fit()
            df = __lpredict(df, "mean_delta_Y_pred_IV_XX", mY_model)
            df["inner_sum_IV_num_XX"] = df["delta_Y_XX"] - df["mean_delta_Y_pred_IV_XX"]

            mD_formula = f"delta_D_XX ~ {IV_reg_pol_terms}"
            mD_model = smf.wls(mD_formula, data=df_temp, weights=df_temp["weight_XX"]).fit()
            df = __lpredict(df, "mean_delta_D_pred_IV_XX", mD_model)
            df["inner_sum_IV_denom_XX"] = df["delta_D_XX"] - df["mean_delta_D_pred_IV_XX"]

            if estimation_method == "ra":
                # Multiply by SI_XX 
                df["inner_sum_IV_num_XX"] = df["inner_sum_IV_num_XX"] * df["SI_XX"]
                df["inner_sum_IV_denom_XX"] = df["inner_sum_IV_denom_XX"] * df["SI_XX"]
                
                # FIX: Use Sum/N instead of Mean to match Stata behavior
                # Mean can give different results if there are NA values
                N_total = float(scalars[f"N{pl}_XX"])
                if N_total > 0:
                    scalars[f"num_delta_IV_{pairwise}{pl}_XX"] = Sum("inner_sum_IV_num_XX", df) / N_total
                    scalars[f"denom_delta_IV_{pairwise}{pl}_XX"] = Sum("inner_sum_IV_denom_XX", df) / N_total
                else:
                    scalars[f"num_delta_IV_{pairwise}{pl}_XX"] = np.nan
                    scalars[f"denom_delta_IV_{pairwise}{pl}_XX"] = np.nan

            if estimation_method == "ps":
                if exact_match:
                    raise ValueError("estimation_method='ps' is not implemented with exact_match=True (matches R behavior).")

                # Reweight stayers (SI_bis_XX==0) to mimic Stata's PS reweighting in the R code
                df["delta_Y_P_IV_XX"] = (
                    df["delta_Y_XX"]
                    * ((df.get("PS_I_Plus_1_Z_1_XX", 0.0) - df.get("PS_I_Minus_1_Z_1_XX", 0.0)) / df["PS_IV_0_Z_1_XX"])
                    * scalars[f"PS_IV_0{pl}_XX"]
                )
                mean_delta_Y_P_IV = Mean("delta_Y_P_IV_XX", df[df["SI_bis_XX"] == 0])
                mean_prod_sgn_Z_delta_Y = Mean("prod_sgn_delta_Z_delta_Y_XX", df)
                scalars[f"num_delta_IV_{pairwise}{pl}_XX"] = mean_prod_sgn_Z_delta_Y - mean_delta_Y_P_IV

                df["delta_D_P_IV_XX"] = (
                    df["delta_D_XX"]
                    * ((df.get("PS_I_Plus_1_Z_1_XX", 0.0) - df.get("PS_I_Minus_1_Z_1_XX", 0.0)) / df["PS_IV_0_Z_1_XX"])
                    * scalars[f"PS_IV_0{pl}_XX"]
                )
                mean_delta_D_P_IV = Mean("delta_D_P_IV_XX", df[df["SI_bis_XX"] == 0])
                mean_prod_sgn_Z_delta_D = Mean("prod_sgn_delta_Z_delta_D_XX", df)
                scalars[f"denom_delta_IV_{pairwise}{pl}_XX"] = mean_prod_sgn_Z_delta_D - mean_delta_D_P_IV

            if estimation_method == "dr":
                df["dr_IV_delta_Y_XX"] = (
                    (df["SI_XX"] - ((df.get("PS_I_Plus_1_Z_1_XX", 0.0) - df.get("PS_I_Minus_1_Z_1_XX", 0.0)) / df["PS_IV_0_Z_1_XX"])
                     * (1.0 - df["SI_bis_XX"]))
                    * df["inner_sum_IV_num_XX"]
                )
                scalars[f"num_delta_IV_{pairwise}{pl}_XX"] = Mean("dr_IV_delta_Y_XX", df)

                df["dr_IV_delta_D_XX"] = (
                    (df["SI_XX"] - ((df.get("PS_I_Plus_1_Z_1_XX", 0.0) - df.get("PS_I_Minus_1_Z_1_XX", 0.0)) / df["PS_IV_0_Z_1_XX"])
                     * (1.0 - df["SI_bis_XX"]))
                    * df["inner_sum_IV_denom_XX"]
                )
                scalars[f"denom_delta_IV_{pairwise}{pl}_XX"] = Mean("dr_IV_delta_D_XX", df)

            scalars[f"delta_3_{pairwise}{pl}_XX"] = (
                scalars[f"num_delta_IV_{pairwise}{pl}_XX"] / scalars[f"denom_delta_IV_{pairwise}{pl}_XX"]
                if scalars[f"denom_delta_IV_{pairwise}{pl}_XX"] != 0 else np.nan
            )

            
            # ------------------------------------------------------------
            # Influence function for IV-WAOSS (delta_3) + its pairwise SE
            # Mirrors the R code:
            #   Phi_Y_XX, Phi_D_XX  ->  Phi_3 = (Phi_Y - delta_3 * Phi_D) / delta_Dbar
            # ------------------------------------------------------------

            # Track the overall weight used later in did_multiplegt_stat_main
            scalars[f"denom_delta_IV_sum{pl}_XX"] = (
                scalars.get(f"denom_delta_IV_sum{pl}_XX", 0.0)
                + scalars.get(f"denom_delta_IV_{pairwise}{pl}_XX", np.nan)
            )

            # "Moment means" (used to center Phi_Y/Phi_D)
            scalars[f"delta_Y{pl}_XX"] = Mean("inner_sum_IV_num_XX", df)
            scalars[f"delta_D{pl}_XX"] = Mean("inner_sum_IV_denom_XX", df)

            # Residuals from the "stayers" regressions
            df["resid_Y_IV_XX"] = df["delta_Y_XX"] - df["mean_delta_Y_pred_IV_XX"]
            df["resid_D_IV_XX"] = df["delta_D_XX"] - df["mean_delta_D_pred_IV_XX"]

            E_abs = scalars.get(f"E_abs_delta_Z{pl}_XX", np.nan)

            # DR-style "score" on the residuals
            if not exact_match:
                denom_ps = df["PS_IV_0_Z_1_XX"].replace({0.0: np.nan})
                score = (
                    df["SI_XX"]
                    - (df.get("PS_I_Plus_1_Z_1_XX", 0.0) - df.get("PS_I_Minus_1_Z_1_XX", 0.0))
                    * (1.0 - df["SI_bis_XX"]) / denom_ps
                )
            else:
                denom_es = (1.0 - df["ES_I_bis_XX_Z_1"]).replace({0.0: np.nan})
                score = (
                    df["SI_XX"]
                    - df["ES_I_XX_Z_1"] * ((1.0 - df["SI_bis_XX"]) / denom_es)
                )

            df["Phi_Y_XX"] = (
                score * df["resid_Y_IV_XX"] - scalars[f"delta_Y{pl}_XX"] * df["abs_delta_Z_XX"]
            ) / E_abs

            df["Phi_D_XX"] = (
                score * df["resid_D_IV_XX"] - scalars[f"delta_D{pl}_XX"] * df["abs_delta_Z_XX"]
            ) / E_abs

            delta_D_bar = scalars.get(f"delta_D{pl}_XX", np.nan)
            delta3 = scalars.get(f"delta_3_{pairwise}{pl}_XX", np.nan)

            if (delta_D_bar is None) or np.isnan(delta_D_bar) or (delta_D_bar == 0):
                df[f"Phi_3_{pairwise}{pl}_XX"] = np.nan
                scalars[f"sd_delta_3_{pairwise}{pl}_XX"] = np.nan
                scalars[f"LB_3_{pairwise}{pl}_XX"] = np.nan
                scalars[f"UB_3_{pairwise}{pl}_XX"] = np.nan
            else:
                df[f"Phi_3_{pairwise}{pl}_XX"] = (df["Phi_Y_XX"] - delta3 * df["Phi_D_XX"]) / delta_D_bar

                # SE for the pairwise delta_3
                if cluster_col is not None:
                    phi = f"Phi_3_{pairwise}{pl}_XX"
                    df["_phi_c"] = df.groupby(cluster_col)[phi].transform(
                        lambda s: float(np.nansum(s.to_numpy(dtype=float)))
                    )
                    df["_first_clus"] = df.groupby(cluster_col).cumcount().eq(0)
                    df["_phi_c"] = np.where(df["_first_clus"], df["_phi_c"], np.nan) / scalars[
                        f"N_bar_c_{pairwise}{pl}_XX"
                    ]

                    nobs_c = wSum(df[~df["_phi_c"].isna()], w="weight_c_XX")
                    sd_phi = Sd("_phi_c", df, w="weight_c_XX") / np.sqrt(nobs_c) if nobs_c > 0 else np.nan
                    scalars[f"sd_delta_3_{pairwise}{pl}_XX"] = sd_phi
                    df.drop(columns=["_phi_c", "_first_clus"], inplace=True)
                else:
                    # R uses its weighted Sd() helper here too.
                    scalars[f"sd_delta_3_{pairwise}{pl}_XX"] = Sd(
                        f"Phi_3_{pairwise}{pl}_XX", df
                    ) / np.sqrt(wSum(df))

                se3 = scalars[f"sd_delta_3_{pairwise}{pl}_XX"]
                scalars[f"LB_3_{pairwise}{pl}_XX"] = delta3 - 1.96 * se3
                scalars[f"UB_3_{pairwise}{pl}_XX"] = delta3 + 1.96 * se3

            # This is later used in did_multiplegt_stat_main for weighting across pairs
            df[f"inner_sum_IV_denom_{pairwise}{pl}_XX"] = df["inner_sum_IV_denom_XX"]

        scalars[f"non_missing_{pairwise}{pl}_XX"] = 1.0

    else:
        # Not feasible => defaults
        for i in (1, 2, 3):
            scalars[f"delta_{i}_{pairwise}{pl}_XX"] = 0.0
            scalars[f"sd_delta_{i}_{pairwise}{pl}_XX"] = np.nan
            scalars[f"LB_{i}_{pairwise}{pl}_XX"] = np.nan
            scalars[f"UB_{i}_{pairwise}{pl}_XX"] = np.nan
            df[f"Phi_{i}_{pairwise}{pl}_XX"] = np.nan

        if aoss == 1:
            scalars[f"P_{pairwise}{pl}_XX"] = 0.0
        if waoss == 1:
            scalars[f"E_abs_delta_D_{pairwise}{pl}_XX"] = 0.0
        if ivwaoss == 1:
            scalars[f"denom_delta_IV_{pairwise}{pl}_XX"] = 0.0

        scalars[f"non_missing_{pairwise}{pl}_XX"] = 0.0

    # ------------------------------------------------------------
    # 20) Prepare "to_add" DataFrame
    # ------------------------------------------------------------
    df = df.sort_values(["ID_XX"], kind="mergesort").reset_index(drop=True)

    keep_cols = [
        "ID_XX",
        f"Phi_1_{pairwise}{pl}_XX",
        f"Phi_2_{pairwise}{pl}_XX",
        f"Phi_3_{pairwise}{pl}_XX",
        f"S_{pairwise}{pl}_XX",
        f"abs_delta_D_{pairwise}{pl}_XX",
        f"used_in_{pairwise}{pl}_XX",
        f"inner_sum_IV_denom_{pairwise}{pl}_XX",
    ]

    if cluster is not None:
        if "cluster_XX" in df.columns:
            keep_cols.append("cluster_XX")
        elif cluster in df.columns:
            keep_cols.append(cluster)

    keep_cols = [c for c in keep_cols if c in df.columns]
    out_df = df.loc[:, keep_cols].copy()

    # ------------------------------------------------------------
    # 21) Final scalar bookkeeping for aggregation
    # ------------------------------------------------------------
    if waoss == 1 or aoss == 1:
        scalars[f"N_Switchers_1_{pairwise}{pl}_XX"] = scalars.get(f"N_Switchers{pl}_XX", 0.0)
        scalars[f"N_Stayers_1_{pairwise}{pl}_XX"] = scalars.get(f"N_Stayers{pl}_XX", 0.0)
        scalars[f"N_Switchers_2_{pairwise}{pl}_XX"] = scalars.get(f"N_Switchers{pl}_XX", 0.0)
        scalars[f"N_Stayers_2_{pairwise}{pl}_XX"] = scalars.get(f"N_Stayers{pl}_XX", 0.0)

    if ivwaoss == 1:
        scalars[f"N_Switchers_3_{pairwise}{pl}_XX"] = scalars.get(f"N_Switchers_IV{pl}_XX", 0.0)
        scalars[f"N_Stayers_3_{pairwise}{pl}_XX"] = scalars.get(f"N_Stayers_IV{pl}_XX", 0.0)

    for i, active in enumerate([aoss, waoss, ivwaoss], start=1):
        if active == 1:
            scalars[f"delta_{i}_{pairwise}{pl}_XX"] = scalars.get(f"delta_{i}_{pairwise}{pl}_XX", 0.0)
            scalars[f"sd_delta_{i}_{pairwise}{pl}_XX"] = scalars.get(f"sd_delta_{i}_{pairwise}{pl}_XX", np.nan)
            scalars[f"LB_{i}_{pairwise}{pl}_XX"] = scalars.get(f"LB_{i}_{pairwise}{pl}_XX", np.nan)
            scalars[f"UB_{i}_{pairwise}{pl}_XX"] = scalars.get(f"UB_{i}_{pairwise}{pl}_XX", np.nan)

    return {"scalars": scalars, "to_add": out_df}
