In [1]:
import copy
import pickle
from datetime import datetime as dtdt

import attrs
import numpy as np

from hetero.config import DTYPE, AlgoConfig, DataGenConfig, GroupingConfig
from hetero.datagen import generate_data_from_config
from hetero.policies import AlternativePolicy
from hetero.tasks import (
    beta_estimate_from,
    beta_estimate_from_e2e_learning,
    beta_estimate_from_nongrouped,
    compute_UV_truths,
    compute_V_estimate,
)

group_reward_coeff = np.array(
    [
        [-2.68, 2.68],
        [2.68, -2.68],
    ],
    dtype=DTYPE,
)

action_reward_coeff = [-2.89, 2.89]

data_config_init = dict(
    num_trajectories=20,
    num_time_steps=20,
    group_reward_coeff_override=group_reward_coeff,
    action_reward_coeff=action_reward_coeff,
    num_burnin_steps=100,
    basis_expansion_method="NONE",
    add_intercept_column=True,
)

FEATURE_TYPE = "NONE"

# First time runner: set COMPUTE_TRUTH = True
# Change the flag to False after generated the truth file.
COMPUTE_TRUTH = True # False
# Change the truth file name if settings are changed.
TRUTH_FILE = f"hetero/data/{FEATURE_TYPE}_truth_20230528_2.68_2.89.pkl"
print("truth file name =", TRUTH_FILE)

time_tag = dtdt.now().strftime("%Y%m%d_%H-%M-%S")
tag = f'N={data_config_init["num_trajectories"]}_T={data_config_init["num_time_steps"]}_{time_tag}'
RESULT_FILE = f"hetero/data/{FEATURE_TYPE}_result_20230528_2.68_2.89_{tag}.pkl"
print("result file name =", RESULT_FILE)

SAVE_RESULT = True
if not SAVE_RESULT:
    print("Result will NOT be saved, only use this for experimental runs!!!")

NUM_EXPERIMENTS = 100

truth file name = hetero/data/NONE_truth_20230528_1.68_2.89.pkl
result file name = hetero/data/NONE_result_20230528_1.68_2.89_N=20_T=20_20230602_11-01-17.pkl


=====================================================================================================
# Algorithm 

- Set configure below.

In [2]:

algo_config = AlgoConfig(
    max_num_iters=10,
    gam=2.7,
    lam=1.6,
    rho=2.0,
    should_remove_outlier=True,
    outlier_lower_perc=2,
    outlier_upper_perc=98,
    nu_coeff=1e-5,
    delta_coeff=1e-5,
    use_group_wise_regression_init=True,
)

pi_eval = AlternativePolicy(2)

grouping_config = GroupingConfig()

In [3]:
if COMPUTE_TRUTH:
    truth_data_config_init = copy.copy(data_config_init)
    truth_data_config_init.pop("num_trajectories")
    us, vs = compute_UV_truths(
        truth_data_config_init,
        algo_config.discount,
        pi_eval,
        num_repeats=10,
        num_trajectories=1000,
    )  # For best results, use num_repeats=10
    u_truth = us.mean(axis=0)
    v_truth = vs.mean(axis=0)

    with open(TRUTH_FILE, "wb") as f:
        pickle.dump(
            dict(
                u_truth=u_truth,
                v_truth=v_truth,
                data_config_dict=truth_data_config_init,
                algo_config_dict=attrs.asdict(algo_config),
            ),
            f,
        )
else:
    with open(TRUTH_FILE, "rb") as f:
        loaded = pickle.load(f)
        u_truth = loaded["u_truth"]
        v_truth = loaded["v_truth"]

In [4]:
beta_ng_list = []
beta_learned_list = []

for i in range(NUM_EXPERIMENTS):
    data_config = DataGenConfig(seed=7531 * (i + 1), **data_config_init)
    data = generate_data_from_config(data_config)
    beta_ng_list.append(
        beta_estimate_from_nongrouped(data, pi_eval, algo_config.discount)
    )
    beta_learned_list.append(
        beta_estimate_from_e2e_learning(data, algo_config, grouping_config, pi_eval)
    )

