In [1]:
import copy
import pickle
from datetime import datetime as dtdt

import attrs
import numpy as np

from hetero.config import DTYPE, AlgoConfig, DataGenConfig, GroupingConfig
from hetero.datagen import generate_data_from_config
from hetero.policies import AlternativePolicy
from hetero.tasks import (
    beta_estimate_from,
    beta_estimate_from_e2e_learning,
    beta_estimate_from_nongrouped,
    compute_UV_truths,
    compute_V_estimate,
)

group_reward_coeff = np.array(
    [
        [-2.68, 2.68],
        [2.68, -2.68],
    ],
    dtype=DTYPE,
)

action_reward_coeff = [-2.89, 2.89]

data_config_init = dict(
    num_trajectories=50,
    num_time_steps=30,
    group_reward_coeff_override=group_reward_coeff,
    action_reward_coeff=action_reward_coeff,
    num_burnin_steps=100,
    basis_expansion_method="NONE",
    add_intercept_column=True,
)

FEATURE_TYPE = "NONE"

# First time runner: set COMPUTE_TRUTH = True
# Change the flag to False after generated the truth file.
COMPUTE_TRUTH = False
# Change the truth file name if settings are changed.
TRUTH_FILE = f"hetero/data/{FEATURE_TYPE}_truth_20230528_2.68_2.89.pkl"
print("truth file name =", TRUTH_FILE)

time_tag = dtdt.now().strftime("%Y%m%d_%H-%M-%S")
tag = f'N={data_config_init["num_trajectories"]}_T={data_config_init["num_time_steps"]}_{time_tag}'
RESULT_FILE = f"hetero/data/{FEATURE_TYPE}_result_20230528_2.68_2.89_{tag}.pkl"
print("result file name =", RESULT_FILE)

SAVE_RESULT = True
if not SAVE_RESULT:
    print("Result will NOT be saved, only use this for experimental runs!!!")

NUM_EXPERIMENTS = 100

truth file name = hetero/data/NONE_truth_20230528_2.68_2.89.pkl
result file name = hetero/data/NONE_result_20230528_2.68_2.89_N=50_T=30_20230602_12-14-57.pkl


=====================================================================================================
# Algorithm 

- Set configure below.

In [2]:

algo_config = AlgoConfig(
    max_num_iters=10,
    gam=2.7,
    lam=1.6,
    rho=2.0,
    should_remove_outlier=True,
    outlier_lower_perc=2,
    outlier_upper_perc=98,
    nu_coeff=1e-5,
    delta_coeff=1e-5,
    use_group_wise_regression_init=True,
)

pi_eval = AlternativePolicy(2)

grouping_config = GroupingConfig()

In [3]:
if COMPUTE_TRUTH:
    truth_data_config_init = copy.copy(data_config_init)
    truth_data_config_init.pop("num_trajectories")
    us, vs = compute_UV_truths(
        truth_data_config_init,
        algo_config.discount,
        pi_eval,
        num_repeats=10,
        num_trajectories=1000,
    )  # For best results, use num_repeats=10
    u_truth = us.mean(axis=0)
    v_truth = vs.mean(axis=0)

    with open(TRUTH_FILE, "wb") as f:
        pickle.dump(
            dict(
                u_truth=u_truth,
                v_truth=v_truth,
                data_config_dict=truth_data_config_init,
                algo_config_dict=attrs.asdict(algo_config),
            ),
            f,
        )
else:
    with open(TRUTH_FILE, "rb") as f:
        loaded = pickle.load(f)
        u_truth = loaded["u_truth"]
        v_truth = loaded["v_truth"]

In [4]:
beta_ng_list = []
beta_learned_list = []

for i in range(NUM_EXPERIMENTS):
    data_config = DataGenConfig(seed=7531 * (i + 1), **data_config_init)
    data = generate_data_from_config(data_config)
    beta_ng_list.append(
        beta_estimate_from_nongrouped(data, pi_eval, algo_config.discount)
    )
    beta_learned_list.append(
        beta_estimate_from_e2e_learning(data, algo_config, grouping_config, pi_eval)
    )

