In [1]:
import copy
import pickle
from datetime import datetime as dtdt

import attrs
import numpy as np

from hetero.config import DTYPE, AlgoConfig, DataGenConfig, GroupingConfig
from hetero.datagen import generate_data_from_config
from hetero.policies import AlternativePolicy
from hetero.tasks import (
    beta_estimate_from,
    beta_estimate_from_e2e_learning,
    beta_estimate_from_nongrouped,
    compute_UV_truths,
    compute_V_estimate,
)

group_reward_coeff = np.array(
    [
        [-2.68, 2.68],
        [2.68, -2.68],
    ],
    dtype=DTYPE,
)

action_reward_coeff = [-2.89, 2.89]

data_config_init = dict(
    num_trajectories=100,
    num_time_steps=20,
    group_reward_coeff_override=group_reward_coeff,
    action_reward_coeff=action_reward_coeff,
    num_burnin_steps=100,
    basis_expansion_method="NONE",
    add_intercept_column=True,
)

FEATURE_TYPE = "NONE"

# First time runner: set COMPUTE_TRUTH = True
# Change the flag to False after generated the truth file.
COMPUTE_TRUTH = False
# Change the truth file name if settings are changed.
TRUTH_FILE = f"hetero/data/{FEATURE_TYPE}_truth_20230528_2.68_2.89.pkl"
print("truth file name =", TRUTH_FILE)

time_tag = dtdt.now().strftime("%Y%m%d_%H-%M-%S")
tag = f'N={data_config_init["num_trajectories"]}_T={data_config_init["num_time_steps"]}_{time_tag}'
RESULT_FILE = f"hetero/data/{FEATURE_TYPE}_result_20230528_2.68_2.89_{tag}.pkl"
print("result file name =", RESULT_FILE)

SAVE_RESULT = True
if not SAVE_RESULT:
    print("Result will NOT be saved, only use this for experimental runs!!!")

NUM_EXPERIMENTS = 100

truth file name = hetero/data/NONE_truth_20230528_2.68_2.89.pkl
result file name = hetero/data/NONE_result_20230528_2.68_2.89_N=100_T=20_20230602_12-51-43.pkl


=====================================================================================================
# Algorithm 

- Set configure below.

In [2]:

algo_config = AlgoConfig(
    max_num_iters=10,
    gam=2.7,
    lam=1.6,
    rho=2.0,
    should_remove_outlier=True,
    outlier_lower_perc=2,
    outlier_upper_perc=98,
    nu_coeff=1e-5,
    delta_coeff=1e-5,
    use_group_wise_regression_init=True,
)

pi_eval = AlternativePolicy(2)

grouping_config = GroupingConfig()

In [3]:
if COMPUTE_TRUTH:
    truth_data_config_init = copy.copy(data_config_init)
    truth_data_config_init.pop("num_trajectories")
    us, vs = compute_UV_truths(
        truth_data_config_init,
        algo_config.discount,
        pi_eval,
        num_repeats=10,
        num_trajectories=1000,
    )  # For best results, use num_repeats=10
    u_truth = us.mean(axis=0)
    v_truth = vs.mean(axis=0)

    with open(TRUTH_FILE, "wb") as f:
        pickle.dump(
            dict(
                u_truth=u_truth,
                v_truth=v_truth,
                data_config_dict=truth_data_config_init,
                algo_config_dict=attrs.asdict(algo_config),
            ),
            f,
        )
else:
    with open(TRUTH_FILE, "rb") as f:
        loaded = pickle.load(f)
        u_truth = loaded["u_truth"]
        v_truth = loaded["v_truth"]

In [4]:
beta_ng_list = []
beta_learned_list = []

for i in range(NUM_EXPERIMENTS):
    data_config = DataGenConfig(seed=7531 * (i + 1), **data_config_init)
    data = generate_data_from_config(data_config)
    beta_ng_list.append(
        beta_estimate_from_nongrouped(data, pi_eval, algo_config.discount)
    )
    beta_learned_list.append(
        beta_estimate_from_e2e_learning(data, algo_config, grouping_config, pi_eval)
    )

