In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Batch Regression-Discontinuity (RD) – Price Shocks (Modularized)
------------------------------------------------------------------
This script runs RDD analysis based on the modular data structure.
It builds an Excel workbook in the 'outcome' folder that includes:

    • A “Summary” sheet listing τ̂ (treatment effect), p-value, 95 % CI, N, etc.
    • An RD plot embedded in each corresponding row.

Required for Export:
    $ pip install xlsxwriter
"""

import io
import pathlib
import warnings
from typing import List, Dict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ─────────────────────────────────────────────────────────────────────────────
# 1)  Dependency checks
try:
    import statsmodels.api as sm
    HAVE_SM = True
except Exception as e:
    warnings.warn(f"statsmodels unavailable ({e}); falling back to NumPy OLS.")
    HAVE_SM = False

try:
    from scipy import stats as sps
    HAVE_SCIPY = True
except Exception:
    HAVE_SCIPY = False

# ─────────────────────────────────────────────────────────────────────────────
# 2)  Global parameters
BANDWIDTHS = [10, 20]            # ±n trading-day windows
OUTCOME    = "Price"             # Column under analysis

# ─────────────────────────────────────────────────────────────────────────────
# 3)  Helper functions
def read_csv_robustly(path: pathlib.Path, engine: str = 'c', sep=','):
    """Reads a CSV file by trying a sequence of common encodings."""
    encodings_to_try = ['utf-8', 'utf-8-sig', 'gbk', 'gb2312', 'latin-1']
    if engine == 'python': sep = None
    for enc in encodings_to_try:
        try:
            return pd.read_csv(path, encoding=enc, engine=engine, sep=sep)
        except (UnicodeDecodeError, UnicodeError, pd.errors.ParserError):
            continue
    raise ValueError(f"Failed to read or parse '{path}'.")

def std_cols(df: pd.DataFrame) -> pd.DataFrame:
    """Normalise column names for consistency."""
    df.columns = (df.columns.str.lower().str.replace(" ", "").str.replace(".", "", regex=False).str.strip())
    return df

def load_price(path: pathlib.Path) -> pd.DataFrame:
    """Read a CSV, clean the Price column, return a Date-sorted DataFrame."""
    df = read_csv_robustly(path)
    # Standardize column names before checking for them
    df = std_cols(df)
    
    # Check for lowercase 'price' due to std_cols
    outcome_lower = OUTCOME.lower()
    if "date" not in df.columns or outcome_lower not in df.columns:
        raise ValueError(f"File '{path}' must contain 'Date' and '{OUTCOME}' columns.")
        
    df = df.rename(columns={outcome_lower: OUTCOME}) # Rename back to original case for compatibility
    
    df["Date"] = pd.to_datetime(df["date"])
    df[OUTCOME] = (df[OUTCOME].astype(str)
                           .str.replace(r"[^0-9\.\+\-eE]", "", regex=True)
                           .replace("", np.nan).astype(float))
    return df.sort_values("Date").reset_index(drop=True)


def ols_numpy(y: np.ndarray, X: np.ndarray):
    """Lightweight OLS with White SEs; returns beta, p-values, covariance."""
    n, k = X.shape
    beta  = np.linalg.lstsq(X, y, rcond=None)[0]
    resid = y - X @ beta
    sigma2 = (resid @ resid) / (n - k)
    cov = sigma2 * np.linalg.inv(X.T @ X)
    se = np.sqrt(np.diag(cov))
    p = (2 * (1 - sps.t.cdf(np.abs(beta / se), df=n - k)) if HAVE_SCIPY else 2 * (1 - np.exp(-0.5 * (beta / se) ** 2) / np.sqrt(2 * np.pi) / np.abs(beta / se)))
    return beta, p, cov


def rd_design(df: pd.DataFrame, event_date: pd.Timestamp, bw: int, y_col: str):
    """Build an RD window and return the window DataFrame and a fitted model."""
    df = df.copy()
    df["D"], df["T"] = (df["Date"] - event_date).dt.days, (df["Date"] >= event_date).astype(int)
    win = df[df["D"].between(-bw, bw)].dropna(subset=[y_col])
    X_df = win[["T", "D"]].astype(float)
    X_df["TD"] = X_df["T"] * X_df["D"]
    X, y = np.column_stack([np.ones(len(X_df)), X_df.to_numpy()]), win[y_col].to_numpy(float)
    cols = ["const", "T", "D", "TD"]
    if HAVE_SM:
        model = sm.OLS(y, X).fit(cov_type="HAC", cov_kwds={"maxlags": 3})
        model.colnames, model.params, model.pvalues, model.cov = cols, pd.Series(model.params, index=cols), pd.Series(model.pvalues, index=cols), model.cov_params()
        return win, model
    else:
        beta, p, cov = ols_numpy(y, X)
        class Result:
            params, pvalues, cov, colnames = pd.Series(beta, index=cols), pd.Series(p, index=cols), pd.DataFrame(cov, index=cols, columns=cols), cols
            def conf_int(self):
                se = np.sqrt(np.diag(self.cov))
                return pd.DataFrame(np.column_stack([self.params - 1.96 * se, self.params + 1.96 * se]), index=cols, columns=["low", "high"])
        return win, Result()


def get_ci(model, param: str):
    """Return (low, high) 95 % CI for *param*, backend-agnostic."""
    ci = model.conf_int()
    return ci.loc[param] if isinstance(ci, pd.DataFrame) else (ci[model.colnames.index(param), 0], ci[model.colnames.index(param), 1])


def rd_plot_fig(win: pd.DataFrame, model, evt: pd.Timestamp, bw: int, asset: str, y_col: str):
    """Create an RD scatter plot and return a matplotlib Figure."""
    grid = np.arange(-bw, bw + 1)
    α, τ, β, γ = (model.params[k] for k in ["const", "T", "D", "TD"])
    μ_L, μ_R = α + β * grid, α + τ + (β + γ) * grid
    cov = model.cov if isinstance(model.cov, np.ndarray) else model.cov.to_numpy()
    X_L = np.column_stack([np.ones_like(grid), np.zeros_like(grid), grid, np.zeros_like(grid)])
    X_R = np.column_stack([np.ones_like(grid), np.ones_like(grid),  grid, grid])
    se_L, se_R = np.sqrt(np.einsum("ij,jk,ik->i", X_L, cov, X_L)), np.sqrt(np.einsum("ij,jk,ik->i", X_R, cov, X_R))
    upper_L, lower_L, upper_R, lower_R = μ_L + 1.96 * se_L, μ_L - 1.96 * se_L, μ_R + 1.96 * se_R, μ_R - 1.96 * se_R

    fig, ax = plt.subplots(figsize=(7, 4))
    colors = win["D"].apply(lambda x: "royalblue" if x < 0 else "firebrick")
    ax.scatter(win["D"], win[y_col], s=18, color=colors, alpha=0.7, zorder=2)
    ax.plot(grid[grid < 0],  μ_L[grid < 0],  color="forestgreen", lw=2)
    ax.plot(grid[grid >= 0], μ_R[grid >= 0], color="forestgreen", lw=2)
    ax.fill_between(grid[grid < 0], lower_L[grid < 0], upper_L[grid < 0], color='grey', alpha=0.3)
    ax.fill_between(grid[grid >= 0], lower_R[grid >= 0], upper_R[grid >= 0], color='grey', alpha=0.3)
    ax.axvline(0, color="crimson", lw=2, ls="--")
    ax.set_title(f"{asset} | {evt.date()} ±{bw}d  τ̂={model.params['T']:.4f}  p={model.pvalues['T']:.3g}")
    ax.set_xlabel("Days relative to event"), ax.set_ylabel(y_col)
    plt.tight_layout()
    return fig

# ─────────────────────────────────────────────────────────────────────────────
# 4)  Main driver
if __name__ == "__main__":
    # --- ADJUSTED: Dynamic Data Loading ---
    print("--- Loading Benchmark Data ---")
    # Define the directory where benchmark CSVs are located.
    BENCHMARK_DIR = pathlib.Path("./benchmark")
    if not BENCHMARK_DIR.is_dir(): raise FileNotFoundError(f"Benchmark directory '{BENCHMARK_DIR}' not found.")
    
    bench_files = ["Gold.csv", "Nasdaq100.csv", "SPY.csv"]
    # Load each file from the benchmark directory.
    loaded_data = {
        pathlib.Path(f).stem: load_price(BENCHMARK_DIR / f) for f in bench_files
    }
    print(f"  • Loaded: {', '.join(loaded_data.keys())}")
    
    print("\n--- Loading Crypto Asset Data ---")
    CRYPTO_DATA_DIR = pathlib.Path("./crypto_data")
    if not CRYPTO_DATA_DIR.is_dir(): raise FileNotFoundError(f"Directory '{CRYPTO_DATA_DIR}' not found.")
    crypto_files = list(CRYPTO_DATA_DIR.glob("*.csv"))
    if not crypto_files: raise FileNotFoundError(f"No CSV files found in '{CRYPTO_DATA_DIR}'.")
    for f_path in crypto_files:
        asset_name = f_path.stem
        loaded_data[asset_name] = load_price(f_path)
    print(f"  • Found and loaded {len(crypto_files)} crypto assets.")

    print("\n--- Loading Wide-Format Event Calendar Data ---")
    EVENTS_DIR = pathlib.Path("./events")
    train_events_file, test_events_file = EVENTS_DIR / "training_set.csv", EVENTS_DIR / "test_set.csv"
    if not EVENTS_DIR.is_dir(): raise FileNotFoundError(f"Directory '{EVENTS_DIR}' not found.")
    if not train_events_file.is_file(): raise FileNotFoundError(f"File '{train_events_file}' not found.")
    if not test_events_file.is_file(): raise FileNotFoundError(f"File '{test_events_file}' not found.")
    
    events = {}
    def load_wide_events(path: pathlib.Path, suffix: str) -> dict:
        local_events = {}
        df = read_csv_robustly(path, engine='python')
        df = std_cols(df)
        for group_name in df.columns:
            local_events[f"{group_name}{suffix}"] = pd.to_datetime(df[group_name].dropna(), errors='coerce').dropna()
        return local_events

    if train_events_file.is_file(): events.update(load_wide_events(train_events_file, "_train"))
    if test_events_file.is_file(): events.update(load_wide_events(test_events_file, "_test"))
    print(f"  • Total unique event groups to process: {len(events)}")
    print("----------------------------------------------------\n")

    # --- Set up output path and Excel writer ---
    OUTPUT_DIR = pathlib.Path("./outcome")
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    out_xlsx = OUTPUT_DIR / f"rd_results_{OUTCOME}.xlsx"
    
    try:
        writer = pd.ExcelWriter(out_xlsx, engine="xlsxwriter")
    except ImportError:
        warnings.warn("'xlsxwriter' not found. To embed plots in Excel, run: pip install xlsxwriter")
        writer = None # Set writer to None if library is missing

    if writer:
        workbook = writer.book
        ws = workbook.add_worksheet("Summary")
        writer.sheets["Summary"] = ws

        headers = ["Asset", "Group", "Event", "Bandwidth", "Tau", "P_value", "CI_Low", "CI_High", "N_obs", "Plot"]
        for c, h in enumerate(headers): ws.write(0, c, h)
        current_row = 1

    # This list will hold data for a potential fallback CSV export
    summary_data_for_csv = []

    # --- Main Loop ---
    for asset, df in loaded_data.items():
        print(f"\n=== {asset} ====================================================")
        if df.empty or OUTCOME not in df.columns:
            warnings.warn(f"'{OUTCOME}' column not found or DataFrame is empty for {asset}. Skipping.")
            continue
            
        for group, dates in events.items():
            for d in dates:
                evt = pd.to_datetime(d)
                if evt not in df["Date"].values: continue
                
                for bw in BANDWIDTHS:
                    try:
                        win, model = rd_design(df, evt, bw, OUTCOME)
                        if len(win) < 10:
                            warnings.warn(f"Skipping {asset} on {d.date()} (±{bw}d) due to insufficient data.")
                            continue
                        
                        ci_lo, ci_hi = get_ci(model, "T")
                        
                        result_dict = {
                            "Asset": asset, "Group": group, "Event": d.strftime('%Y-%m-%d'), "Bandwidth": bw,
                            "Tau": float(model.params["T"]), "P_value": float(model.pvalues["T"]),
                            "CI_Low": ci_lo, "CI_High": ci_hi, "N_obs": len(win)
                        }
                        summary_data_for_csv.append(result_dict)
                        print(f"{d.date()} {group:<22} ±{bw:>2}d τ̂={result_dict['Tau']:.6f} p={result_dict['P_value']:.4f}")

                        if writer:
                            # Write data row to Excel
                            ws.write_row(current_row, 0, [result_dict[h] for h in headers if h != 'Plot'])
                            # Create plot and insert into Excel
                            fig = rd_plot_fig(win, model, evt, bw, asset, OUTCOME)
                            buf = io.BytesIO()
                            fig.savefig(buf, format="png", bbox_inches="tight")
                            buf.seek(0)
                            ws.set_row(current_row, 220) # Set row height to fit the plot
                            ws.insert_image(current_row, len(headers)-1, f"img_{asset}_{d.date()}_{bw}", {"image_data": buf, "x_scale": 0.8, "y_scale": 0.8, 'object_position': 2})
                            plt.close(fig)
                            current_row += 1
                            
                    except Exception as e:
                        print(f"Could not process event {d.date()} for {asset} (±{bw}d). Error: {e}")

    # --- Finalize Export ---
    if writer:
        writer.close()
        print(f"\nDone! Results and charts saved to {out_xlsx}")
    else:
        # Fallback to CSV if xlsxwriter is not available
        print("\nExporting summary data to CSV...")
        df_summary = pd.DataFrame(summary_data_for_csv)
        csv_path = OUTPUT_DIR / f"rd_results_{OUTCOME}.csv"
        df_summary.to_csv(csv_path, index=False)
        print(f"Done! Summary data saved to {csv_path}. Plots were not saved.")