new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = 0.86676437
MCPImpl: num_above=2511, num_below=2439
beta_solver, min eigen of left matrix = 0.86676437
MCPImpl: num_above=2518, num_below=2432
beta_solver, min eigen of left matrix = 0.86676437
MCPImpl: num_above=2527, num_below=2423
beta_solver, min eigen of left matrix = 0.86676437
MCPImpl: num_above=2523, num_below=2427
beta_solver, min eigen of left matrix = 0.86676437
MCPImpl: num_above=2528, num_below=2422
beta_solver, min eigen of left matrix = 0.86676437
MCPImpl: num_above=2522, num_below=2428
beta_solver, min eigen of left matrix = 0.86676437
MCPImpl: num_above=2529, num_below=2421
beta_solver, min eigen of left matrix = 0.86676437
MCPImpl: num_above=2523, num_below=2427
beta_solver, min eigen of left matrix = 0.86676437
MCPImpl: num_above=2529, num_below=2421
beta_solver, min eigen of left matrix = 0.86676437
MCPImpl: num_above=2523, num_bel

MCPImpl: num_above=2505, num_below=2445
beta_solver, min eigen of left matrix = (0.87653816+0j)
MCPImpl: num_above=2507, num_below=2443
beta_solver, min eigen of left matrix = (0.87653816+0j)
MCPImpl: num_above=2503, num_below=2447
beta_solver, min eigen of left matrix = (0.87653816+0j)
MCPImpl: num_above=2508, num_below=2442
beta_solver, min eigen of left matrix = (0.87653816+0j)
MCPImpl: num_above=2504, num_below=2446
beta_solver, min eigen of left matrix = (0.87653816+0j)
MCPImpl: num_above=2507, num_below=2443
beta_solver, min eigen of left matrix = (0.87653816+0j)
MCPImpl: num_above=2503, num_below=2447
beta_solver, min eigen of left matrix = (0.87653816+0j)
MCPImpl: num_above=2508, num_below=2442
kmeans center = [[ 1.59039991 -3.75640248  2.15666742  3.64117148 -1.43198984  5.04349094]
 [-1.64027955  3.58216251 -2.0239767  -3.6533418   1.66580223 -4.86301987]] and inertia = 110.3817579227633
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=3000

MCPImpl: num_above=2510, num_below=2440
beta_solver, min eigen of left matrix = 0.86972964
MCPImpl: num_above=2510, num_below=2440
beta_solver, min eigen of left matrix = 0.86972964
MCPImpl: num_above=2510, num_below=2440
beta_solver, min eigen of left matrix = 0.86972964
MCPImpl: num_above=2510, num_below=2440
beta_solver, min eigen of left matrix = 0.86972964
MCPImpl: num_above=2510, num_below=2440
beta_solver, min eigen of left matrix = 0.86972964
MCPImpl: num_above=2510, num_below=2440
kmeans center = [[ 1.62110771 -3.85340289  2.33246746  3.73486424 -1.56332172  5.25362893]
 [-1.57841912  3.85640595 -2.045093   -3.77989081  1.55100277 -5.11002202]] and inertia = 122.95525537358985
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = 0.8669041
MCPImpl: num_above=2513, num_below=2437
b

beta_solver, min eigen of left matrix = 0.8748
MCPImpl: num_above=2521, num_below=2429
beta_solver, min eigen of left matrix = 0.8748
MCPImpl: num_above=2517, num_below=2433
beta_solver, min eigen of left matrix = 0.8748
MCPImpl: num_above=2520, num_below=2430
kmeans center = [[ 1.68207514 -3.80702679  2.07091363  3.7125789  -1.62036607  4.99638387]
 [-1.51180489  3.71118035 -2.2105144  -3.59079859  1.63365633 -5.14655225]] and inertia = 114.26495682669054
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = (0.8752411+0j)
MCPImpl: num_above=2511, num_below=2439
beta_solver, min eigen of left matrix = (0.8752411+0j)
MCPImpl: num_above=2512, num_below=2438
beta_solver, min eigen of left matrix = (0.8752411+0j)
MCPImpl: num_above=2512, num_below=2438
beta_solver, min eigen of left matrix = 