new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.80915236-0.001679108j)
MCPImpl: num_above=10123, num_below=9777
beta_solver, min eigen of left matrix = (0.80915236-0.001679108j)
MCPImpl: num_above=10145, num_below=9755
beta_solver, min eigen of left matrix = (0.80915236-0.001679108j)
MCPImpl: num_above=10135, num_below=9765
beta_solver, min eigen of left matrix = (0.80915236-0.001679108j)
MCPImpl: num_above=10142, num_below=9758
beta_solver, min eigen of left matrix = (0.80915236-0.001679108j)
MCPImpl: num_above=10135, num_below=9765
beta_solver, min eigen of left matrix = (0.80915236-0.001679108j)
MCPImpl: num_above=10154, num_below=9746
beta_solver, min eigen of left matrix = (0.80915236-0.001679108j)
MCPImpl: num_above=10136, num_below=9764
beta_solver, min eigen of left matrix = (0.80915236-0.001679108j)
MCPImpl: num_above=10161, num_below=9739
beta_solver, min eigen of left matrix = (0.809

new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = 0.80608594
MCPImpl: num_above=10260, num_below=9640
beta_solver, min eigen of left matrix = 0.80608594
MCPImpl: num_above=10269, num_below=9631
beta_solver, min eigen of left matrix = 0.80608594
MCPImpl: num_above=10256, num_below=9644
beta_solver, min eigen of left matrix = 0.80608594
MCPImpl: num_above=10255, num_below=9645
beta_solver, min eigen of left matrix = 0.80608594
MCPImpl: num_above=10249, num_below=9651
beta_solver, min eigen of left matrix = 0.80608594
MCPImpl: num_above=10248, num_below=9652
beta_solver, min eigen of left matrix = 0.80608594
MCPImpl: num_above=10247, num_below=9653
beta_solver, min eigen of left matrix = 0.80608594
MCPImpl: num_above=10250, num_below=9650
beta_solver, min eigen of left matrix = 0.80608594
MCPImpl: num_above=10250, num_below=9650
beta_solver, min eigen of left matrix = 0.80608594
MCPImpl: num_above=1025

new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.80490243+0j)
MCPImpl: num_above=10241, num_below=9659
beta_solver, min eigen of left matrix = (0.80490243+0j)
MCPImpl: num_above=10286, num_below=9614
beta_solver, min eigen of left matrix = (0.80490243+0j)
MCPImpl: num_above=10263, num_below=9637
beta_solver, min eigen of left matrix = (0.80490243+0j)
MCPImpl: num_above=10308, num_below=9592
beta_solver, min eigen of left matrix = (0.80490243+0j)
MCPImpl: num_above=10283, num_below=9617
beta_solver, min eigen of left matrix = (0.80490243+0j)
MCPImpl: num_above=10299, num_below=9601
beta_solver, min eigen of left matrix = (0.80490243+0j)
MCPImpl: num_above=10273, num_below=9627
beta_solver, min eigen of left matrix = (0.80490243+0j)
MCPImpl: num_above=10286, num_below=9614
beta_solver, min eigen of left matrix = (0.80490243+0j)
MCPImpl: num_above=10265, num_below=9635
beta_solver, min eigen of lef

new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.8156794-0.0038484447j)
MCPImpl: num_above=10509, num_below=9391
beta_solver, min eigen of left matrix = (0.8156794-0.0038484447j)
MCPImpl: num_above=10669, num_below=9231
beta_solver, min eigen of left matrix = (0.8156794-0.0038484447j)
MCPImpl: num_above=10598, num_below=9302
beta_solver, min eigen of left matrix = (0.8156794-0.0038484447j)
MCPImpl: num_above=10648, num_below=9252
beta_solver, min eigen of left matrix = (0.8156794-0.0038484447j)
MCPImpl: num_above=10611, num_below=9289
beta_solver, min eigen of left matrix = (0.8156794-0.0038484447j)
MCPImpl: num_above=10648, num_below=9252
beta_solver, min eigen of left matrix = (0.8156794-0.0038484447j)
MCPImpl: num_above=10626, num_below=9274
beta_solver, min eigen of left matrix = (0.8156794-0.0038484447j)
MCPImpl: num_above=10647, num_below=9253
beta_solver, min eigen of left matrix = (0.815

MCPImpl: num_above=10332, num_below=9568
kmeans center = [[ 1.48790166 -3.69420398  2.10576373  3.78258005 -1.44602164  5.00676489]
 [-1.60521708  3.78552038 -2.07764729 -3.83032166  1.53267746 -5.13085862]] and inertia = 396.60634234694953
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.793701+0j)
MCPImpl: num_above=10235, num_below=9665
beta_solver, min eigen of left matrix = (0.793701+0j)
MCPImpl: num_above=10239, num_below=9661
beta_solver, min eigen of left matrix = (0.793701+0j)
MCPImpl: num_above=10184, num_below=9716
beta_solver, min eigen of left matrix = (0.793701+0j)
MCPImpl: num_above=10215, num_below=9685
beta_solver, min eigen of left matrix = (0.793701+0j)
MCPImpl: num_above=10191, num_below=9709
beta_solver, min eigen of left matrix = (0.793701+0j)
MCPImpl: num_abo

