In [1]:
# ============================================================
#  run_census_income_ppci_mean.py  (UPDATED)
#  PPCI / PPCI_label_only / PPI on census_income
#
#  NEW CHANGE (requested now):
#   - ell tuning: pilot-based ell0_hat, then local LOGELL_grid
#   - c_min,c_max depend on sex:
#       sex=1 -> (0.9, 1.1)
#       sex=2 -> (0.8, 1.2)
#  other places unchanged
# ============================================================

import numpy as np
import cupy as cp

from ppi_py.datasets import load_dataset
from ppi_py import ppi_mean_ci, ppi_mean_pointestimate

xp = cp
np.set_printoptions(precision=4, suppress=True)

import os
import sys
import csv

# 兼容：脚本里用 __file__，notebook 里用 os.getcwd()
try:
    current_dir = os.path.dirname(os.path.abspath(__file__))
except NameError:
    current_dir = os.getcwd()

# 上一层目录
parent_dir = os.path.abspath(os.path.join(current_dir, ".."))
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

from conditional_mean_functions import (
    preprocess_X,
    make_age_groups,
    select_logell_by_pilot_logo_cv,
    nw_point_mean_ci_matern52,
    WeightAtX0,
    ppci_conditional_mean,
    ppci_conditional_mean_label_only,
)