MCPImpl: num_above=2515, num_below=2435
beta_solver, min eigen of left matrix = (0.86902237+0j)
MCPImpl: num_above=2512, num_below=2438
beta_solver, min eigen of left matrix = (0.86902237+0j)
MCPImpl: num_above=2513, num_below=2437
beta_solver, min eigen of left matrix = (0.86902237+0j)
MCPImpl: num_above=2512, num_below=2438
beta_solver, min eigen of left matrix = (0.86902237+0j)
MCPImpl: num_above=2512, num_below=2438
kmeans center = [[ 1.64749531 -3.81679786  2.04795284  3.87804705 -1.60724654  5.04171414]
 [-1.53054927  3.83694466 -2.17660327 -3.92032928  1.58289201 -5.09884338]] and inertia = 121.27246348072812
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = (0.87802017+0j)
MCPImpl: num_above=2501, num_below=2449
beta_solver, min eigen of left matrix = (0.87802017+0j)
MCPImpl: n

MCPImpl: num_above=2502, num_below=2448
beta_solver, min eigen of left matrix = (0.87480086-0.0033090718j)
MCPImpl: num_above=2502, num_below=2448
beta_solver, min eigen of left matrix = (0.87480086-0.0033090718j)
MCPImpl: num_above=2502, num_below=2448
beta_solver, min eigen of left matrix = (0.87480086-0.0033090718j)
MCPImpl: num_above=2502, num_below=2448
beta_solver, min eigen of left matrix = (0.87480086-0.0033090718j)
MCPImpl: num_above=2502, num_below=2448
kmeans center = [[ 1.42740491 -3.83455724  2.18138964  3.78771012 -1.54642329  5.06060187]
 [-1.53501078  3.82152962 -2.21100277 -3.80395442  1.50790887 -5.12868358]] and inertia = 97.34208569059724
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = (0.8795301+0j)
MCPImpl: num_above=2504, num_below=2446
beta_solver, min eigen o

MCPImpl: num_above=2503, num_below=2447
beta_solver, min eigen of left matrix = 0.8756393
MCPImpl: num_above=2501, num_below=2449
beta_solver, min eigen of left matrix = 0.8756393
MCPImpl: num_above=2503, num_below=2447
kmeans center = [[ 1.51618777 -3.68773587  2.00974971  3.7452594  -1.55555541  5.03186755]
 [-1.57322999  3.8119353  -2.23541001 -3.93605015  1.57941836 -5.18112743]] and inertia = 130.54240685047438
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = 0.87628335
MCPImpl: num_above=2501, num_below=2449
beta_solver, min eigen of left matrix = 0.87628335
MCPImpl: num_above=2505, num_below=2445
beta_solver, min eigen of left matrix = 0.87628335
MCPImpl: num_above=2501, num_below=2449
beta_solver, min eigen of left matrix = 0.87628335
MCPImpl: num_above=2505, num_below=2445
be

MCPImpl: num_above=2503, num_below=2447
beta_solver, min eigen of left matrix = 0.86909777
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.86909777
MCPImpl: num_above=2503, num_below=2447
kmeans center = [[ 1.54564956 -3.74849519  2.19992541  3.77447116 -1.45112185  5.0257114 ]
 [-1.56308302  3.74892058 -2.19603518 -3.75835288  1.54977209 -4.89525753]] and inertia = 98.98563673897878
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = 0.85649395
MCPImpl: num_above=2506, num_below=2444
beta_solver, min eigen of left matrix = 0.85649395
MCPImpl: num_above=2524, num_below=2426
beta_solver, min eigen of left matrix = 0.85649395
MCPImpl: num_above=2514, num_below=2436
beta_solver, min eigen of left matrix = 0.85649395
MCPImpl: num_above=2525, num_below=2425
b