MCPImpl: num_above=10235, num_below=9665
kmeans center = [[ 1.53827385 -3.86270219  2.12497911  3.89568604 -1.57848672  5.17176564]
 [-1.58734464  3.77851996 -2.28021736 -3.55218123  1.52854573 -5.16769453]] and inertia = 322.4319523379709
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = 0.8146654
MCPImpl: num_above=10292, num_below=9608
beta_solver, min eigen of left matrix = 0.8146654
MCPImpl: num_above=10332, num_below=9568
beta_solver, min eigen of left matrix = 0.8146654
MCPImpl: num_above=10322, num_below=9578
beta_solver, min eigen of left matrix = 0.8146654
MCPImpl: num_above=10300, num_below=9600
beta_solver, min eigen of left matrix = 0.8146654
MCPImpl: num_above=10302, num_below=9598
beta_solver, min eigen of left matrix = 0.8146654
MCPImpl: num_above=10276, num_below=9624


new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.807074+0j)
MCPImpl: num_above=10282, num_below=9618
beta_solver, min eigen of left matrix = (0.807074+0j)
MCPImpl: num_above=10382, num_below=9518
beta_solver, min eigen of left matrix = (0.807074+0j)
MCPImpl: num_above=10368, num_below=9532
beta_solver, min eigen of left matrix = (0.807074+0j)
MCPImpl: num_above=10400, num_below=9500
beta_solver, min eigen of left matrix = (0.807074+0j)
MCPImpl: num_above=10390, num_below=9510
beta_solver, min eigen of left matrix = (0.807074+0j)
MCPImpl: num_above=10405, num_below=9495
beta_solver, min eigen of left matrix = (0.807074+0j)
MCPImpl: num_above=10399, num_below=9501
beta_solver, min eigen of left matrix = (0.807074+0j)
MCPImpl: num_above=10414, num_below=9486
beta_solver, min eigen of left matrix = (0.807074+0j)
MCPImpl: num_above=10407, num_below=9493
beta_solver, min eigen of left matrix = (0.8070

beta_solver, min eigen of left matrix = 0.80538666
MCPImpl: num_above=10212, num_below=9688
beta_solver, min eigen of left matrix = 0.80538666
MCPImpl: num_above=10208, num_below=9692
beta_solver, min eigen of left matrix = 0.80538666
MCPImpl: num_above=10221, num_below=9679
beta_solver, min eigen of left matrix = 0.80538666
MCPImpl: num_above=10203, num_below=9697
beta_solver, min eigen of left matrix = 0.80538666
MCPImpl: num_above=10226, num_below=9674
beta_solver, min eigen of left matrix = 0.80538666
MCPImpl: num_above=10208, num_below=9692
beta_solver, min eigen of left matrix = 0.80538666
MCPImpl: num_above=10230, num_below=9670
beta_solver, min eigen of left matrix = 0.80538666
MCPImpl: num_above=10222, num_below=9678
beta_solver, min eigen of left matrix = 0.80538666
MCPImpl: num_above=10227, num_below=9673
kmeans center = [[ 1.52089596 -3.78097503  2.18278786  3.66384912 -1.63081463  5.09111597]
 [-1.5952523   3.62447357 -2.06435517 -3.76345808  1.56514925 -5.12148251]] and i

new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.8103355-0.0039516347j)
MCPImpl: num_above=10133, num_below=9767
beta_solver, min eigen of left matrix = (0.8103355-0.0039516347j)
MCPImpl: num_above=10120, num_below=9780
beta_solver, min eigen of left matrix = (0.8103355-0.0039516347j)
MCPImpl: num_above=10102, num_below=9798
beta_solver, min eigen of left matrix = (0.8103355-0.0039516347j)
MCPImpl: num_above=10103, num_below=9797
beta_solver, min eigen of left matrix = (0.8103355-0.0039516347j)
MCPImpl: num_above=10100, num_below=9800
beta_solver, min eigen of left matrix = (0.8103355-0.0039516347j)
MCPImpl: num_above=10110, num_below=9790
beta_solver, min eigen of left matrix = (0.8103355-0.0039516347j)
MCPImpl: num_above=10100, num_below=9800
beta_solver, min eigen of left matrix = (0.8103355-0.0039516347j)
MCPImpl: num_above=10099, num_below=9801
beta_solver, min eigen of left matrix = (0.8103355-0.0039516347j)
MCPImpl: num_above=10100, num

