In [1]:
import copy
import pickle
from datetime import datetime as dtdt

import attrs
import numpy as np

from hetero.config import DTYPE, AlgoConfig, DataGenConfig, GroupingConfig
from hetero.datagen import generate_data_from_config
from hetero.policies import AlternativePolicy
from hetero.tasks import (
    beta_estimate_from,
    beta_estimate_from_e2e_learning,
    beta_estimate_from_nongrouped,
    compute_UV_truths,
    compute_V_estimate,
)

group_reward_coeff = np.array(
    [
        [-2.68, 2.68],
        [2.68, -2.68],
    ],
    dtype=DTYPE,
)

action_reward_coeff = [-2.89, 2.89]

data_config_init = dict(
    num_trajectories=50,
    num_time_steps=20,
    group_reward_coeff_override=group_reward_coeff,
    action_reward_coeff=action_reward_coeff,
    num_burnin_steps=100,
    basis_expansion_method="NONE",
    add_intercept_column=True,
)

FEATURE_TYPE = "NONE"

# First time runner: set COMPUTE_TRUTH = True
# Change the flag to False after generated the truth file.
COMPUTE_TRUTH = True
# Change the truth file name if settings are changed.
TRUTH_FILE = f"hetero/data/{FEATURE_TYPE}_truth_20230528_2.68_2.89.pkl"
print("truth file name =", TRUTH_FILE)

time_tag = dtdt.now().strftime("%Y%m%d_%H-%M-%S")
tag = f'N={data_config_init["num_trajectories"]}_T={data_config_init["num_time_steps"]}_{time_tag}'
RESULT_FILE = f"hetero/data/{FEATURE_TYPE}_result_20230528_2.68_2.89_{tag}.pkl"
print("result file name =", RESULT_FILE)

SAVE_RESULT = True
if not SAVE_RESULT:
    print("Result will NOT be saved, only use this for experimental runs!!!")

NUM_EXPERIMENTS = 100

truth file name = hetero/data/NONE_truth_20230528_2.68_2.89.pkl
result file name = hetero/data/NONE_result_20230528_2.68_2.89_N=50_T=20_20230602_12-14-22.pkl


=====================================================================================================
# Algorithm 

- Set configure below.

In [2]:

algo_config = AlgoConfig(
    max_num_iters=10,
    gam=2.7,
    lam=1.6,
    rho=2.0,
    should_remove_outlier=True,
    outlier_lower_perc=2,
    outlier_upper_perc=98,
    nu_coeff=1e-5,
    delta_coeff=1e-5,
    use_group_wise_regression_init=True,
)

pi_eval = AlternativePolicy(2)

grouping_config = GroupingConfig()

In [3]:
if COMPUTE_TRUTH:
    truth_data_config_init = copy.copy(data_config_init)
    truth_data_config_init.pop("num_trajectories")
    us, vs = compute_UV_truths(
        truth_data_config_init,
        algo_config.discount,
        pi_eval,
        num_repeats=10,
        num_trajectories=1000,
    )  # For best results, use num_repeats=10
    u_truth = us.mean(axis=0)
    v_truth = vs.mean(axis=0)

    with open(TRUTH_FILE, "wb") as f:
        pickle.dump(
            dict(
                u_truth=u_truth,
                v_truth=v_truth,
                data_config_dict=truth_data_config_init,
                algo_config_dict=attrs.asdict(algo_config),
            ),
            f,
        )
else:
    with open(TRUTH_FILE, "rb") as f:
        loaded = pickle.load(f)
        u_truth = loaded["u_truth"]
        v_truth = loaded["v_truth"]

In [4]:
beta_ng_list = []
beta_learned_list = []

for i in range(NUM_EXPERIMENTS):
    data_config = DataGenConfig(seed=7531 * (i + 1), **data_config_init)
    data = generate_data_from_config(data_config)
    beta_ng_list.append(
        beta_estimate_from_nongrouped(data, pi_eval, algo_config.discount)
    )
    beta_learned_list.append(
        beta_estimate_from_e2e_learning(data, algo_config, grouping_config, pi_eval)
    )