MCPImpl: num_above=2510, num_below=2440
beta_solver, min eigen of left matrix = (0.8721921-0.0040180376j)
MCPImpl: num_above=2507, num_below=2443
beta_solver, min eigen of left matrix = (0.8721921-0.0040180376j)
MCPImpl: num_above=2510, num_below=2440
kmeans center = [[ 1.52006471 -3.7865052   2.16306047  3.70718876 -1.56147577  5.02886497]
 [-1.5998698   3.63962044 -2.09302506 -3.68820135  1.57882201 -4.94799181]] and inertia = 101.47677567452178
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = (0.8717826+0j)
MCPImpl: num_above=2519, num_below=2431
beta_solver, min eigen of left matrix = (0.8717826+0j)
MCPImpl: num_above=2531, num_below=2419
beta_solver, min eigen of left matrix = (0.8717826+0j)
MCPImpl: num_above=2526, num_below=2424
beta_solver, min eigen of left matrix = (0.871782

MCPImpl: num_above=2533, num_below=2417
beta_solver, min eigen of left matrix = (0.87235755+0j)
MCPImpl: num_above=2533, num_below=2417
kmeans center = [[ 1.57943598 -3.79468256  2.21956709  3.90887828 -1.54563253  5.07873899]
 [-1.56095062  3.69834381 -2.39617209 -3.7650833   1.59454692 -5.13151754]] and inertia = 140.3126037536649
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = (0.8656611+0j)
MCPImpl: num_above=2545, num_below=2405
beta_solver, min eigen of left matrix = (0.8656611+0j)
MCPImpl: num_above=2579, num_below=2371
beta_solver, min eigen of left matrix = (0.8656611+0j)
MCPImpl: num_above=2561, num_below=2389
beta_solver, min eigen of left matrix = (0.8656611+0j)
MCPImpl: num_above=2580, num_below=2370
beta_solver, min eigen of left matrix = (0.8656611+0j)
MCPImpl: num_abo

MCPImpl: num_above=2555, num_below=2395
beta_solver, min eigen of left matrix = (0.87171763+0j)
MCPImpl: num_above=2562, num_below=2388
kmeans center = [[ 1.55353411 -3.82184129  2.22863974  3.80630977 -1.77424935  5.19591304]
 [-1.56949468  3.67039145 -2.42295804 -3.72689653  1.62575987 -5.25392424]] and inertia = 152.97519110866904
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = 0.87857795
MCPImpl: num_above=2502, num_below=2448
beta_solver, min eigen of left matrix = 0.87857795
MCPImpl: num_above=2508, num_below=2442
beta_solver, min eigen of left matrix = 0.87857795
MCPImpl: num_above=2504, num_below=2446
beta_solver, min eigen of left matrix = 0.87857795
MCPImpl: num_above=2508, num_below=2442
beta_solver, min eigen of left matrix = 0.87857795
MCPImpl: num_above=2504, num_below=

MCPImpl: num_above=2528, num_below=2422
kmeans center = [[ 1.61494259 -3.58080633  2.07387169  3.87983737 -1.48825999  4.9039294 ]
 [-1.77567323  3.8940661  -2.11811038 -3.66832142  1.57439556 -5.16776584]] and inertia = 158.40013724972465
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = 0.86995274
MCPImpl: num_above=2507, num_below=2443
beta_solver, min eigen of left matrix = 0.86995274
MCPImpl: num_above=2516, num_below=2434
beta_solver, min eigen of left matrix = 0.86995274
MCPImpl: num_above=2508, num_below=2442
beta_solver, min eigen of left matrix = 0.86995274
MCPImpl: num_above=2515, num_below=2435
beta_solver, min eigen of left matrix = 0.86995274
MCPImpl: num_above=2508, num_below=2442
beta_solver, min eigen of left matrix = 0.86995274
MCPImpl: num_above=2514, num_below=2436


MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.8647538-0.0026059293j)
MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 1.72415298 -3.72785996  2.15153331  3.85906537 -1.52023071  5.01218387]
 [-1.54554075  3.74757849 -2.01580431 -3.86113845  1.57708801 -4.93493765]] and inertia = 108.48425640524609
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = 0.8847187
MCPImpl: num_above=2518, num_below=2432
beta_solver, min eigen of left matrix = 0.8847187
MCPImpl: num_above=2545, num_below=2405
beta_solver, min eigen of left matrix = 0.8847187
MCPImpl: num_above=2528, num_below=2422
beta_solver, min eigen of left matrix = 0.8847187
MCPImpl: num_above=2546, num_below=2404
beta_solver, min eigen of left matrix = 0.8847187
MCPImpl: num_above=2527, num_b