new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
beta_solver, min eigen of left matrix = (0.93860525+0j)
MCPImpl: num_above=460, num_below=320
beta_solver, min eigen of left matrix = (0.93860525+0j)
MCPImpl: num_above=555, num_below=225
beta_solver, min eigen of left matrix = (0.93860525+0j)
MCPImpl: num_above=475, num_below=305
beta_solver, min eigen of left matrix = (0.93860525+0j)
MCPImpl: num_above=554, num_below=226
beta_solver, min eigen of left matrix = (0.93860525+0j)
MCPImpl: num_above=479, num_below=301
beta_solver, min eigen of left matrix = (0.93860525+0j)
MCPImpl: num_above=558, num_below=222
beta_solver, min eigen of left matrix = (0.93860525+0j)
MCPImpl: num_above=485, num_below=295
beta_solver, min eigen of left matrix = (0.93860525+0j)
MCPImpl: num_above=558, num_below=222
beta_solver, min eigen of left matrix = (0.93860525+0j)
MCPImpl: num_above=483, num_below=297
beta_solver, min eigen of left matrix = (0.93860525+0j)
MC

MCPImpl: num_above=532, num_below=248
beta_solver, min eigen of left matrix = (0.9386925+0j)
MCPImpl: num_above=447, num_below=333
beta_solver, min eigen of left matrix = (0.9386925+0j)
MCPImpl: num_above=529, num_below=251
beta_solver, min eigen of left matrix = (0.9386925+0j)
MCPImpl: num_above=454, num_below=326
beta_solver, min eigen of left matrix = (0.9386925+0j)
MCPImpl: num_above=532, num_below=248
beta_solver, min eigen of left matrix = (0.9386925+0j)
MCPImpl: num_above=455, num_below=325
beta_solver, min eigen of left matrix = (0.9386925+0j)
MCPImpl: num_above=531, num_below=249
beta_solver, min eigen of left matrix = (0.9386925+0j)
MCPImpl: num_above=453, num_below=327
beta_solver, min eigen of left matrix = (0.9386925+0j)
MCPImpl: num_above=530, num_below=250
kmeans center = [[-2.0358754   4.1755745  -1.85032682 -3.69846283  1.62026552 -5.38579308]
 [ 1.50930918 -3.44859172  2.48496573  3.79255522 -1.44484964  5.24390853]] and inertia = 174.45410271859566
Label mismatch = 0

beta_solver, min eigen of left matrix = (0.9306005+0j)
MCPImpl: num_above=548, num_below=232
beta_solver, min eigen of left matrix = (0.9306005+0j)
MCPImpl: num_above=489, num_below=291
beta_solver, min eigen of left matrix = (0.9306005+0j)
MCPImpl: num_above=547, num_below=233
kmeans center = [[-1.7636872   3.94582288 -2.10430029 -3.413082    1.56089002 -5.23767376]
 [ 1.51891405 -3.54491323  2.39928081  3.64454725 -1.48341785  5.27314174]] and inertia = 223.0977887838717
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
beta_solver, min eigen of left matrix = (0.9077001+0j)
MCPImpl: num_above=452, num_below=328
beta_solver, min eigen of left matrix = (0.9077001+0j)
MCPImpl: num_above=522, num_below=258
beta_solver, min eigen of left matrix = (0.9077001+0j)
MCPImpl: num_above=471, num_below=309
beta_solver, min eigen of left ma

new_labels.length=800 matches number of records
beta_solver, min eigen of left matrix = (0.94086593+0j)
MCPImpl: num_above=454, num_below=326
beta_solver, min eigen of left matrix = (0.94086593+0j)
MCPImpl: num_above=558, num_below=222
beta_solver, min eigen of left matrix = (0.94086593+0j)
MCPImpl: num_above=467, num_below=313
beta_solver, min eigen of left matrix = (0.94086593+0j)
MCPImpl: num_above=560, num_below=220
beta_solver, min eigen of left matrix = (0.94086593+0j)
MCPImpl: num_above=472, num_below=308
beta_solver, min eigen of left matrix = (0.94086593+0j)
MCPImpl: num_above=560, num_below=220
beta_solver, min eigen of left matrix = (0.94086593+0j)
MCPImpl: num_above=477, num_below=303
beta_solver, min eigen of left matrix = (0.94086593+0j)
MCPImpl: num_above=561, num_below=219
beta_solver, min eigen of left matrix = (0.94086593+0j)
MCPImpl: num_above=476, num_below=304
beta_solver, min eigen of left matrix = (0.94086593+0j)
MCPImpl: num_above=560, num_below=220
kmeans cente

