In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import Optional, List, Dict, Any


def _balance_panel_fill(df: pd.DataFrame, ID: str, Time: str) -> pd.DataFrame:
    """
    Rough Python equivalent of `plm::make.pbalanced(..., balance.type="fill")`.

    Assumptions (same as the R code comment):
    - At most one observation per (ID, Time) cell.

    Adds:
    - `tsfilled_XX`: 1 if the row was added by balancing, 0 otherwise.
    """
    d = df.copy()

    # R note: multiple observations per cell are not handled.
    if d.duplicated(subset=[ID, Time], keep=False).any():
        raise ValueError(
            "Multiple observations per (ID, Time) cell detected. "
            "The R code notes this case is not handled. Please aggregate first."
        )

    d = d.sort_values([ID, Time], kind="mergesort")

    ids = pd.Index(pd.unique(d[ID]))
    # Safer than Index.sort_values() for older pandas
    times_raw = pd.unique(d[Time])
    try:
        # numeric-ish
        times = np.sort(times_raw.astype(float))
    except Exception:
        times = sorted(list(times_raw))

    full_index = pd.MultiIndex.from_product([ids, times], names=[ID, Time])

    d2 = d.set_index([ID, Time]).reindex(full_index)

    orig_index = pd.MultiIndex.from_frame(d[[ID, Time]])
    d2["tsfilled_XX"] = np.where(d2.index.isin(orig_index), 0.0, 1.0)

    return d2.reset_index()