MCPImpl: num_above=2503, num_below=2447
kmeans center = [[ 1.58420078 -3.92155281  2.0573392   3.80478894 -1.55119994  4.96416949]
 [-1.50673089  3.81065757 -2.03813305 -3.84234709  1.56895764 -4.99948783]] and inertia = 98.15706593670272
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = (0.8755347-0.0006017631j)
MCPImpl: num_above=2507, num_below=2443
beta_solver, min eigen of left matrix = (0.8755347-0.0006017631j)
MCPImpl: num_above=2511, num_below=2439
beta_solver, min eigen of left matrix = (0.8755347-0.0006017631j)
MCPImpl: num_above=2510, num_below=2440
beta_solver, min eigen of left matrix = (0.8755347-0.0006017631j)
MCPImpl: num_above=2512, num_below=2438
beta_solver, min eigen of left matrix = (0.8755347-0.0006017631j)
MCPImpl: num_above=2511, num_below=2439
beta_solver, min 

MCPImpl: num_above=2501, num_below=2449
beta_solver, min eigen of left matrix = (0.8807393+0j)
MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 1.58183884 -3.64280915  2.13756668  3.71376275 -1.53257812  5.11317929]
 [-1.48245743  3.70498273 -2.33854005 -3.58774241  1.48452336 -5.0643586 ]] and inertia = 106.3311130162355
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = 0.8783055
MCPImpl: num_above=2539, num_below=2411
beta_solver, min eigen of left matrix = 0.8783055
MCPImpl: num_above=2556, num_below=2394
beta_solver, min eigen of left matrix = 0.8783055
MCPImpl: num_above=2543, num_below=2407
beta_solver, min eigen of left matrix = 0.8783055
MCPImpl: num_above=2557, num_below=2393
beta_solver, min eigen of left matrix = 0.8783055
MCPImpl: num_above=2547, num_below=2403
be

new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = 0.869139
MCPImpl: num_above=2516, num_below=2434
beta_solver, min eigen of left matrix = 0.869139
MCPImpl: num_above=2536, num_below=2414
beta_solver, min eigen of left matrix = 0.869139
MCPImpl: num_above=2520, num_below=2430
beta_solver, min eigen of left matrix = 0.869139
MCPImpl: num_above=2531, num_below=2419
beta_solver, min eigen of left matrix = 0.869139
MCPImpl: num_above=2519, num_below=2431
beta_solver, min eigen of left matrix = 0.869139
MCPImpl: num_above=2531, num_below=2419
beta_solver, min eigen of left matrix = 0.869139
MCPImpl: num_above=2519, num_below=2431
beta_solver, min eigen of left matrix = 0.869139
MCPImpl: num_above=2531, num_below=2419
beta_solver, min eigen of left matrix = 0.869139
MCPImpl: num_above=2519, num_below=2431
beta_solver, min eigen of left matrix = 0.869139
MCPImpl: num_above=2531, num_below=2419
kmeans cente

new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = (0.86381966+0j)
MCPImpl: num_above=2522, num_below=2428
beta_solver, min eigen of left matrix = (0.86381966+0j)
MCPImpl: num_above=2532, num_below=2418
beta_solver, min eigen of left matrix = (0.86381966+0j)
MCPImpl: num_above=2525, num_below=2425
beta_solver, min eigen of left matrix = (0.86381966+0j)
MCPImpl: num_above=2529, num_below=2421
beta_solver, min eigen of left matrix = (0.86381966+0j)
MCPImpl: num_above=2524, num_below=2426
beta_solver, min eigen of left matrix = (0.86381966+0j)
MCPImpl: num_above=2528, num_below=2422
beta_solver, min eigen of left matrix = (0.86381966+0j)
MCPImpl: num_above=2524, num_below=2426
beta_solver, min eigen of left matrix = (0.86381966+0j)
MCPImpl: num_above=2527, num_below=2423
beta_solver, min eigen of left matrix = (0.86381966+0j)
MCPImpl: num_above=2524, num_below=2426
beta_solver, min eigen of left matrix 

