In [None]:
import numpy as np
import pandas as pd
from typing import Any, Dict, List, Optional


def _fmt_float7(x):
    """Format float to 7 decimal places (like Stata), with thousands separator."""
    if x is None or (isinstance(x, float) and np.isnan(x)):
        return ""
    try:
        v = float(x)
    except Exception:
        return ""
    # Use 7 decimals like Stata, keep trailing zeros for alignment
    return f"{v:.7f}"


def _fmt_float5(x):
    """Format float to 5 decimal places, with thousands separator."""
    if x is None or (isinstance(x, float) and np.isnan(x)):
        return ""
    try:
        v = float(x)
    except Exception:
        return ""
    return f"{v:.5f}"


def _fmt_int0(x):
    """R-like integer formatting with thousands separators."""
    if x is None or (isinstance(x, float) and np.isnan(x)):
        return ""
    try:
        v = float(x)
    except Exception:
        return ""
    return f"{v:,.0f}"


def _get(obj: Any, key: str, default=None):
    if isinstance(obj, dict):
        return obj.get(key, default)
    return getattr(obj, key, default)


def strdisplay(label: str, value):
    '''
    Mimic the aligned 'label = value' lines in R's strdisplay().
    '''
    ltot1, ltot2 = 16, 16
    label_out = (label + ' ' * max(0, ltot1 - len(label)))[:ltot1]
    if isinstance(value, str):
        v = value
    else:
        try:
            v = f"{float(value):.0f}"
        except Exception:
            v = str(value)
    value_out = ((' ' * max(0, ltot2 - len(v))) + v)[-ltot2:]
    print(f"{label_out} = {value_out}")


def mat_print(mat, name: Optional[str] = None):
    """
    Stata-like table printing:
      - first 4 columns (Estimate, SE, LB CI, UB CI): 7 decimals fixed
      - remaining columns (Switchers, Stayers): integer counts with thousands separators
    """
    if isinstance(mat, np.ndarray):
        df = pd.DataFrame(mat)
    else:
        df = mat.copy() if isinstance(mat, pd.DataFrame) else pd.DataFrame(mat)

    if df.shape[0] == 0:
        print("(empty)")
        return

    dis = df.copy()
    for j in range(dis.shape[1]):
        col = pd.to_numeric(dis.iloc[:, j], errors="coerce")
        if j <= 3:
            # Use 7 decimals for Estimate, SE, LB CI, UB CI
            dis.iloc[:, j] = col.map(_fmt_float7)
        else:
            # Use integer format for counts
            dis.iloc[:, j] = col.map(_fmt_int0)

    with pd.option_context("display.max_rows", 200, "display.max_columns", 200, "display.width", 200):
        if name is not None and (dis.shape[0] == 1):
            dis.index = [str(name).upper()]
        print(dis)


def tab_print(mat):
    """
    Stata-like printing for small numeric matrices (all columns to 5 decimals).
    Used for the "AOSS vs WAOSS difference test" table.
    """
    if isinstance(mat, np.ndarray):
        df = pd.DataFrame(mat)
    else:
        df = mat.copy() if isinstance(mat, pd.DataFrame) else pd.DataFrame(mat)

    if df.shape[0] == 0:
        print("(empty)")
        return

    dis = df.copy()
    for j in range(dis.shape[1]):
        col = pd.to_numeric(dis.iloc[:, j], errors="coerce")
        dis.iloc[:, j] = col.map(_fmt_float5)

    with pd.option_context("display.max_rows", 200, "display.max_columns", 200, "display.width", 200):
        print(dis)


def _infer_estimators(args: Dict[str, Any]) -> List[str]:
    if args.get('estimator') is None and args.get('Z') is None:
        return ['aoss', 'waoss']
    if args.get('estimator') is None and args.get('Z') is not None:
        return ['ivwaoss']
    return list(args.get('estimator') or [])


