In [None]:
# ============================================================
# Chaisemartin & D'Haultfoeuille (2020) — DID_M
# ATT(g,t), ATT médio e Event Study por k, com cluster bootstrap por id
# ============================================================
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import t as tdist

# ---------------- Parâmetros ----------------
REF            = -1            # referência do event study (lead omitido ao plotar; ATT_k para k!=REF)
WINDOW         = None          # ex.: (-10, 10) para limitar janelas no gráfico; ou None
INCLUDE_NEVER  = True          # True: controles = "not-yet": G>t ou G=+inf (recomendado pelos autores)
B_BOOT         = 200           # nº de réplicas do bootstrap em blocos (ids)
ALPHA          = 0.05          # nível p/ IC bilaterial (95% -> 0.05)
SHOW_ATT_LINE  = True          # desenhar linha do ATT médio no gráfico

# ---------------- Preparação da base ----------------
df_cdh = df.copy()

# Garantir G (coorte de 1ª adoção); usar _nfd se G não existir
if "G" not in df_cdh.columns:
    if "_nfd" in df_cdh.columns:
        df_cdh["G"] = df_cdh["_nfd"]
    else:
        raise ValueError("Não encontrei 'G' nem '_nfd' na base.")

# Tipos básicos
df_cdh["id"]   = df_cdh["id"].astype("category")
df_cdh["year"] = pd.to_numeric(df_cdh["year"], errors="coerce").astype(int)
df_cdh["G"]    = pd.to_numeric(df_cdh["G"],    errors="coerce")

# Marcadores
ever = np.isfinite(df_cdh["G"])
df_cdh["treated_ever"] = ever.astype(int)

# Ordenação e diferença em 2 períodos (ΔY_it = Y_it - Y_i,t-1)
df_cdh = df_cdh.sort_values(["id", "year"]).reset_index(drop=True)
df_cdh["_dY"] = df_cdh.groupby("id")["Y"].diff(1)

# Guardar anos e coortes úteis
years   = np.sort(df_cdh["year"].unique())
cohorts = np.sort(df_cdh.loc[ever, "G"].dropna().unique())

# ---------------- Funções auxiliares ----------------
def att_by_gt(df_in):
    """
    Computa ATT(g,t) para todos g,t com t >= g:
      ATT(g,t) = mean(ΔY_it | G=g, t) - mean(ΔY_it | G>t (ou never), t)
    Controles: "not-yet-treated" (INCLUDE_NEVER=True inclui never automaticamente).
    Retorna dataframe com colunas: ['g','t','n_treated','att_gt'].
    """
    out = []
    for t in years[1:]:  # precisa ter t-1
        # tratados em t: unidades com G <= t e já tratadas desde g; CDH usam coorte g (primeiro tratamento)
        # Para ATT(g,t), grupo tratado é G=g e observações no ano t (com ΔY observado)
        for g in cohorts:
            if t < g: 
                continue
            treated_mask = (df_in["year"] == t) & (df_in["G"] == g)
            # controles = "not yet": G > t (ainda não tratados em t); inclui never (G NaN)
            if INCLUDE_NEVER:
                control_mask = (df_in["year"] == t) & ( (~df_in["G"].notna()) | (df_in["G"] > t) )
            else:
                # só not-yet (exclui never); se quiser incluir apenas never, mude aqui
                control_mask = (df_in["year"] == t) & (df_in["G"].notna()) & (df_in["G"] > t)

            dY_treated = df_in.loc[treated_mask, "_dY"].dropna()
            dY_control = df_in.loc[control_mask, "_dY"].dropna()

            n_tr = dY_treated.shape[0]
            n_ct = dY_control.shape[0]
            if n_tr == 0 or n_ct == 0:
                continue

            att_gt = dY_treated.mean() - dY_control.mean()
            out.append({"g": int(g), "t": int(t), "n_treated": int(n_tr), "att_gt": float(att_gt)})
    return pd.DataFrame(out, columns=["g","t","n_treated","att_gt"])

def aggregate_att(att_gt_df):
    """ATT médio (peso = nº de tratados n_treated em (g,t))."""
    if att_gt_df.empty:
        return np.nan
    w = att_gt_df["n_treated"].to_numpy(dtype=float)
    w = w / w.sum()
    return float((w * att_gt_df["att_gt"].to_numpy()).sum())

def event_study_from_gt(att_gt_df, ref=REF, window=WINDOW):
    """Agrega ATT(g,t) para ATT_k por k=t-g, excluindo o k=ref do gráfico."""
    if att_gt_df.empty:
        return pd.DataFrame(columns=["k","estimate","n_treated"])
    est = (att_gt_df.assign(k=lambda d: d["t"] - d["g"])
                      .groupby("k", as_index=False)
                      .apply(lambda d: pd.Series({
                          "estimate": np.average(d["att_gt"], weights=d["n_treated"]),
                          "n_treated": d["n_treated"].sum()
                      }))
                      .reset_index(drop=True))
    # janela opcional
    if window is not None:
        lo, hi = window
        est = est[(est["k"] >= lo) & (est["k"] <= hi)].copy()
    # tira o ponto de referência do gráfico
    est_plot = est[est["k"] != ref].sort_values("k").reset_index(drop=True)
    return est, est_plot