MCPImpl: num_above=10352, num_below=9548
beta_solver, min eigen of left matrix = 0.81502056
MCPImpl: num_above=10470, num_below=9430
beta_solver, min eigen of left matrix = 0.81502056
MCPImpl: num_above=10401, num_below=9499
beta_solver, min eigen of left matrix = 0.81502056
MCPImpl: num_above=10432, num_below=9468
beta_solver, min eigen of left matrix = 0.81502056
MCPImpl: num_above=10352, num_below=9548
beta_solver, min eigen of left matrix = 0.81502056
MCPImpl: num_above=10396, num_below=9504
beta_solver, min eigen of left matrix = 0.81502056
MCPImpl: num_above=10367, num_below=9533
beta_solver, min eigen of left matrix = 0.81502056
MCPImpl: num_above=10402, num_below=9498
beta_solver, min eigen of left matrix = 0.81502056
MCPImpl: num_above=10376, num_below=9524
beta_solver, min eigen of left matrix = 0.81502056
MCPImpl: num_above=10405, num_below=9495
kmeans center = [[ 1.77993407 -3.60379492  1.85704595  3.85218165 -1.5330941   4.87851487]
 [-1.57862121  3.82226819 -2.21112451 -3

MCPImpl: num_above=10318, num_below=9582
beta_solver, min eigen of left matrix = 0.80592006
MCPImpl: num_above=10403, num_below=9497
beta_solver, min eigen of left matrix = 0.80592006
MCPImpl: num_above=10429, num_below=9471
beta_solver, min eigen of left matrix = 0.80592006
MCPImpl: num_above=10448, num_below=9452
beta_solver, min eigen of left matrix = 0.80592006
MCPImpl: num_above=10450, num_below=9450
beta_solver, min eigen of left matrix = 0.80592006
MCPImpl: num_above=10464, num_below=9436
beta_solver, min eigen of left matrix = 0.80592006
MCPImpl: num_above=10458, num_below=9442
beta_solver, min eigen of left matrix = 0.80592006
MCPImpl: num_above=10470, num_below=9430
kmeans center = [[ 1.59209448 -3.88757818  2.07274559  3.74279164 -1.5356614   4.97099552]
 [-1.59300835  3.66746844 -2.12455179 -3.80897177  1.53934986 -5.05140206]] and inertia = 400.1033955056497
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=4000 matches number of records


beta_solver, min eigen of left matrix = (0.80678123+0j)
MCPImpl: num_above=10365, num_below=9535
beta_solver, min eigen of left matrix = (0.80678123+0j)
MCPImpl: num_above=10361, num_below=9539
beta_solver, min eigen of left matrix = (0.80678123+0j)
MCPImpl: num_above=10358, num_below=9542
beta_solver, min eigen of left matrix = (0.80678123+0j)
MCPImpl: num_above=10362, num_below=9538
beta_solver, min eigen of left matrix = (0.80678123+0j)
MCPImpl: num_above=10366, num_below=9534
beta_solver, min eigen of left matrix = (0.80678123+0j)
MCPImpl: num_above=10380, num_below=9520
beta_solver, min eigen of left matrix = (0.80678123+0j)
MCPImpl: num_above=10371, num_below=9529
kmeans center = [[ 1.60922958 -3.81157673  2.1209197   3.75151752 -1.52113762  5.05374001]
 [-1.49993809  3.78788592 -1.91411489 -3.63974822  1.61212598 -4.95910176]] and inertia = 423.2559224474698
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=4000 matches number of records
new_la

MCPImpl: num_above=10232, num_below=9668
beta_solver, min eigen of left matrix = (0.8088002+0j)
MCPImpl: num_above=10217, num_below=9683
beta_solver, min eigen of left matrix = (0.8088002+0j)
MCPImpl: num_above=10234, num_below=9666
beta_solver, min eigen of left matrix = (0.8088002+0j)
MCPImpl: num_above=10218, num_below=9682
beta_solver, min eigen of left matrix = (0.8088002+0j)
MCPImpl: num_above=10231, num_below=9669
beta_solver, min eigen of left matrix = (0.8088002+0j)
MCPImpl: num_above=10215, num_below=9685
kmeans center = [[ 1.58594244 -3.6989351   2.08448458  3.79118357 -1.4982618   4.98370086]
 [-1.61179714  3.81485607 -1.94738861 -3.81342538  1.48336426 -4.89708547]] and inertia = 374.49175525369105
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.8202375+0j)
MCPImpl: n

beta_solver, min eigen of left matrix = 0.8077963
MCPImpl: num_above=10202, num_below=9698
beta_solver, min eigen of left matrix = 0.8077963
MCPImpl: num_above=10197, num_below=9703
beta_solver, min eigen of left matrix = 0.8077963
MCPImpl: num_above=10196, num_below=9704
beta_solver, min eigen of left matrix = 0.8077963
MCPImpl: num_above=10188, num_below=9712
kmeans center = [[ 1.49162055 -3.7138159   2.1518243   3.63620737 -1.53266832  5.03127958]
 [-1.47781747  3.81570123 -2.16549125 -3.86000002  1.60392135 -5.22197128]] and inertia = 309.6133230250275
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = 0.8028911
MCPImpl: num_above=10348, num_below=9552
beta_solver, min eigen of left matrix = 0.8028911
MCPImpl: num_above=10286, num_below=9614
beta_solver, min eigen of left matrix = 0

MCPImpl: num_above=10323, num_below=9577
beta_solver, min eigen of left matrix = 0.8131408
MCPImpl: num_above=10308, num_below=9592
beta_solver, min eigen of left matrix = 0.8131408
MCPImpl: num_above=10317, num_below=9583
beta_solver, min eigen of left matrix = 0.8131408
MCPImpl: num_above=10304, num_below=9596
beta_solver, min eigen of left matrix = 0.8131408
MCPImpl: num_above=10333, num_below=9567
beta_solver, min eigen of left matrix = 0.8131408
MCPImpl: num_above=10319, num_below=9581
kmeans center = [[ 1.56874658 -3.88369791  2.21971096  3.65433766 -1.62048444  5.14617545]
 [-1.47645109  3.72955134 -2.07922615 -3.81330794  1.52748132 -5.12322993]] and inertia = 395.66605909180237
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.81081706+0j)
MCPImpl: num_above=10401, num_belo