MCPImpl: num_above=522, num_below=258
beta_solver, min eigen of left matrix = 0.93428254
MCPImpl: num_above=461, num_below=319
beta_solver, min eigen of left matrix = 0.93428254
MCPImpl: num_above=524, num_below=256
beta_solver, min eigen of left matrix = 0.93428254
MCPImpl: num_above=464, num_below=316
beta_solver, min eigen of left matrix = 0.93428254
MCPImpl: num_above=524, num_below=256
beta_solver, min eigen of left matrix = 0.93428254
MCPImpl: num_above=464, num_below=316
beta_solver, min eigen of left matrix = 0.93428254
MCPImpl: num_above=524, num_below=256
beta_solver, min eigen of left matrix = 0.93428254
MCPImpl: num_above=463, num_below=317
beta_solver, min eigen of left matrix = 0.93428254
MCPImpl: num_above=524, num_below=256
kmeans center = [[-1.72244128  3.40746114 -2.34660755 -3.71316162  1.75290404 -5.3872859 ]
 [ 1.65697328 -3.9478065   1.99102149  3.63242336 -1.58688987  5.17505225]] and inertia = 186.03789610011899
Label mismatch = 0
new_labels.length=40 matches nu

beta_solver, min eigen of left matrix = 0.9374779
MCPImpl: num_above=492, num_below=288
beta_solver, min eigen of left matrix = 0.9374779
MCPImpl: num_above=536, num_below=244
beta_solver, min eigen of left matrix = 0.9374779
MCPImpl: num_above=493, num_below=287
beta_solver, min eigen of left matrix = 0.9374779
MCPImpl: num_above=536, num_below=244
kmeans center = [[-1.31999855  3.88001028 -2.19390679 -4.02529765  1.55594188 -5.38565215]
 [ 1.4838259  -3.43617896  2.13373362  3.74667622 -1.71465786  5.09563879]] and inertia = 209.83422092128671
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
beta_solver, min eigen of left matrix = 0.93532014
MCPImpl: num_above=456, num_below=324
beta_solver, min eigen of left matrix = 0.93532014
MCPImpl: num_above=560, num_below=220
beta_solver, min eigen of left matrix = 0.93532014
MCPImpl: 

MCPImpl: num_above=564, num_below=216
kmeans center = [[-1.64493385  3.3876527  -1.64730618 -3.61330709  1.54553806 -4.50685267]
 [ 1.72380032 -4.43031073  1.90494428  3.38923969 -1.68433255  4.96198805]] and inertia = 277.76539692773537
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
beta_solver, min eigen of left matrix = (0.94895613+0j)
MCPImpl: num_above=426, num_below=354
beta_solver, min eigen of left matrix = (0.94895613+0j)
MCPImpl: num_above=534, num_below=246
beta_solver, min eigen of left matrix = (0.94895613+0j)
MCPImpl: num_above=436, num_below=344
beta_solver, min eigen of left matrix = (0.94895613+0j)
MCPImpl: num_above=530, num_below=250
beta_solver, min eigen of left matrix = (0.94895613+0j)
MCPImpl: num_above=448, num_below=332
beta_solver, min eigen of left matrix = (0.94895613+0j)
MCPImpl: num_above=529, nu

beta_solver, min eigen of left matrix = 0.9344738
MCPImpl: num_above=565, num_below=215
beta_solver, min eigen of left matrix = 0.9344738
MCPImpl: num_above=486, num_below=294
beta_solver, min eigen of left matrix = 0.9344738
MCPImpl: num_above=559, num_below=221
beta_solver, min eigen of left matrix = 0.9344738
MCPImpl: num_above=491, num_below=289
beta_solver, min eigen of left matrix = 0.9344738
MCPImpl: num_above=561, num_below=219
beta_solver, min eigen of left matrix = 0.9344738
MCPImpl: num_above=485, num_below=295
beta_solver, min eigen of left matrix = 0.9344738
MCPImpl: num_above=558, num_below=222
beta_solver, min eigen of left matrix = 0.9344738
MCPImpl: num_above=486, num_below=294
beta_solver, min eigen of left matrix = 0.9344738
MCPImpl: num_above=558, num_below=222
kmeans center = [[-1.53779104  3.71276077 -2.56642884 -3.87547606  1.68655051 -5.3965873 ]
 [ 1.95777365 -4.02607385  1.55917246  3.35098693 -1.79300857  4.59575418]] and inertia = 237.43260553160857