def simulate_one_x0(
    Y_total,
    Yhat_total,
    X_total,
    seed,
    B,
    n_label,
    N_unlab_total,
    lam_grid,
    LOGELL_grid,
    n_pilot,
    alpha,
    z_alpha,
    x0_raw,
    theta0_mode="smooth",
    N_t_full=7000,
    age_bin_width=1,
):
    rng = np.random.default_rng(seed)

    X_total_raw = np.asarray(X_total, dtype=float)

    scale_y = 10000.0
    Y_total = np.asarray(Y_total, dtype=float) / scale_y
    Yhat_total = np.asarray(Yhat_total, dtype=float) / scale_y

    # standardize X
    X_scaled_total, X_mean, X_std = preprocess_X(X_total_raw)
    N_total = len(Y_total)

    x0_raw = np.asarray(x0_raw, dtype=float)
    x0_scaled = (x0_raw - X_mean) / X_std

    # exclude exact duplicates of x0
    mask_x0 = np.all(X_total_raw == x0_raw.reshape(1, -1), axis=1)
    n_x0 = int(mask_x0.sum())
    print(f"[x0 = {x0_raw}] exact duplicate observations (excluded): {n_x0}")

    theta0_obs = float(np.mean(Y_total[mask_x0])) if np.any(mask_x0) else np.nan

    # theta0_smooth
    dists = np.linalg.norm(X_scaled_total - x0_scaled.reshape(1, -1), axis=1)
    ell_oracle = float(np.median(dists))

    X_scaled_full_gpu = cp.asarray(X_scaled_total)
    Y_full_gpu = cp.asarray(Y_total)

    theta0_s, _, _ = nw_point_mean_ci_matern52(
        X_l=X_scaled_full_gpu,
        Y_l=Y_full_gpu,
        x0=cp.asarray(x0_scaled),
        ell=ell_oracle,
        z_alpha=z_alpha,
    )
    theta0_smooth = float(theta0_s)

    if theta0_mode == "obs":
        if np.isnan(theta0_obs):
            raise ValueError("theta0_mode='obs' but x0 has no exact obs.")
        theta0 = theta0_obs
        theta0_source = "obs"
    elif theta0_mode == "smooth":
        theta0 = theta0_smooth
        theta0_source = "smooth"
    elif theta0_mode == "mix":
        if not np.isnan(theta0_obs):
            theta0 = theta0_obs
            theta0_source = "obs"
        else:
            theta0 = theta0_smooth
            theta0_source = "smooth"
    else:
        raise ValueError("theta0_mode must be 'obs'/'smooth'/'mix'.")

    # used pool
    X_scaled_used = X_scaled_total[~mask_x0]
    Y_used = Y_total[~mask_x0]
    Yhat_used = Yhat_total[~mask_x0]
    X_used_raw = X_total_raw[~mask_x0]

    N_used = len(Y_used)
    print(f"[x0 = {x0_raw}] N_total={N_total}, N_used={N_used}")

    # unlabeled split
    if N_unlab_total < N_t_full:
        raise ValueError("N_unlab_total must be >= N_t_full.")
    N_use = int(N_unlab_total - N_t_full)

    min_required = n_pilot + N_t_full + (n_label + N_use)
    if N_used < min_required:
        raise ValueError(
            f"N_used={N_used} too small for pilot({n_pilot}) + aux({N_t_full}) "
            f"+ labeled({n_label}) + unlab({N_use})."
        )

    # GPU arrays
    X_scaled_gpu = cp.asarray(X_scaled_used)
    Y_gpu = cp.asarray(Y_used)
    Yhat_gpu = cp.asarray(Yhat_used)

    # global split: pilot / aux / eval_pool
    perm_all = rng.permutation(N_used)
    idx_pilot = perm_all[:n_pilot]
    idx_aux = perm_all[n_pilot : n_pilot + N_t_full]
    idx_eval_pool = perm_all[n_pilot + N_t_full :]

    # pilot ell tuning (LOGO on age groups)
    X_pilot_gpu = X_scaled_gpu[idx_pilot]
    Y_pilot_gpu = Y_gpu[idx_pilot]
    pilot_age_raw = X_used_raw[idx_pilot, 0]
    pilot_groups = make_age_groups(pilot_age_raw, bin_width=age_bin_width)

    # =========================================================
    # NEW: pilot-based ell0_hat -> local LOGELL_grid
    #   c_min,c_max depend on sex (raw):
    #     sex=1 -> (0.9, 1.1)
    #     sex=2 -> (0.8, 1.2)
    # =========================================================
    sex_raw = int(x0_raw[1])
    if sex_raw == 1:
        c_min, c_max = 0.9, 1.1
    elif sex_raw == 2:
        c_min, c_max = 0.8, 1.2
    else:
        raise ValueError(f"Unexpected sex value in x0_raw[1]: {x0_raw[1]} (expect 1 or 2)")

    x0_gpu = cp.asarray(x0_scaled)
    d_pilot = cp.linalg.norm(X_pilot_gpu - x0_gpu[None, :], axis=1)
    ell0_hat = float(cp.asnumpy(cp.median(d_pilot)))
    ell0_hat = max(1e-8, ell0_hat)

    ell_low = max(1e-8, c_min * ell0_hat)
    ell_high = c_max * ell0_hat

    n_grid = int(len(LOGELL_grid)) if LOGELL_grid is not None else 50
    LOGELL_grid_local = np.linspace(np.log(ell_low), np.log(ell_high), n_grid)

    ell_star, pilot_table = select_logell_by_pilot_logo_cv(
        LOGELL_grid=LOGELL_grid_local,
        X_pilot=X_pilot_gpu,
        Y_pilot=Y_pilot_gpu,
        group_labels=pilot_groups,
    )

    # build f_hat + lambda_hat
    X_aux = X_scaled_gpu[idx_aux]
    locw = WeightAtX0(X_aux=X_aux, x0=x0_scaled, ell=ell_star)
    lam_hat, alpha_hat, lam_grid_sorted = locw.select_lambda_lcurve(
        lam_grid=lam_grid, normalize=False, make_plot=False
    )
    f_hat = locw.make_f_hat(alpha_hat)

    # MC lists
    thetas_ppci, covered_ppci, widths_ppci = [], [], []
    sigma2_YmA_list, sigma2_A_list = [], []

    thetas_ppilo, covered_ppilo, widths_ppilo = [], [], []
    sigma2_Y_list = []

    thetas_ppi, covered_ppi, widths_ppi = [], [], []

    for _ in range(B):
        perm_eval = rng.permutation(idx_eval_pool)
        idx_l = perm_eval[:n_label]
        idx_u = perm_eval[n_label : n_label + N_use]

        X_l = X_scaled_gpu[idx_l]
        Y_l = Y_gpu[idx_l]
        A_l = Yhat_gpu[idx_l]

        X_u = X_scaled_gpu[idx_u]
        A_u = Yhat_gpu[idx_u]

        # PPCI
        th, se, (lo, up), ex = ppci_conditional_mean(
            X_l=X_l, Y_l=Y_l, X_u=X_u, f_hat=f_hat, A_l=A_l, A_u=A_u,
            z_alpha=z_alpha, return_extras=True
        )
        thetas_ppci.append(th)
        covered_ppci.append(float(lo <= theta0 <= up))
        widths_ppci.append(float(up - lo))
        sigma2_YmA_list.append(ex["sigma2_Y_minus_A"])
        sigma2_A_list.append(ex["sigma2_A"])

        # label-only
        th2, se2, (lo2, up2), ex2 = ppci_conditional_mean_label_only(
            X_l=X_l, Y_l=Y_l, f_hat=f_hat, z_alpha=z_alpha, return_extras=True
        )
        thetas_ppilo.append(th2)
        covered_ppilo.append(float(lo2 <= theta0 <= up2))
        widths_ppilo.append(float(up2 - lo2))
        sigma2_Y_list.append(ex2["sigma2_Y"])

        # PPI (global mean)
        Y_l_cpu = cp.asnumpy(Y_l).ravel()
        A_l_cpu = cp.asnumpy(A_l).ravel()
        A_u_cpu = cp.asnumpy(A_u).ravel()

        ppi_lo, ppi_up = ppi_mean_ci(Y_l_cpu, A_l_cpu, A_u_cpu, alpha=alpha)
        ppi_th = ppi_mean_pointestimate(Y_l_cpu, A_l_cpu, A_u_cpu)

        thetas_ppi.append(float(ppi_th))
        covered_ppi.append(float(ppi_lo <= theta0 <= ppi_up))
        widths_ppi.append(float(ppi_up - ppi_lo))

    # summarize arrays
    thetas_ppci = np.asarray(thetas_ppci, dtype=float)
    covered_ppci = np.asarray(covered_ppci, dtype=float)
    widths_ppci = np.asarray(widths_ppci, dtype=float)
    sigma2_YmA_arr = np.asarray(sigma2_YmA_list, dtype=float)
    sigma2_A_arr = np.asarray(sigma2_A_list, dtype=float)

    thetas_ppilo = np.asarray(thetas_ppilo, dtype=float)
    covered_ppilo = np.asarray(covered_ppilo, dtype=float)
    widths_ppilo = np.asarray(widths_ppilo, dtype=float)
    sigma2_Y_arr = np.asarray(sigma2_Y_list, dtype=float)

    thetas_ppi = np.asarray(thetas_ppi, dtype=float)
    covered_ppi = np.asarray(covered_ppi, dtype=float)
    widths_ppi = np.asarray(widths_ppi, dtype=float)

    out = {
        # x0 / oracle
        "age": int(x0_raw[0]),
        "sex": int(x0_raw[1]),
        "x0_raw": x0_raw.tolist(),
        "theta0_mode": theta0_mode,
        "theta0_source": theta0_source,
        "theta0": float(theta0),
        "theta0_obs": float(theta0_obs) if not np.isnan(theta0_obs) else np.nan,
        "theta0_smooth": float(theta0_smooth),
        "ell_oracle": float(ell_oracle),

        # tuning diagnostics
        "c_min": float(c_min),
        "c_max": float(c_max),
        "ell0_hat": float(ell0_hat),
        "ell_low": float(ell_low),
        "ell_high": float(ell_high),
        "ell_star": float(ell_star),
        "lambda_hat": float(lam_hat),

        # budgets
        "B": int(B),
        "n_label": int(n_label),
        "n_pilot": int(n_pilot),
        "N_unlab_total": int(N_unlab_total),
        "N_t_full": int(N_t_full),
        "N_use": int(N_use),
        "age_bin_width": int(age_bin_width),

        # PPCI metrics
        "PPCI_theta_mean": float(np.mean(thetas_ppci)),
        "PPCI_theta_rmse": float(np.sqrt(np.mean((thetas_ppci - theta0) ** 2))),
        "PPCI_coverage": float(np.mean(covered_ppci)),
        "PPCI_avg_ci_width": float(np.mean(widths_ppci)),

        # sigma2 (MC mean/std)
        "PPCI_sigma2_Y_minus_A_mean": float(np.mean(sigma2_YmA_arr)),
        "PPCI_sigma2_Y_minus_A_std": float(np.std(sigma2_YmA_arr, ddof=1)) if B > 1 else 0.0,
        "PPCI_sigma2_A_mean": float(np.mean(sigma2_A_arr)),
        "PPCI_sigma2_A_std": float(np.std(sigma2_A_arr, ddof=1)) if B > 1 else 0.0,

        # label-only metrics
        "PPCILO_theta_mean": float(np.mean(thetas_ppilo)),
        "PPCILO_theta_rmse": float(np.sqrt(np.mean((thetas_ppilo - theta0) ** 2))),
        "PPCILO_coverage": float(np.mean(covered_ppilo)),
        "PPCILO_avg_ci_width": float(np.mean(widths_ppilo)),

        # label-only sigma2 (MC mean/std)
        "PPCILO_sigma2_Y_mean": float(np.mean(sigma2_Y_arr)),
        "PPCILO_sigma2_Y_std": float(np.std(sigma2_Y_arr, ddof=1)) if B > 1 else 0.0,

        # PPI metrics
        "PPI_theta_mean": float(np.mean(thetas_ppi)),
        "PPI_theta_rmse": float(np.sqrt(np.mean((thetas_ppi - theta0) ** 2))),
        "PPI_coverage": float(np.mean(covered_ppi)),
        "PPI_avg_ci_width": float(np.mean(widths_ppi)),

        # scale
        "scale_y": float(scale_y),
    }
    return out