new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
beta_solver, min eigen of left matrix = (0.8896219+0j)
MCPImpl: num_above=2649, num_below=2301
beta_solver, min eigen of left matrix = (0.8896219+0j)
MCPImpl: num_above=2945, num_below=2005
beta_solver, min eigen of left matrix = (0.8896219+0j)
MCPImpl: num_above=2747, num_below=2203
beta_solver, min eigen of left matrix = (0.8896219+0j)
MCPImpl: num_above=2945, num_below=2005
beta_solver, min eigen of left matrix = (0.8896219+0j)
MCPImpl: num_above=2780, num_below=2170
beta_solver, min eigen of left matrix = (0.8896219+0j)
MCPImpl: num_above=2941, num_below=2009
beta_solver, min eigen of left matrix = (0.8896219+0j)
MCPImpl: num_above=2787, num_below=2163
beta_solver, min eigen of left matrix = (0.8896219+0j)
MCPImpl: num_above=2943, num_below=2007
beta_solver, min eigen of left matrix = (0.8896219+0j)
MCPImpl: num_above=2792, num_below=2158
beta_solver, min eigen of left matrix = (0.8896

MCPImpl: num_above=2628, num_below=2322
beta_solver, min eigen of left matrix = 0.8779244
MCPImpl: num_above=2911, num_below=2039
beta_solver, min eigen of left matrix = 0.8779244
MCPImpl: num_above=2732, num_below=2218
beta_solver, min eigen of left matrix = 0.8779244
MCPImpl: num_above=2926, num_below=2024
beta_solver, min eigen of left matrix = 0.8779244
MCPImpl: num_above=2754, num_below=2196
beta_solver, min eigen of left matrix = 0.8779244
MCPImpl: num_above=2932, num_below=2018
beta_solver, min eigen of left matrix = 0.8779244
MCPImpl: num_above=2756, num_below=2194
beta_solver, min eigen of left matrix = 0.8779244
MCPImpl: num_above=2926, num_below=2024
beta_solver, min eigen of left matrix = 0.8779244
MCPImpl: num_above=2761, num_below=2189
beta_solver, min eigen of left matrix = 0.8779244
MCPImpl: num_above=2927, num_below=2023
kmeans center = [[ 1.60170456 -3.65373016  2.19708182  3.75257216 -1.54507987  5.28022813]
 [-1.60188571  3.98075496 -2.03909162 -3.84720504  1.591129

MCPImpl: num_above=2908, num_below=2042
beta_solver, min eigen of left matrix = 0.89288497
MCPImpl: num_above=3114, num_below=1836
beta_solver, min eigen of left matrix = 0.89288497
MCPImpl: num_above=2937, num_below=2013
beta_solver, min eigen of left matrix = 0.89288497
MCPImpl: num_above=3116, num_below=1834
beta_solver, min eigen of left matrix = 0.89288497
MCPImpl: num_above=2942, num_below=2008
beta_solver, min eigen of left matrix = 0.89288497
MCPImpl: num_above=3110, num_below=1840
beta_solver, min eigen of left matrix = 0.89288497
MCPImpl: num_above=2941, num_below=2009
beta_solver, min eigen of left matrix = 0.89288497
MCPImpl: num_above=3111, num_below=1839
kmeans center = [[ 1.40448221 -3.82603701  2.37823504  3.56146831 -1.52806388  5.35466312]
 [-1.57877566  3.72347121 -1.74072811 -3.75127891  1.42873432 -4.87364762]] and inertia = 503.0148323146308
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=2000 matches number of records
new_labe

MCPImpl: num_above=2906, num_below=2044
beta_solver, min eigen of left matrix = (0.8736882-0.0039033543j)
MCPImpl: num_above=3032, num_below=1918
beta_solver, min eigen of left matrix = (0.8736882-0.0039033543j)
MCPImpl: num_above=2903, num_below=2047
beta_solver, min eigen of left matrix = (0.8736882-0.0039033543j)
MCPImpl: num_above=3034, num_below=1916
beta_solver, min eigen of left matrix = (0.8736882-0.0039033543j)
MCPImpl: num_above=2910, num_below=2040
beta_solver, min eigen of left matrix = (0.8736882-0.0039033543j)
MCPImpl: num_above=3039, num_below=1911
kmeans center = [[ 1.5698425  -4.40654731  2.0640857   3.56399959 -1.65745283  5.01945122]
 [-1.724189    3.73926812 -1.9507166  -3.59190992  1.57426682 -4.67678358]] and inertia = 443.5971313712804
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
beta_solver, min 

MCPImpl: num_above=2719, num_below=2231
beta_solver, min eigen of left matrix = 0.8812564
MCPImpl: num_above=3041, num_below=1909
beta_solver, min eigen of left matrix = 0.8812564
MCPImpl: num_above=2881, num_below=2069
beta_solver, min eigen of left matrix = 0.8812564
MCPImpl: num_above=3069, num_below=1881
beta_solver, min eigen of left matrix = 0.8812564
MCPImpl: num_above=2931, num_below=2019
beta_solver, min eigen of left matrix = 0.8812564
MCPImpl: num_above=3078, num_below=1872
beta_solver, min eigen of left matrix = 0.8812564
MCPImpl: num_above=2953, num_below=1997
beta_solver, min eigen of left matrix = 0.8812564
MCPImpl: num_above=3083, num_below=1867
beta_solver, min eigen of left matrix = 0.8812564
MCPImpl: num_above=2956, num_below=1994
beta_solver, min eigen of left matrix = 0.8812564
MCPImpl: num_above=3082, num_below=1868
kmeans center = [[ 1.50196351 -3.74563864  2.43609418  4.0702477  -1.53254481  5.20632854]
 [-1.54410425  3.9599828  -2.13829816 -3.6938195   1.510531

MCPImpl: num_above=2833, num_below=2117
beta_solver, min eigen of left matrix = (0.8851531+0j)
MCPImpl: num_above=2688, num_below=2262
beta_solver, min eigen of left matrix = (0.8851531+0j)
MCPImpl: num_above=2851, num_below=2099
beta_solver, min eigen of left matrix = (0.8851531+0j)
MCPImpl: num_above=2713, num_below=2237
beta_solver, min eigen of left matrix = (0.8851531+0j)
MCPImpl: num_above=2866, num_below=2084
beta_solver, min eigen of left matrix = (0.8851531+0j)
MCPImpl: num_above=2723, num_below=2227
beta_solver, min eigen of left matrix = (0.8851531+0j)
MCPImpl: num_above=2854, num_below=2096
beta_solver, min eigen of left matrix = (0.8851531+0j)
MCPImpl: num_above=2723, num_below=2227
beta_solver, min eigen of left matrix = (0.8851531+0j)
MCPImpl: num_above=2857, num_below=2093
kmeans center = [[ 1.45258296 -3.78724639  2.06929957  3.80399995 -1.47486725  4.73485967]
 [-1.36605898  3.83382309 -2.42039418 -3.69539155  1.70147654 -5.37165462]] and inertia = 317.66274606158055


MCPImpl: num_above=3049, num_below=1901
beta_solver, min eigen of left matrix = 0.8939735
MCPImpl: num_above=2927, num_below=2023
beta_solver, min eigen of left matrix = 0.8939735
MCPImpl: num_above=3051, num_below=1899
beta_solver, min eigen of left matrix = 0.8939735
MCPImpl: num_above=2931, num_below=2019
beta_solver, min eigen of left matrix = 0.8939735
MCPImpl: num_above=3051, num_below=1899
beta_solver, min eigen of left matrix = 0.8939735
MCPImpl: num_above=2933, num_below=2017
beta_solver, min eigen of left matrix = 0.8939735
MCPImpl: num_above=3051, num_below=1899
kmeans center = [[ 1.68073071 -3.71356847  2.16873497  3.76980357 -1.65993235  5.07031681]
 [-1.47848788  3.96343499 -2.48417892 -3.77677788  1.50218496 -5.6598021 ]] and inertia = 470.65304855557065
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
beta_s

MCPImpl: num_above=2975, num_below=1975
beta_solver, min eigen of left matrix = (0.8740625+0j)
MCPImpl: num_above=2825, num_below=2125
beta_solver, min eigen of left matrix = (0.8740625+0j)
MCPImpl: num_above=2983, num_below=1967
beta_solver, min eigen of left matrix = (0.8740625+0j)
MCPImpl: num_above=2838, num_below=2112
beta_solver, min eigen of left matrix = (0.8740625+0j)
MCPImpl: num_above=2988, num_below=1962
beta_solver, min eigen of left matrix = (0.8740625+0j)
MCPImpl: num_above=2843, num_below=2107
beta_solver, min eigen of left matrix = (0.8740625+0j)
MCPImpl: num_above=2994, num_below=1956
kmeans center = [[ 1.54237557 -3.66715419  1.85528319  3.54531462 -1.59144526  4.94595309]
 [-1.51416164  3.80910129 -2.24108505 -3.70924858  1.51154809 -5.02909701]] and inertia = 449.2592557296182
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
new_labels.length=2000 mat

MCPImpl: num_above=2869, num_below=2081
beta_solver, min eigen of left matrix = (0.8648326+0j)
MCPImpl: num_above=3062, num_below=1888
beta_solver, min eigen of left matrix = (0.8648326+0j)
MCPImpl: num_above=2908, num_below=2042
beta_solver, min eigen of left matrix = (0.8648326+0j)
MCPImpl: num_above=3075, num_below=1875
beta_solver, min eigen of left matrix = (0.8648326+0j)
MCPImpl: num_above=2921, num_below=2029
beta_solver, min eigen of left matrix = (0.8648326+0j)
MCPImpl: num_above=3085, num_below=1865
kmeans center = [[ 1.50842095 -3.79781519  2.09620811  3.79115508 -1.55390836  5.15058746]
 [-1.67223828  3.89736746 -1.80480112 -3.74615765  1.47969753 -4.85149726]] and inertia = 426.46738817142176
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
beta_solver, min eigen of left matrix = (0.8806715+0j)
MCPImpl: num_abo

MCPImpl: num_above=2722, num_below=2228
beta_solver, min eigen of left matrix = (0.88662225+0j)
MCPImpl: num_above=2843, num_below=2107
beta_solver, min eigen of left matrix = (0.88662225+0j)
MCPImpl: num_above=2718, num_below=2232
beta_solver, min eigen of left matrix = (0.88662225+0j)
MCPImpl: num_above=2845, num_below=2105
kmeans center = [[ 1.57154468 -3.70939752  2.09895618  3.61207464 -1.58891762  4.8955204 ]
 [-1.65887492  3.82937684 -2.21745913 -3.66739041  1.47282271 -4.93681642]] and inertia = 316.11465251955724
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
beta_solver, min eigen of left matrix = (0.88954324-0.0018141443j)
MCPImpl: num_above=2748, num_below=2202
beta_solver, min eigen of left matrix = (0.88954324-0.0018141443j)
MCPImpl: num_above=3144, num_below=1806
beta_solver, min eigen of left matrix = (0.8

MCPImpl: num_above=2915, num_below=2035
beta_solver, min eigen of left matrix = (0.8757429+0j)
MCPImpl: num_above=2802, num_below=2148
beta_solver, min eigen of left matrix = (0.8757429+0j)
MCPImpl: num_above=2918, num_below=2032
beta_solver, min eigen of left matrix = (0.8757429+0j)
MCPImpl: num_above=2809, num_below=2141
beta_solver, min eigen of left matrix = (0.8757429+0j)
MCPImpl: num_above=2918, num_below=2032
beta_solver, min eigen of left matrix = (0.8757429+0j)
MCPImpl: num_above=2818, num_below=2132
beta_solver, min eigen of left matrix = (0.8757429+0j)
MCPImpl: num_above=2916, num_below=2034
kmeans center = [[ 1.49693367 -3.93362789  2.13503596  3.88921041 -1.68361892  5.08142666]
 [-1.58401996  3.66212568 -1.95585496 -3.86957582  1.52189314 -4.89556168]] and inertia = 402.1067353634118
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
new_labels.length=2000 mat

MCPImpl: num_above=2910, num_below=2040
beta_solver, min eigen of left matrix = (0.88972324-0.004563966j)
MCPImpl: num_above=3098, num_below=1852
beta_solver, min eigen of left matrix = (0.88972324-0.004563966j)
MCPImpl: num_above=2921, num_below=2029
beta_solver, min eigen of left matrix = (0.88972324-0.004563966j)
MCPImpl: num_above=3107, num_below=1843
beta_solver, min eigen of left matrix = (0.88972324-0.004563966j)
MCPImpl: num_above=2932, num_below=2018
beta_solver, min eigen of left matrix = (0.88972324-0.004563966j)
MCPImpl: num_above=3108, num_below=1842
kmeans center = [[ 1.47676717 -3.79896412  1.959035    3.75850567 -1.4668771   5.1276993 ]
 [-1.57822354  3.54920607 -2.20078625 -3.81399961  1.66673936 -5.34355309]] and inertia = 418.95659138268576
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
beta_solver, min

MCPImpl: num_above=2856, num_below=2094
beta_solver, min eigen of left matrix = (0.88032544+0j)
MCPImpl: num_above=2761, num_below=2189
beta_solver, min eigen of left matrix = (0.88032544+0j)
MCPImpl: num_above=2858, num_below=2092
beta_solver, min eigen of left matrix = (0.88032544+0j)
MCPImpl: num_above=2777, num_below=2173
beta_solver, min eigen of left matrix = (0.88032544+0j)
MCPImpl: num_above=2864, num_below=2086
kmeans center = [[ 1.75817251 -3.71024361  2.0649422   3.60369468 -1.73045658  5.08007077]
 [-1.55309848  3.66897266 -1.95729136 -3.74627058  1.68417818 -4.90697429]] and inertia = 343.4632395181935
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
beta_solver, min eigen of left matrix = 0.8853666
MCPImpl: num_above=2562, num_below=2388
beta_solver, min eigen of left matrix = 0.8853666
MCPImpl: num_above=2703

MCPImpl: num_above=2784, num_below=2166
beta_solver, min eigen of left matrix = (0.89537746+0j)
MCPImpl: num_above=2661, num_below=2289
beta_solver, min eigen of left matrix = (0.89537746+0j)
MCPImpl: num_above=2786, num_below=2164
beta_solver, min eigen of left matrix = (0.89537746+0j)
MCPImpl: num_above=2662, num_below=2288
beta_solver, min eigen of left matrix = (0.89537746+0j)
MCPImpl: num_above=2784, num_below=2166
kmeans center = [[ 1.52967555 -3.75142277  2.1488255   3.55750908 -1.59011405  5.00147432]
 [-1.6929916   3.57305464 -2.09126601 -3.60719779  1.63969284 -5.15168754]] and inertia = 283.50123869495127
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
beta_solver, min eigen of left matrix = 0.8840495
MCPImpl: num_above=2677, num_below=2273
beta_solver, min eigen of left matrix = 0.8840495
MCPImpl: num_above=297

MCPImpl: num_above=2979, num_below=1971
beta_solver, min eigen of left matrix = (0.8727211+0j)
MCPImpl: num_above=2836, num_below=2114
beta_solver, min eigen of left matrix = (0.8727211+0j)
MCPImpl: num_above=2977, num_below=1973
beta_solver, min eigen of left matrix = (0.8727211+0j)
MCPImpl: num_above=2838, num_below=2112
beta_solver, min eigen of left matrix = (0.8727211+0j)
MCPImpl: num_above=2979, num_below=1971
kmeans center = [[ 1.5995407  -3.68954324  1.97266678  3.62922727 -1.31518454  5.09268234]
 [-1.613302    3.95911309 -2.09271605 -3.56488019  1.50651943 -4.8881512 ]] and inertia = 381.22406789353664
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
beta_solver, min eigen of left matrix = (0.8896263-0.0012499299j)
MCPImpl: num_above=2750, num_below=2200
beta_solver, min eigen of left matrix = (0.8896263-0.0012499

MCPImpl: num_above=2717, num_below=2233
beta_solver, min eigen of left matrix = 0.88138163
MCPImpl: num_above=2830, num_below=2120
beta_solver, min eigen of left matrix = 0.88138163
MCPImpl: num_above=2721, num_below=2229
beta_solver, min eigen of left matrix = 0.88138163
MCPImpl: num_above=2828, num_below=2122
kmeans center = [[ 1.77934994 -3.87239897  2.09717113  3.60010371 -1.5873759   4.82300038]
 [-1.63882567  3.626674   -2.11627723 -3.70952756  1.43999375 -4.96817708]] and inertia = 334.38850473907587
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
beta_solver, min eigen of left matrix = (0.88294125+0j)
MCPImpl: num_above=2586, num_below=2364
beta_solver, min eigen of left matrix = (0.88294125+0j)
MCPImpl: num_above=2760, num_below=2190
beta_solver, min eigen of left matrix = (0.88294125+0j)
MCPImpl: num_above=2635, 

MCPImpl: num_above=3045, num_below=1905
beta_solver, min eigen of left matrix = 0.89667416
MCPImpl: num_above=2895, num_below=2055
beta_solver, min eigen of left matrix = 0.89667416
MCPImpl: num_above=3048, num_below=1902
kmeans center = [[ 1.58670433 -3.45292558  2.06538893  3.76114532 -1.51794222  5.00427046]
 [-1.43821509  3.71862658 -2.23358965 -3.6308479   1.52359406 -5.03272067]] and inertia = 428.3299272159064
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
beta_solver, min eigen of left matrix = 0.8771966
MCPImpl: num_above=2591, num_below=2359
beta_solver, min eigen of left matrix = 0.8771966
MCPImpl: num_above=2778, num_below=2172
beta_solver, min eigen of left matrix = 0.8771966
MCPImpl: num_above=2675, num_below=2275
beta_solver, min eigen of left matrix = 0.8771966
MCPImpl: num_above=2788, num_below=2162
beta_

In [5]:
beta_learned_list[0].betas

[array([-1.4616398,  3.7746   , -1.9842905, -3.8889172,  1.4649606,
        -4.8728433], dtype=float32),
 array([ 1.4120233, -3.8926744,  2.3052633,  4.0163927, -1.5487857,
         5.1435328], dtype=float32)]

In [6]:
mu_learned_list = []
sigma_learned_list = []
z_score_learned_list = []
mu_ng_list = []
sigma_ng_list = []
z_score_ng_list = []

for beta_learned, beta_ng in zip(beta_learned_list, beta_ng_list):
    v_mus, v_sigmas = compute_V_estimate(u_truth, beta_learned)
    z_score_learned = [
        (mu - truth) / sigma for mu, truth, sigma in zip(v_mus, v_truth, v_sigmas)
    ]
    mu_learned_list.append(v_mus)
    sigma_learned_list.append(v_sigmas)
    z_score_learned_list.append(z_score_learned)

    ngv_mus, ngv_sigmas = compute_V_estimate(u_truth, beta_ng)
    z_score_ng = [
        (mu - truth) / sigma
        for mu, truth, sigma in zip(
            ngv_mus * len(v_truth), v_truth, ngv_sigmas * len(v_truth)
        )
    ]
    mu_ng_list.append(ngv_mus)
    sigma_ng_list.append(ngv_sigmas)
    z_score_ng_list.append(z_score_ng)

# Reports that average over two groups

In [10]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD
learned_in_threshold_perc = learned_in_threshold.sum() / learned_in_threshold.size
ng_in_threshold_perc = ng_in_threshold.sum() / ng_in_threshold.size
print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)
if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= 0.98 , ng_in_threshold_perc= 0.0


In [11]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list)
ac_mse = np.mean( (mu_learned_list - v_truth)**2 )

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list)
mv_mse = np.mean( (mu_ng_list - v_truth)**2 )

In [12]:
print("ACPE results: (average over groups)")
print(f"ACL: {ac_acl}") 
print(f"MSE: {ac_mse}")
print(f"ECP: {learned_in_threshold_perc}")


print("MVPE results: (average over groups)")
print(f"ACL: {mv_acl}") 
print(f"MSE: {mv_mse}")
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: (average over groups)
ACL: 0.5032844780992225
MSE: 0.014394965954124928
ECP: 0.98
MVPE results: (average over groups)
ACL: 0.7723561754295803
MSE: 12.966915130615234
ECP: 0.0


# Reports that seperate two groups

In [13]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD

learned_in_threshold_perc = np.sum(learned_in_threshold, axis=0) / learned_in_threshold.shape[0]
ng_in_threshold_perc = np.sum(ng_in_threshold, axis=0) / ng_in_threshold.shape[0]

print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)

if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= [0.99 0.97] , ng_in_threshold_perc= [0. 0.]


In [14]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list, axis=0)
ac_mse = np.mean((mu_learned_list - v_truth)**2, axis=0)

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list, axis=0)
mv_mse = np.mean( (mu_ng_list - v_truth)**2, axis=0 )

In [15]:
print("ACPE results: Group1, Group 2")
print(f"MSE: {ac_mse}")
print(f"ACL: {ac_acl}") 
print(f"ECP: {learned_in_threshold_perc}")

print("===")
print("MVPE results: ")
print(f"MSE: {mv_mse}")
print(f"ACL: {mv_acl}") 
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: Group1, Group 2
MSE: [0.01420538 0.01458455]
ACL: [0.50368047 0.50288849]
ECP: [0.99 0.97]
===
MVPE results: 
MSE: [13.034944 12.898888]
ACL: [0.77235618]
ECP: [0. 0.]