Label mi

beta_solver, min eigen of left matrix = (0.95169234-0.006679985j)
MCPImpl: num_above=566, num_below=214
beta_solver, min eigen of left matrix = (0.95169234-0.006679985j)
MCPImpl: num_above=491, num_below=289
beta_solver, min eigen of left matrix = (0.95169234-0.006679985j)
MCPImpl: num_above=555, num_below=225
beta_solver, min eigen of left matrix = (0.95169234-0.006679985j)
MCPImpl: num_above=492, num_below=288
beta_solver, min eigen of left matrix = (0.95169234-0.006679985j)
MCPImpl: num_above=555, num_below=225
beta_solver, min eigen of left matrix = (0.95169234-0.006679985j)
MCPImpl: num_above=492, num_below=288
beta_solver, min eigen of left matrix = (0.95169234-0.006679985j)
MCPImpl: num_above=555, num_below=225
beta_solver, min eigen of left matrix = (0.95169234-0.006679985j)
MCPImpl: num_above=492, num_below=288
beta_solver, min eigen of left matrix = (0.95169234-0.006679985j)
MCPImpl: num_above=556, num_below=224
kmeans center = [[-1.66254229  3.70898004 -2.35743305 -3.9555437

beta_solver, min eigen of left matrix = (0.94038355+0j)
MCPImpl: num_above=603, num_below=177
beta_solver, min eigen of left matrix = (0.94038355+0j)
MCPImpl: num_above=538, num_below=242
beta_solver, min eigen of left matrix = (0.94038355+0j)
MCPImpl: num_above=602, num_below=178
beta_solver, min eigen of left matrix = (0.94038355+0j)
MCPImpl: num_above=544, num_below=236
beta_solver, min eigen of left matrix = (0.94038355+0j)
MCPImpl: num_above=602, num_below=178
beta_solver, min eigen of left matrix = (0.94038355+0j)
MCPImpl: num_above=543, num_below=237
beta_solver, min eigen of left matrix = (0.94038355+0j)
MCPImpl: num_above=602, num_below=178
beta_solver, min eigen of left matrix = (0.94038355+0j)
MCPImpl: num_above=545, num_below=235
beta_solver, min eigen of left matrix = (0.94038355+0j)
MCPImpl: num_above=602, num_below=178
kmeans center = [[-1.56171856  3.54815749 -2.15420142 -4.1896613   1.27272805 -4.87312475]
 [ 1.75277973 -3.52116268  2.06978257  3.33735684 -1.64035721  

MCPImpl: num_above=496, num_below=284
beta_solver, min eigen of left matrix = (0.93417794-0.0053090816j)
MCPImpl: num_above=553, num_below=227
beta_solver, min eigen of left matrix = (0.93417794-0.0053090816j)
MCPImpl: num_above=494, num_below=286
beta_solver, min eigen of left matrix = (0.93417794-0.0053090816j)
MCPImpl: num_above=553, num_below=227
kmeans center = [[-1.28431781  4.54607809 -1.78678753 -2.60954543  1.42572876 -5.34795299]
 [ 1.38316522 -3.76769475  2.09174677  3.70440195 -1.35419562  5.12419226]] and inertia = 806.0082136786142
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
beta_solver, min eigen of left matrix = 0.9437819
MCPImpl: num_above=504, num_below=276
beta_solver, min eigen of left matrix = 0.9437819
MCPImpl: num_above=624, num_below=156
beta_solver, min eigen of left matrix = 0.9437819
MCPImpl: num

MCPImpl: num_above=451, num_below=329
beta_solver, min eigen of left matrix = 0.9435801
MCPImpl: num_above=505, num_below=275
beta_solver, min eigen of left matrix = 0.9435801
MCPImpl: num_above=455, num_below=325
beta_solver, min eigen of left matrix = 0.9435801
MCPImpl: num_above=504, num_below=276
beta_solver, min eigen of left matrix = 0.9435801
MCPImpl: num_above=455, num_below=325
beta_solver, min eigen of left matrix = 0.9435801
MCPImpl: num_above=505, num_below=275
beta_solver, min eigen of left matrix = 0.9435801
MCPImpl: num_above=455, num_below=325
beta_solver, min eigen of left matrix = 0.9435801
MCPImpl: num_above=504, num_below=276
kmeans center = [[-1.60055328  3.6883166  -2.08438322 -4.01398063  1.64228597 -5.12207175]
 [ 1.47504015 -3.8117011   1.97904569  3.36856261 -1.67779589  4.49087982]] and inertia = 163.00882613770915
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=800 matches number of records
new_labels.length=800 matches nu

beta_solver, min eigen of left matrix = (0.9429064+0j)
MCPImpl: num_above=493, num_below=287
beta_solver, min eigen of left matrix = (0.9429064+0j)
MCPImpl: num_above=468, num_below=312
beta_solver, min eigen of left matrix = (0.9429064+0j)
MCPImpl: num_above=493, num_below=287
kmeans center = [[-1.67790764  3.43317141 -2.23570794 -3.77249966  1.55977654 -4.78291678]
 [ 1.47012964 -3.33501855  2.19627649  3.38972972 -1.47429924  5.10788333]] and inertia = 244.9641471072207
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
beta_solver, min eigen of left matrix = (0.93381935+0j)
MCPImpl: num_above=448, num_below=332
beta_solver, min eigen of left matrix = (0.93381935+0j)
MCPImpl: num_above=547, num_below=233
beta_solver, min eigen of left matrix = (0.93381935+0j)
MCPImpl: num_above=473, num_below=307
beta_solver, min eigen of left

MCPImpl: num_above=575, num_below=205
beta_solver, min eigen of left matrix = (0.9454069+0j)
MCPImpl: num_above=493, num_below=287
beta_solver, min eigen of left matrix = (0.9454069+0j)
MCPImpl: num_above=572, num_below=208
beta_solver, min eigen of left matrix = (0.9454069+0j)
MCPImpl: num_above=497, num_below=283
beta_solver, min eigen of left matrix = (0.9454069+0j)
MCPImpl: num_above=571, num_below=209
beta_solver, min eigen of left matrix = (0.9454069+0j)
MCPImpl: num_above=498, num_below=282
beta_solver, min eigen of left matrix = (0.9454069+0j)
MCPImpl: num_above=571, num_below=209
beta_solver, min eigen of left matrix = (0.9454069+0j)
MCPImpl: num_above=497, num_below=283
beta_solver, min eigen of left matrix = (0.9454069+0j)
MCPImpl: num_above=571, num_below=209
kmeans center = [[-1.72928237  3.34367618 -2.30311673 -3.9234781   1.72465415 -5.41420426]
 [ 1.44487334 -3.7315508   2.59064659  3.79970868 -1.719337    5.44707954]] and inertia = 226.8116318015359
Label mismatch = 0


beta_solver, min eigen of left matrix = (0.9439122+0j)
MCPImpl: num_above=474, num_below=306
beta_solver, min eigen of left matrix = (0.9439122+0j)
MCPImpl: num_above=528, num_below=252
kmeans center = [[-1.72191032  4.35114625 -1.81850063 -3.55835915  1.42459004 -4.71927665]
 [ 1.10219215 -3.24966624  1.9098904   3.64623458 -1.4518702   4.9189262 ]] and inertia = 213.52569054684562
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
beta_solver, min eigen of left matrix = (0.9303064+0j)
MCPImpl: num_above=453, num_below=327
beta_solver, min eigen of left matrix = (0.9303064+0j)
MCPImpl: num_above=555, num_below=225
beta_solver, min eigen of left matrix = (0.9303064+0j)
MCPImpl: num_above=472, num_below=308
beta_solver, min eigen of left matrix = (0.9303064+0j)
MCPImpl: num_above=564, num_below=216
beta_solver, min eigen of left m

MCPImpl: num_above=555, num_below=225
beta_solver, min eigen of left matrix = (0.937416+0j)
MCPImpl: num_above=477, num_below=303
beta_solver, min eigen of left matrix = (0.937416+0j)
MCPImpl: num_above=549, num_below=231
beta_solver, min eigen of left matrix = (0.937416+0j)
MCPImpl: num_above=479, num_below=301
beta_solver, min eigen of left matrix = (0.937416+0j)
MCPImpl: num_above=550, num_below=230
beta_solver, min eigen of left matrix = (0.937416+0j)
MCPImpl: num_above=481, num_below=299
beta_solver, min eigen of left matrix = (0.937416+0j)
MCPImpl: num_above=550, num_below=230
beta_solver, min eigen of left matrix = (0.937416+0j)
MCPImpl: num_above=481, num_below=299
beta_solver, min eigen of left matrix = (0.937416+0j)
MCPImpl: num_above=550, num_below=230
kmeans center = [[-1.60995267  3.52027208 -2.03971176 -3.61779664  1.67990897 -5.19401004]
 [ 1.82057515 -3.74032232  2.10619897  3.77427425 -1.62687795  4.81304636]] and inertia = 210.7352551761818
Label mismatch = 0
new_labe

In [23]:
beta_learned_list[0].betas

[array([-1.4556423,  4.090999 , -2.28717  , -4.066601 ,  1.411059 ,
        -5.2927995], dtype=float32),
 array([ 1.2268863, -3.8711095,  2.4759016,  3.9264438, -1.504575 ,
         5.364038 ], dtype=float32)]

In [5]:
mu_learned_list = []
sigma_learned_list = []
z_score_learned_list = []
mu_ng_list = []
sigma_ng_list = []
z_score_ng_list = []

for beta_learned, beta_ng in zip(beta_learned_list, beta_ng_list):
    v_mus, v_sigmas = compute_V_estimate(u_truth, beta_learned)
    z_score_learned = [
        (mu - truth) / sigma for mu, truth, sigma in zip(v_mus, v_truth, v_sigmas)
    ]
    mu_learned_list.append(v_mus)
    sigma_learned_list.append(v_sigmas)
    z_score_learned_list.append(z_score_learned)

    ngv_mus, ngv_sigmas = compute_V_estimate(u_truth, beta_ng)
    z_score_ng = [
        (mu - truth) / sigma
        for mu, truth, sigma in zip(
            ngv_mus * len(v_truth), v_truth, ngv_sigmas * len(v_truth)
        )
    ]
    mu_ng_list.append(ngv_mus)
    sigma_ng_list.append(ngv_sigmas)
    z_score_ng_list.append(z_score_ng)

# Reports that average over two groups

In [12]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD
learned_in_threshold_perc = leanred_in_threshold.sum() / leanred_in_threshold.size
ng_in_threshold_perc = ng_in_threshold.sum() / ng_in_threshold.size
print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)
if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= 0.97 , ng_in_threshold_perc= 0.0


In [13]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list)
ac_mse = np.mean( (mu_learned_list - v_truth)**2 )

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list)
mv_mse = np.mean( (mu_ng_list - v_truth)**2 )

In [14]:
print("ACPE results: (average over groups)")
print(f"ACL: {ac_acl}") 
print(f"MSE: {ac_mse}")
print(f"ECP: {learned_in_threshold_perc}")


print("MVPE results: (average over groups)")
print(f"ACL: {mv_acl}") 
print(f"MSE: {mv_mse}")
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: 
ACL: 0.8017043736565425
MSE: 0.0373266376554966
ECP: 0.97
MVPE results: 
ACL: 1.219222704607102
MSE: 13.000341415405273
ECP: 0.0


# Reports that seperate two groups

In [32]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD

learned_in_threshold_perc = np.sum(learned_in_threshold, axis=0) / learned_in_threshold.shape[0]
ng_in_threshold_perc = np.sum(ng_in_threshold, axis=0) / ng_in_threshold.shape[0]

print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)

if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= [0.97 0.97] , ng_in_threshold_perc= [0. 0.]


In [33]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list, axis=0)
ac_mse = np.mean((mu_learned_list - v_truth)**2, axis=0)

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list, axis=0)
mv_mse = np.mean( (mu_ng_list - v_truth)**2, axis=0 )

In [37]:
print("ACPE results: Group1, Group 2")
print(f"MSE: {ac_mse}")
print(f"ACL: {ac_acl}") 
print(f"ECP: {learned_in_threshold_perc}")

print("===")
print("MVPE results: ")
print(f"MSE: {mv_mse}")
print(f"ACL: {mv_acl}") 
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: Group1, Group 2
MSE: [0.03912388 0.03552939]
ACL: [0.80288801 0.80052074]
ECP: [0.97 0.97]
===
MVPE results: 
MSE: [12.811908 13.188772]
ACL: [1.2192227]
ECP: [0. 0.]