MCPImpl: num_above=10461, num_below=9439
beta_solver, min eigen of left matrix = 0.80383265
MCPImpl: num_above=10414, num_below=9486
beta_solver, min eigen of left matrix = 0.80383265
MCPImpl: num_above=10461, num_below=9439
beta_solver, min eigen of left matrix = 0.80383265
MCPImpl: num_above=10423, num_below=9477
beta_solver, min eigen of left matrix = 0.80383265
MCPImpl: num_above=10470, num_below=9430
beta_solver, min eigen of left matrix = 0.80383265
MCPImpl: num_above=10420, num_below=9480
beta_solver, min eigen of left matrix = 0.80383265
MCPImpl: num_above=10466, num_below=9434
kmeans center = [[ 1.57606089 -3.85188288  1.98174588  3.84592318 -1.55342073  5.10937143]
 [-1.57061303  3.99674393 -1.94722009 -3.82303206  1.56957446 -4.89806185]] and inertia = 441.81645516645506
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of r

MCPImpl: num_above=10203, num_below=9697
beta_solver, min eigen of left matrix = (0.810733+0j)
MCPImpl: num_above=10262, num_below=9638
beta_solver, min eigen of left matrix = (0.810733+0j)
MCPImpl: num_above=10216, num_below=9684
beta_solver, min eigen of left matrix = (0.810733+0j)
MCPImpl: num_above=10263, num_below=9637
beta_solver, min eigen of left matrix = (0.810733+0j)
MCPImpl: num_above=10221, num_below=9679
beta_solver, min eigen of left matrix = (0.810733+0j)
MCPImpl: num_above=10265, num_below=9635
kmeans center = [[ 1.51616478 -3.69036631  2.22361364  3.8794214  -1.62084274  5.12015033]
 [-1.5212672   3.82929473 -2.06428651 -3.81634334  1.52009598 -4.9288878 ]] and inertia = 353.59946096481644
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = 0.8083798
MCPImpl: num_above=1