In [5]:
beta_learned_list[0].betas

[array([-1.4987195,  4.122486 , -1.9482466, -4.0749016,  1.4737415,
        -4.9391704], dtype=float32),
 array([ 1.523692 , -3.9672484,  2.1617239,  3.761832 , -1.5627939,
         5.026748 ], dtype=float32)]

In [6]:
mu_learned_list = []
sigma_learned_list = []
z_score_learned_list = []
mu_ng_list = []
sigma_ng_list = []
z_score_ng_list = []

for beta_learned, beta_ng in zip(beta_learned_list, beta_ng_list):
    v_mus, v_sigmas = compute_V_estimate(u_truth, beta_learned)
    z_score_learned = [
        (mu - truth) / sigma for mu, truth, sigma in zip(v_mus, v_truth, v_sigmas)
    ]
    mu_learned_list.append(v_mus)
    sigma_learned_list.append(v_sigmas)
    z_score_learned_list.append(z_score_learned)

    ngv_mus, ngv_sigmas = compute_V_estimate(u_truth, beta_ng)
    z_score_ng = [
        (mu - truth) / sigma
        for mu, truth, sigma in zip(
            ngv_mus * len(v_truth), v_truth, ngv_sigmas * len(v_truth)
        )
    ]
    mu_ng_list.append(ngv_mus)
    sigma_ng_list.append(ngv_sigmas)
    z_score_ng_list.append(z_score_ng)

# Reports that average over two groups

In [7]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD
learned_in_threshold_perc = learned_in_threshold.sum() / learned_in_threshold.size
ng_in_threshold_perc = ng_in_threshold.sum() / ng_in_threshold.size
print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)
if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= 0.975 , ng_in_threshold_perc= 0.0


In [8]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list)
ac_mse = np.mean( (mu_learned_list - v_truth)**2 )

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list)
mv_mse = np.mean( (mu_ng_list - v_truth)**2 )

In [9]:
print("ACPE results: (average over groups)")
print(f"ACL: {ac_acl}") 
print(f"MSE: {ac_mse}")
print(f"ECP: {learned_in_threshold_perc}")


print("MVPE results: (average over groups)")
print(f"ACL: {mv_acl}") 
print(f"MSE: {mv_mse}")
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: (average over groups)
ACL: 0.4115450861620926
MSE: 0.008924947120249271
ECP: 0.975
MVPE results: (average over groups)
ACL: 0.6286287159498496
MSE: 12.965012550354004
ECP: 0.0


# Reports that seperate two groups

In [10]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD

learned_in_threshold_perc = np.sum(learned_in_threshold, axis=0) / learned_in_threshold.shape[0]
ng_in_threshold_perc = np.sum(ng_in_threshold, axis=0) / ng_in_threshold.shape[0]

print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)

if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= [0.97 0.98] , ng_in_threshold_perc= [0. 0.]


In [11]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list, axis=0)
ac_mse = np.mean((mu_learned_list - v_truth)**2, axis=0)

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list, axis=0)
mv_mse = np.mean( (mu_ng_list - v_truth)**2, axis=0 )

In [12]:
print("ACPE results: Group1, Group 2")
print(f"MSE: {ac_mse}")
print(f"ACL: {ac_acl}") 
print(f"ECP: {learned_in_threshold_perc}")

print("===")
print("MVPE results: ")
print(f"MSE: {mv_mse}")
print(f"ACL: {mv_acl}") 
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: Group1, Group 2
MSE: [0.00784053 0.01000936]
ACL: [0.41227241 0.41081776]
ECP: [0.97 0.98]
===
MVPE results: 
MSE: [13.213402 12.716619]
ACL: [0.62862872]
ECP: [0. 0.]