def write_results_to_csv(results_list, csv_path):
    if len(results_list) == 0:
        raise ValueError("results_list is empty; nothing to write.")
    fieldnames = list(results_list[0].keys())
    with open(csv_path, "w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for row in results_list:
            writer.writerow(row)
    print(f"\n[Saved] {csv_path}")


if __name__ == "__main__":
    # choose GPU
    cp.cuda.Device(1).use()

    dataset_folder = "./data/"
    data = load_dataset(dataset_folder, "census_income")

    Y_total = data["Y"]
    Yhat_total = data["Yhat"]
    X_total = data["X"]  # columns: [age, sex]

    seed = 2025
    B = 500
    n_label = 200

    N_unlab_total = 10000
    N_t_full = 7000

    lam_grid = np.logspace(np.log10(0.1 / n_label), np.log10(10.0 / n_label), 50)

    n_pilot = 200
    alpha = 0.05
    z_alpha = 1.96

    # kept (only used for grid size now)
    LOGELL_grid_sex1 = np.linspace(np.log(0.35), np.log(4.0), 50)
    LOGELL_grid_sex2 = np.linspace(np.log(0.35), np.log(4.0), 50)

    age_bin_width = 1  # LOGO by raw age

    all_rows = []

    for sex_val in [1.0, 2.0]:
        print("\n" + "=" * 80)
        print(f"======================  Running for sex = {sex_val:.0f}  ======================")
        print("=" * 80)

        mask_sex = (X_total[:, 1] == sex_val)
        Y_sex = Y_total[mask_sex]
        Yhat_sex = Yhat_total[mask_sex]
        X_sex = X_total[mask_sex]

        LOGELL_grid = LOGELL_grid_sex1 if sex_val == 1.0 else LOGELL_grid_sex2

        ages = list(range(70, 101))
        for idx, age in enumerate(ages, 1):
            x0_raw = np.array([float(age), float(sex_val)])

            print("\n" + "-" * 80)
            print(f"[sex={int(sex_val)}] x0 #{idx}: age={age}")

            out = simulate_one_x0(
                Y_total=Y_sex,
                Yhat_total=Yhat_sex,
                X_total=X_sex,
                seed=seed,
                B=B,
                n_label=n_label,
                N_unlab_total=N_unlab_total,
                lam_grid=lam_grid,
                LOGELL_grid=LOGELL_grid,
                n_pilot=n_pilot,
                alpha=alpha,
                z_alpha=z_alpha,
                x0_raw=x0_raw,
                theta0_mode="smooth",
                N_t_full=N_t_full,
                age_bin_width=age_bin_width,
            )

            all_rows.append(out)

            print(
                f"theta0={out['theta0']:.6f}  "
                f"ell_oracle={out['ell_oracle']:.4f}  "
                f"ell0_hat={out['ell0_hat']:.4f}  "
                f"c=[{out['c_min']:.2f},{out['c_max']:.2f}]  "
                f"ell_star={out['ell_star']:.4f}  "
                f"lambda_hat={out['lambda_hat']:.6g}"
            )

            print(
                "PPCI : "
                f"rmse={out['PPCI_theta_rmse']:.6f}  "
                f"cov={out['PPCI_coverage']:.3f}  width={out['PPCI_avg_ci_width']:.6f}  "
                f"s2_YmA={out['PPCI_sigma2_Y_minus_A_mean']:.6g}  "
                f"s2_A={out['PPCI_sigma2_A_mean']:.6g}"
            )
            print(
                "LO   : "
                f"rmse={out['PPCILO_theta_rmse']:.6f}  "
                f"cov={out['PPCILO_coverage']:.3f}  width={out['PPCILO_avg_ci_width']:.6f}  "
                f"s2_Y={out['PPCILO_sigma2_Y_mean']:.6g}"
            )
            print(
                "PPI  : "
                f"rmse={out['PPI_theta_rmse']:.6f}  "
                f"cov={out['PPI_coverage']:.3f}  width={out['PPI_avg_ci_width']:.6f}"
            )

    # write one CSV for all x0
    save_dir = os.path.join(current_dir, "results")
    os.makedirs(save_dir, exist_ok=True)
    csv_path = os.path.join(save_dir, "census_income_ppci_mean_all_x0.csv")
    write_results_to_csv(all_rows, csv_path)


  from .autonotebook import tqdm as notebook_tqdm




--------------------------------------------------------------------------------
[sex=1] x0 #1: age=70
[x0 = [70.  1.]] exact duplicate observations (excluded): 1990
[x0 = [70.  1.]] N_total=187471, N_used=185481
theta0=6.886993  ell_oracle=1.3541  ell0_hat=1.3104  c=[0.90,1.10]  ell_star=1.1794  lambda_hat=0.0101179
PPCI : rmse=0.794959  cov=0.946  width=2.964835  s2_YmA=28.1518  s2_A=4.62462
LO   : rmse=0.862868  cov=0.948  width=3.331407  s2_Y=35.9491
PPI  : rmse=1.701188  cov=0.160  width=1.979438

--------------------------------------------------------------------------------
[sex=1] x0 #2: age=71
[x0 = [71.  1.]] exact duplicate observations (excluded): 1838
[x0 = [71.  1.]] N_total=187471, N_used=185633
theta0=6.852685  ell_oracle=1.3978  ell0_hat=1.5944  c=[0.90,1.10]  ell_star=1.4349  lambda_hat=0.0101179
PPCI : rmse=0.706139  cov=0.928  width=2.789535  s2_YmA=31.9227  s2_A=5.22384
LO   : rmse=0.788063  cov=0.934  width=3.175475  s2_Y=41.6331
PPI  : rmse=1.618327  cov=0.222