In [5]:
beta_learned_list[0].betas

[array([-1.5275658,  3.9006035, -2.0532732, -3.8493872,  1.4957978,
        -4.929152 ], dtype=float32),
 array([ 1.5100113, -3.8850377,  2.1295114,  3.979889 , -1.4891409,
         5.1304836], dtype=float32)]

In [6]:
mu_learned_list = []
sigma_learned_list = []
z_score_learned_list = []
mu_ng_list = []
sigma_ng_list = []
z_score_ng_list = []

for beta_learned, beta_ng in zip(beta_learned_list, beta_ng_list):
    v_mus, v_sigmas = compute_V_estimate(u_truth, beta_learned)
    z_score_learned = [
        (mu - truth) / sigma for mu, truth, sigma in zip(v_mus, v_truth, v_sigmas)
    ]
    mu_learned_list.append(v_mus)
    sigma_learned_list.append(v_sigmas)
    z_score_learned_list.append(z_score_learned)

    ngv_mus, ngv_sigmas = compute_V_estimate(u_truth, beta_ng)
    z_score_ng = [
        (mu - truth) / sigma
        for mu, truth, sigma in zip(
            ngv_mus * len(v_truth), v_truth, ngv_sigmas * len(v_truth)
        )
    ]
    mu_ng_list.append(ngv_mus)
    sigma_ng_list.append(ngv_sigmas)
    z_score_ng_list.append(z_score_ng)

# Reports that average over two groups

In [7]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD
learned_in_threshold_perc = learned_in_threshold.sum() / learned_in_threshold.size
ng_in_threshold_perc = ng_in_threshold.sum() / ng_in_threshold.size
print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)
if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= 0.965 , ng_in_threshold_perc= 0.0


In [8]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list)
ac_mse = np.mean( (mu_learned_list - v_truth)**2 )

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list)
mv_mse = np.mean( (mu_ng_list - v_truth)**2 )

In [9]:
print("ACPE results: (average over groups)")
print(f"ACL: {ac_acl}") 
print(f"MSE: {ac_mse}")
print(f"ECP: {learned_in_threshold_perc}")


print("MVPE results: (average over groups)")
print(f"ACL: {mv_acl}") 
print(f"MSE: {mv_mse}")
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: (average over groups)
ACL: 0.3566814630069848
MSE: 0.00714547373354435
ECP: 0.965
MVPE results: (average over groups)
ACL: 0.5454122750423384
MSE: 12.959375381469727
ECP: 0.0


# Reports that seperate two groups

In [10]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD

learned_in_threshold_perc = np.sum(learned_in_threshold, axis=0) / learned_in_threshold.shape[0]
ng_in_threshold_perc = np.sum(ng_in_threshold, axis=0) / ng_in_threshold.shape[0]

print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)

if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= [0.97 0.96] , ng_in_threshold_perc= [0. 0.]


In [11]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list, axis=0)
ac_mse = np.mean((mu_learned_list - v_truth)**2, axis=0)

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list, axis=0)
mv_mse = np.mean( (mu_ng_list - v_truth)**2, axis=0 )

In [12]:
print("ACPE results: Group1, Group 2")
print(f"MSE: {ac_mse}")
print(f"ACL: {ac_acl}") 
print(f"ECP: {learned_in_threshold_perc}")

print("===")
print("MVPE results: ")
print(f"MSE: {mv_mse}")
print(f"ACL: {mv_acl}") 
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: Group1, Group 2
MSE: [0.00673532 0.00755562]
ACL: [0.35752927 0.35583366]
ECP: [0.97 0.96]
===
MVPE results: 
MSE: [13.088767 12.829983]
ACL: [0.54541228]
ECP: [0. 0.]