def summary_did_multiplegt_stat(obj: Dict[str, Any]):
    '''
    Console summary similar to Stata's did_multiplegt_stat output.
    '''
    args = _get(obj, 'args', {}) or {}
    estim_list = _infer_estimators(args)

    by = args.get('by')
    by_fd = args.get('by_fd')
    if by is None and by_fd is None:
        by_levs = ['_no_by']
        by_obj = ['results']
        header = None
    else:
        by_levs = list(_get(obj, 'by_levels', []))
        by_obj = [f"results_by_{j+1}" for j in range(len(by_levs))]
        by_name = 'quantiles' if by_fd is not None else by
        header = f"## did_multiplegt_stat by {by_name} ({len(by_levs):.0f} levels)"

    if header is not None:
        print('\n' + '#' * 70)
        print(header)
        print('#' * 70 + '\n')

    estims = {'aoss': 0, 'waoss': 1, 'ivwaoss': 2}

    for idx, key in enumerate(by_obj):
        print_obj = _get(obj, key, None)
        if print_obj is None:
            continue

        if by_levs[idx] != '_no_by':
            msg = f" By level: {by_levs[idx]}"
            print('#' * max(1, 70 - len(msg)) + msg)

        print('\n' + '-' * 35)

        table = _get(print_obj, 'table', None)
        pairs = int(_get(print_obj, 'pairs', 1) or 1)


        def _select_rows(tbl: pd.DataFrame, est_label: str, l_bound: int, u_bound: int, disaggregate: bool, pairs: int) -> pd.DataFrame:
            """Robust row selection by index label."""
            if isinstance(tbl, pd.DataFrame) and hasattr(tbl, "index"):
                try:
                    idx_norm = pd.Index([str(x).strip().upper() for x in tbl.index])
                    target = str(est_label).strip().upper()
                    pos = np.where(idx_norm == target)[0]
                    if len(pos) > 0:
                        if disaggregate:
                            return tbl.iloc[pos[:max(1, int(pairs))]].copy()
                        return tbl.iloc[pos[:1]].copy()
                except Exception:
                    pass
            return tbl.iloc[l_bound:u_bound].copy()

        def _N_from_row(row):
            if isinstance(row, pd.Series):
                if {'N_switchers', 'N_stayers'}.issubset(row.index):
                    return float(row['N_switchers'] + row['N_stayers'])
                if len(row) >= 6:
                    return float(row.iloc[4] + row.iloc[5])
            return float('nan')

        N = np.nan
        if isinstance(table, pd.DataFrame) and len(table) > 0:
            if {'N_switchers', 'N_stayers'}.issubset(set(table.columns)):
                try:
                    tmp = table[['N_switchers', 'N_stayers']].apply(pd.to_numeric, errors='coerce')
                    s = (tmp['N_switchers'] + tmp['N_stayers']).dropna()
                    if len(s):
                        N = float(s.iloc[0])
                except Exception:
                    N = np.nan
            if np.isnan(N):
                if 'ivwaoss' in estim_list:
                    r = 2 * pairs
                else:
                    r = pairs if 'waoss' in estim_list else 0
                r = min(max(r, 0), len(table) - 1)
                N = _N_from_row(table.iloc[r])

        strdisplay('N', N)

        methods = {'ra': 'Reg. Adjustment', 'dr': 'Doubly Robust', 'ps': 'Propensity Score'}
        method = args.get('estimation_method') or 'dr'
        if bool(args.get('exact_match')):
            method = 'ra'
        for m in ('waoss', 'ivwaoss'):
            if m in estim_list:
                strdisplay(f"{m.upper()} Method", methods.get(method, method))

        if not bool(args.get('exact_match')) and args.get('order') is not None:
            strdisplay('Polynomial Order', args.get('order'))

        if bool(args.get('exact_match')):
            strdisplay('Common Support', 'Exact Matching')
        if bool(args.get('noextrapolation')):
            strdisplay('Common Support', 'No Extrapolation')

        if args.get('switchers') is not None:
            strdisplay('Switchers', str(args.get('switchers')))

        print('-' * 35)

        cluster = args.get('cluster')
        ID = args.get('ID')
        if cluster is not None and cluster != ID:
            n_clusters = _get(print_obj, 'n_clusters', None)
            if isinstance(n_clusters, (list, tuple)) and len(n_clusters) > 0:
                nc = n_clusters[0]
            else:
                nc = n_clusters
            if nc is not None:
                print(f"(Std. errors adjusted for {int(nc):.0f} clusters in {cluster})")

        for t in ('aoss', 'waoss', 'ivwaoss'):
            if t not in estim_list:
                continue

            print('\n' + '-' * 70)
            print(' ' * 20 + f"Estimation of {t.upper()}(s)")
            print('-' * 70)

            if isinstance(table, pd.DataFrame):
                l_bound = estims[t] * pairs
                u_bound = l_bound + (pairs if bool(args.get('disaggregate')) else 1)
                mat_sel = _select_rows(table, t, l_bound, u_bound, bool(args.get('disaggregate')), pairs)
                mat_print(mat_sel)
            else:
                print('(no table to print)')

            if bool(args.get('placebo')):
                table_p = _get(print_obj, 'table_placebo', None)
                if isinstance(table_p, pd.DataFrame):
                    print('\n' + '-' * 70)
                    print(' ' * 15 + f"Estimation of {t.upper()}(s) - Placebo")
                    print('-' * 70)
                    mat_sel_p = _select_rows(table_p, t, l_bound, u_bound, bool(args.get('disaggregate')), pairs)
                    mat_print(mat_sel_p)

        if bool(args.get('aoss_vs_waoss')):
            print('\n' + '-' * 70)
            print(' ' * 15 + 'Difference test: AOSS and WAOSS')
            print('-' * 70)
            print('H0: AOSS = WAOSS')
            diff_tab = _get(print_obj, 'aoss_vs_waoss', None)
            if diff_tab is not None:
                tab_print(diff_tab)


def print_did_multiplegt_stat(obj: Dict[str, Any]):
    '''Alias to summary, matching R's print() method.'''
    summary_did_multiplegt_stat(obj)


if __name__ == '__main__':
    demo_table = pd.DataFrame(
        [[0.1, 0.2, -0.3, 0.5, 10, 90]],
        columns=['pe', 'sd', 'lb', 'ub', 'N_switchers', 'N_stayers']
    )
    demo = {'args': {'estimator': ['aoss'], 'order': 1, 'disaggregate': False},
            'results': {'table': demo_table, 'pairs': 1}}
    print_did_multiplegt_stat(demo)