def cluster_bootstrap_ids(df_in, B=B_BOOT, random_state=123):
    """
    Bootstrap em blocos por id: reamostra ids com reposição, mantém todos os anos do id.
    Retorna lista de (ATT_hat, ES_full_df) para cada réplica.
    """
    rng = np.random.default_rng(random_state)
    ids = df_in["id"].cat.categories if hasattr(df_in["id"], "cat") else df_in["id"].unique()
    ids = np.array(list(ids))
    out = []
    for b in range(B):
        sample_ids = rng.choice(ids, size=len(ids), replace=True)
        df_b = pd.concat([df_in.loc[df_in["id"] == sid] for sid in sample_ids], axis=0, ignore_index=True)
        # recomputar ΔY por id (ordem pode ter mexido)
        df_b = df_b.sort_values(["id","year"]).reset_index(drop=True)
        df_b["_dY"] = df_b.groupby("id")["Y"].diff(1)

        att_gt_b = att_by_gt(df_b)
        att_b    = aggregate_att(att_gt_b)
        ES_full_b, _ = event_study_from_gt(att_gt_b, ref=REF, window=WINDOW)
        out.append((att_b, ES_full_b))
    return out

# ---------------- Estimação pontual ----------------
att_gt = att_by_gt(df_cdh)

ATT_hat = aggregate_att(att_gt)
ES_full, ES_plot = event_study_from_gt(att_gt, ref=REF, window=WINDOW)

# ---------------- EPs via cluster bootstrap ----------------
boot = cluster_bootstrap_ids(df_cdh, B=B_BOOT, random_state=123)

# SE para ATT médio
att_boot = np.array([x[0] for x in boot], dtype=float)
se_att  = float(np.nanstd(att_boot, ddof=1))
# IC com distribuição t com gl ~ (#clusters - 1)
G = (df_cdh["id"].cat.categories.size if hasattr(df_cdh["id"], "cat") else df_cdh["id"].nunique())
df_t = max(G - 1, 1)
tcrit = float(tdist.ppf(1 - ALPHA/2, df_t))
ci_att = (ATT_hat - tcrit*se_att, ATT_hat + tcrit*se_att)

print("=== CDH (2020) — DID_M: ATT médio (not-yet-treated) ===")
print(f"ATT            : {ATT_hat: .4f}")
print(f"SE (cluster id): {se_att: .4f}")
print(f"IC {int((1-ALPHA)*100)}%     : [{ci_att[0]: .4f}, {ci_att[1]: .4f}]")

# SE para cada k do Event Study (alinha suportes entre réplicas)
def es_at_k(es_df, k):
    row = es_df.loc[es_df["k"] == k, "estimate"]
    return np.nan if row.empty else float(row.iloc[0])

ks_all = sorted(ES_full["k"].unique())
se_k = []
for k in ks_all:
    boot_k = np.array([es_at_k(es_df, k) for (_att_b, es_df) in boot], dtype=float)
    se_k.append(np.nanstd(boot_k, ddof=1))
ES_full["se"] = se_k
ES_full["ci_low"]  = ES_full["estimate"] - tcrit*ES_full["se"]
ES_full["ci_high"] = ES_full["estimate"] + tcrit*ES_full["se"]

# Para o gráfico, tira o k=REF
ES_plot = ES_full[ES_full["k"] != REF].copy().sort_values("k")

# ---------------- Gráfico (estilo clean) ----------------
plt.figure(figsize=(10, 6))
plt.errorbar(
    ES_plot["k"], ES_plot["estimate"],
    yerr=[ES_plot["estimate"] - ES_plot["ci_low"],
          ES_plot["ci_high"] - ES_plot["estimate"]],
    fmt='s', color='black', markersize=4,
    ecolor='gray', elinewidth=0.8, capsize=2, label='CDH DID_M'
)

plt.axhline(0, color='black', linestyle='--', linewidth=0.8)
if SHOW_ATT_LINE and np.isfinite(ATT_hat):
    plt.axhline(ATT_hat, color='blue', linestyle='-', linewidth=1,
                label=f'ATT ({ATT_hat:.2f})')

plt.axvline(REF, color='black', linestyle='--', linewidth=0.8, label=f'Reference ({REF})')
plt.title('Chaisemartin & D’Haultfoeuille (2020) — DID_M Event Study', fontsize=12)
plt.xlabel('Years Relative to Treatment (k = t - g)', fontsize=10)
plt.ylabel('Coefficient (ATT_k)', fontsize=10)
plt.legend(frameon=False, fontsize=10)
plt.grid(True, linestyle=':', alpha=0.5)
plt.tight_layout()
plt.show()