def did_multiplegt_stat_quantiles(
    df: pd.DataFrame,
    ID: str,
    Time: str,
    D: str,
    Z: Optional[str],
    by_opt: int,
    quantiles: List[float],
) -> Dict[str, Any]:
    """
    Python port of `did_multiplegt_stat_quantiles()`.

    This is used by `did_multiplegt_stat` when `by_fd` is specified, to create
    quantile bins based on |ΔD| (or |ΔZ| for IV-WAOSS).

    Returns
    -------
    dict with keys:
      - df: balanced df with `partition_XX`, `in_aggregation_XX`, `tsfilled_XX`
      - val_quantiles: cutoffs on |ΔD| / |ΔZ|
      - quantiles: achieved CDF cutoffs (may differ when there is point mass)
      - switch_df: per-bin counts and median |ΔD|/|ΔZ|
      - quantiles_plot: matplotlib Figure (CDF plot)
    """
    d = df.copy()

    # Drop missing essentials (same as R)
    mask_drop = d[ID].isna() | d[Time].isna() | d[D].isna()
    d = d.loc[~mask_drop].copy()
    if Z is not None:
        d = d.loc[~d[Z].isna()].copy()

    # Balance the panel (fill missing cells)
    d = _balance_panel_fill(d, ID=ID, Time=Time)
    d = d.sort_values([ID, Time], kind="mergesort").copy()

    # Compute |diff| within unit over time
    if Z is None:
        base = pd.to_numeric(d[D], errors="coerce")
        d["delta_pre_XX"] = base.groupby(d[ID]).diff().abs()
        gr_title = "ΔD"
    else:
        base = pd.to_numeric(d[Z], errors="coerce")
        d["delta_pre_XX"] = base.groupby(d[ID]).diff().abs()
        gr_title = "ΔZ"

    # R: switchers_dummy_XX <- delta_pre_XX != 0   (NA stays NA)
    switchers_dummy = np.where(
        d["delta_pre_XX"].isna(),
        np.nan,
        (d["delta_pre_XX"] != 0).astype(float),
    )
    d["switchers_dummy_XX"] = switchers_dummy

    # R: sum(..., na.rm=TRUE) by time
    g = d.groupby(Time, sort=False)
    d["switchers_N_XX"] = g["switchers_dummy_XX"].transform(lambda s: float(np.nansum(s.to_numpy(dtype=float))))
    d["stayers_N_XX"] = g["switchers_dummy_XX"].transform(lambda s: float(np.nansum((1.0 - s.to_numpy(dtype=float)))))

    # IMPORTANT: in R, in_aggregation_XX is logical but later used like 1/0.
    d["in_aggregation_XX"] = ((d["switchers_N_XX"] > 0) & (d["stayers_N_XX"] > 1)).astype(int)

    d = d.drop(columns=["switchers_dummy_XX", "switchers_N_XX", "stayers_N_XX"])

    # Distribution of |Δ| among switchers in aggregation periods
    d_switch = d.loc[
        d["delta_pre_XX"].notna()
        & (d["delta_pre_XX"] != 0)
        & (d["in_aggregation_XX"] == 1)
    ].copy()

    N_switchers_plot = int(len(d_switch))

    dist = (
        d_switch.groupby("delta_pre_XX", dropna=True)
        .size()
        .rename("tot_delta_XX")
        .reset_index()
        .sort_values("delta_pre_XX", kind="mergesort")
    )

    # No switchers => return minimal structure
    if dist.shape[0] == 0:
        d["partition_XX"] = np.nan
        switch_df = pd.DataFrame(columns=["partition_XX", "N_partition_XX", "Med_delta_pre_XX"])
        fig = plt.figure(figsize=(6, 4))
        plt.title(f"{gr_title} distribution (no switchers to bin)")
        plt.close(fig)

        d = d.drop(columns=["delta_pre_XX"], errors="ignore")
        return {
            "df": d,
            "val_quantiles": [],
            "quantiles": [0.0, 1.0],
            "switch_df": switch_df,
            "quantiles_plot": fig,
        }

    dist["tot_delta_XX"] = dist["tot_delta_XX"] / dist["tot_delta_XX"].sum()
    dist["cdf"] = dist["tot_delta_XX"].cumsum()

    # Validate / normalize quantiles vector
    q = list(map(float, quantiles)) if quantiles is not None else [0.0, 1.0]
    if len(q) == 0:
        q = [0.0, 1.0]
    if q[0] != 0.0:
        q = [0.0] + q
    if q[-1] != 1.0:
        q = q + [1.0]
    # Ensure nondecreasing
    if any(q[i] > q[i + 1] for i in range(len(q) - 1)):
        raise ValueError("`quantiles` must be nondecreasing (e.g., [0, .2, .4, .6, .8, 1]).")

    dist["partition_XX"] = by_opt
    cut_off: List[float] = []
    quantiles_temp: List[float] = [0.0]

    # R loop: j in 2:length(quantiles), assign partition=j-1 on [q_{j-1}, q_j)
    for j in range(1, len(q)):
        q_lo, q_hi = q[j - 1], q[j]
        mask = (dist["cdf"] >= q_lo) & (dist["cdf"] < q_hi)
        dist.loc[mask, "partition_XX"] = j  # j corresponds to (j) in Python, (j-1) in R
        if mask.any():
            cut_off.append(float(dist.loc[mask, "delta_pre_XX"].min()))
            quantiles_temp.append(float(dist.loc[mask, "cdf"].max()))

    # Final upper cutoff
    cut_off.append(float(dist["delta_pre_XX"].max()))

    # Plot (matplotlib version of ggplot CDF)
    fig = plt.figure(figsize=(7.5, 4.5))
    plt.plot(dist["delta_pre_XX"], dist["cdf"], marker="o", linewidth=1.0)
    for x in cut_off:
        plt.axvline(x, linewidth=0.6, linestyle="--")
    plt.xlabel(f"|{gr_title}|")
    plt.ylabel("CDF")
    plt.title(f"Empirical distribution of {gr_title} (N={N_switchers_plot})\nBin cutoffs shown as vertical lines")
    plt.tight_layout()
    plt.close(fig)

    achieved_quantiles = quantiles_temp

    # Assign partition_XX to full df (match R lines 34-20)
    d["switchers_XX"] = (
        d["delta_pre_XX"].notna()
        & (d["delta_pre_XX"] != 0)
        & (d["in_aggregation_XX"] == 1)
    )
    d["partition_XX"] = np.where(d["switchers_XX"], 1.0, 0.0)

    for p in range(1, len(cut_off)):
        lo = cut_off[p - 1]
        hi = cut_off[p]
        mask = d["switchers_XX"] & (d["delta_pre_XX"] > lo) & (d["delta_pre_XX"] <= hi)
        d.loc[mask, "partition_XX"] = float(p)

    d["partition_XX"] = pd.to_numeric(d["partition_XX"], errors="coerce")
    d.loc[d["in_aggregation_XX"] != 1, "partition_XX"] = np.nan

    # switch_df summary (same columns as R)
    d["it_XX"] = 1.0
    switch_df = (
        d.loc[d["partition_XX"].notna() & (d["partition_XX"] != 0)]
        .groupby("partition_XX", as_index=False)
        .agg(
            N_partition_XX=("it_XX", "sum"),
            Med_delta_pre_XX=("delta_pre_XX", lambda s: float(np.nanmedian(s.to_numpy(dtype=float)))),
        )
    )

    d = d.drop(columns=["it_XX"], errors="ignore")
    d = d.drop(columns=["delta_pre_XX", "switchers_XX"], errors="ignore")

    return {
        "df": d,
        "val_quantiles": cut_off,
        "quantiles": achieved_quantiles,
        "switch_df": switch_df,
        "quantiles_plot": fig,
    }
