In [1]:
import copy
import pickle
from datetime import datetime as dtdt

import attrs
import numpy as np

from hetero.config import DTYPE, AlgoConfig, DataGenConfig, GroupingConfig
from hetero.datagen import generate_data_from_config
from hetero.policies import AlternativePolicy
from hetero.tasks import (
    beta_estimate_from,
    beta_estimate_from_e2e_learning,
    beta_estimate_from_nongrouped,
    compute_UV_truths,
    compute_V_estimate,
)

group_reward_coeff = np.array(
    [
        [-2.68, 2.68],
        [2.68, -2.68],
    ],
    dtype=DTYPE,
)

action_reward_coeff = [-2.89, 2.89]

data_config_init = dict(
    num_trajectories=20,
    num_time_steps=40,
    group_reward_coeff_override=group_reward_coeff,
    action_reward_coeff=action_reward_coeff,
    num_burnin_steps=100,
    basis_expansion_method="NONE",
    add_intercept_column=True,
)

FEATURE_TYPE = "NONE"

# First time runner: set COMPUTE_TRUTH = True
# Change the flag to False after generated the truth file.
COMPUTE_TRUTH = True
# Change the truth file name if settings are changed.
TRUTH_FILE = f"hetero/data/{FEATURE_TYPE}_truth_20230528_2.68_2.89.pkl"
print("truth file name =", TRUTH_FILE)

time_tag = dtdt.now().strftime("%Y%m%d_%H-%M-%S")
tag = f'N={data_config_init["num_trajectories"]}_T={data_config_init["num_time_steps"]}_{time_tag}'
RESULT_FILE = f"hetero/data/{FEATURE_TYPE}_result_20230528_2.68_2.89_{tag}.pkl"
print("result file name =", RESULT_FILE)

SAVE_RESULT = True
if not SAVE_RESULT:
    print("Result will NOT be saved, only use this for experimental runs!!!")

NUM_EXPERIMENTS = 100

truth file name = hetero/data/NONE_truth_20230528_2.68_2.89.pkl
result file name = hetero/data/NONE_result_20230528_2.68_2.89_N=20_T=40_20230602_12-02-59.pkl


=====================================================================================================
# Algorithm 

- Set configure below.

In [2]:

algo_config = AlgoConfig(
    max_num_iters=10,
    gam=2.7,
    lam=1.6,
    rho=2.0,
    should_remove_outlier=True,
    outlier_lower_perc=2,
    outlier_upper_perc=98,
    nu_coeff=1e-5,
    delta_coeff=1e-5,
    use_group_wise_regression_init=True,
)

pi_eval = AlternativePolicy(2)

grouping_config = GroupingConfig()

In [3]:
if COMPUTE_TRUTH:
    truth_data_config_init = copy.copy(data_config_init)
    truth_data_config_init.pop("num_trajectories")
    us, vs = compute_UV_truths(
        truth_data_config_init,
        algo_config.discount,
        pi_eval,
        num_repeats=10,
        num_trajectories=1000,
    )  # For best results, use num_repeats=10
    u_truth = us.mean(axis=0)
    v_truth = vs.mean(axis=0)

    with open(TRUTH_FILE, "wb") as f:
        pickle.dump(
            dict(
                u_truth=u_truth,
                v_truth=v_truth,
                data_config_dict=truth_data_config_init,
                algo_config_dict=attrs.asdict(algo_config),
            ),
            f,
        )
else:
    with open(TRUTH_FILE, "rb") as f:
        loaded = pickle.load(f)
        u_truth = loaded["u_truth"]
        v_truth = loaded["v_truth"]

In [4]:
beta_ng_list = []
beta_learned_list = []

for i in range(NUM_EXPERIMENTS):
    data_config = DataGenConfig(seed=7531 * (i + 1), **data_config_init)
    data = generate_data_from_config(data_config)
    beta_ng_list.append(
        beta_estimate_from_nongrouped(data, pi_eval, algo_config.discount)
    )
    beta_learned_list.append(
        beta_estimate_from_e2e_learning(data, algo_config, grouping_config, pi_eval)
    )

new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
beta_solver, min eigen of left matrix = (0.93656087+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.93656087+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.93656087+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.93656087+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.93656087+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.93656087+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.93656087+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.93656087+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.93656087+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.93656087+0j)


MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.93154365+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.93154365+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.93154365+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.93154365+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.93154365+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.93154365+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.93154365+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.93154365+0j)
MCPImpl: num_above=400, num_below=380
kmeans center = [[-1.44562171  3.52082568 -2.04356447 -3.87883561  1.7375116  -4.98325122]
 [ 1.46564151 -3.88748169  1.99341742  3.83585869 -1.68701898  4.97153158]] and inertia = 49.08847869387748
Label misma

beta_solver, min eigen of left matrix = (0.9249989+0j)
MCPImpl: num_above=406, num_below=374
beta_solver, min eigen of left matrix = (0.9249989+0j)
MCPImpl: num_above=417, num_below=363
beta_solver, min eigen of left matrix = (0.9249989+0j)
MCPImpl: num_above=407, num_below=373
beta_solver, min eigen of left matrix = (0.9249989+0j)
MCPImpl: num_above=417, num_below=363
kmeans center = [[-1.61893391  3.65030682 -1.98329732 -3.97076179  1.67681945 -4.89091337]
 [ 1.66819043 -3.57411768  2.27935995  3.63256556 -1.59258267  5.23865906]] and inertia = 72.25069661498975
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
beta_solver, min eigen of left matrix = 0.9355698
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.9355698
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 

beta_solver, min eigen of left matrix = (0.937611+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.937611+0j)
MCPImpl: num_above=414, num_below=366
beta_solver, min eigen of left matrix = (0.937611+0j)
MCPImpl: num_above=404, num_below=376
beta_solver, min eigen of left matrix = (0.937611+0j)
MCPImpl: num_above=413, num_below=367
beta_solver, min eigen of left matrix = (0.937611+0j)
MCPImpl: num_above=404, num_below=376
beta_solver, min eigen of left matrix = (0.937611+0j)
MCPImpl: num_above=413, num_below=367
beta_solver, min eigen of left matrix = (0.937611+0j)
MCPImpl: num_above=405, num_below=375
beta_solver, min eigen of left matrix = (0.937611+0j)
MCPImpl: num_above=413, num_below=367
beta_solver, min eigen of left matrix = (0.937611+0j)
MCPImpl: num_above=404, num_below=376
beta_solver, min eigen of left matrix = (0.937611+0j)
MCPImpl: num_above=413, num_below=367
kmeans center = [[-1.70070712  3.6554379  -1.84559234 -3.72287873  1.5584433  -4

beta_solver, min eigen of left matrix = (0.9339988+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9339988+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9339988+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9339988+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9339988+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9339988+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9339988+0j)
MCPImpl: num_above=400, num_below=380
kmeans center = [[-1.75144343  3.90280337 -2.34808172 -3.64306684  1.44012983 -5.06225718]
 [ 1.53836533 -3.77115396  2.06848523  3.78888195 -1.73852245  5.13752348]] and inertia = 49.75067307761417
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1600 matches number of records
new_labels.length=1600 matches numb

MCPImpl: num_above=407, num_below=373
beta_solver, min eigen of left matrix = (0.9331223+0j)
MCPImpl: num_above=407, num_below=373
beta_solver, min eigen of left matrix = (0.9331223+0j)
MCPImpl: num_above=407, num_below=373
beta_solver, min eigen of left matrix = (0.9331223+0j)
MCPImpl: num_above=408, num_below=372
kmeans center = [[-1.64896626  3.89110138 -2.44114166 -4.04547613  1.50031341 -5.23248519]
 [ 1.42233056 -3.90067407  2.30333382  3.91163999 -1.6176537   5.09164139]] and inertia = 72.41238251037493
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
beta_solver, min eigen of left matrix = (0.92967844+0j)
MCPImpl: num_above=403, num_below=377
beta_solver, min eigen of left matrix = (0.92967844+0j)
MCPImpl: num_above=416, num_below=364
beta_solver, min eigen of left matrix = (0.92967844+0j)
MCPImpl: num_above=406, num

beta_solver, min eigen of left matrix = (0.9275719-0.0013132484j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9275719-0.0013132484j)
MCPImpl: num_above=402, num_below=378
kmeans center = [[-1.56180294  3.77658536 -2.34390129 -3.73682891  1.73117275 -5.30868073]
 [ 1.3363801  -3.68853781  2.2041923   3.65070784 -1.53511492  5.07210322]] and inertia = 59.04638300803642
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
beta_solver, min eigen of left matrix = (0.9282399+0j)
MCPImpl: num_above=402, num_below=378
beta_solver, min eigen of left matrix = (0.9282399+0j)
MCPImpl: num_above=420, num_below=360
beta_solver, min eigen of left matrix = (0.9282399+0j)
MCPImpl: num_above=405, num_below=375
beta_solver, min eigen of left matrix = (0.9282399+0j)
MCPImpl: num_above=420, num_below=360
beta_sol

MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9295254+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9295254+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9295254+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9295254+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9295254+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9295254+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9295254+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9295254+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9295254+0j)
MCPImpl: num_above=400, num_below=380
kmeans center = [[-1.58087557  3.74545806 -2.21749366 -3.96061772  1.5447453  -5.14482601]
 [ 1.80776339 -3.82208465  2.12037

beta_solver, min eigen of left matrix = (0.9289225+0j)
MCPImpl: num_above=401, num_below=379
beta_solver, min eigen of left matrix = (0.9289225+0j)
MCPImpl: num_above=406, num_below=374
beta_solver, min eigen of left matrix = (0.9289225+0j)
MCPImpl: num_above=401, num_below=379
beta_solver, min eigen of left matrix = (0.9289225+0j)
MCPImpl: num_above=406, num_below=374
kmeans center = [[-1.33709285  3.70084147 -2.36998862 -3.57366622  1.56693688 -5.15927651]
 [ 1.63610145 -3.58908711  2.28881571  3.75807429 -1.68005585  5.27936746]] and inertia = 55.083280835525414
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
beta_solver, min eigen of left matrix = (0.9301954+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9301954+0j)
MCPImpl: num_above=408, num_below=372
beta_solver, min eigen of lef

MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9241811+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9241811+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9241811+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9241811+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9241811+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9241811+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9241811+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9241811+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9241811+0j)
MCPImpl: num_above=400, num_below=380
kmeans center = [[-1.63195567  3.90632582 -2.02171795 -3.96922902  1.80081166 -5.03374925]
 [ 1.47622774 -3.44311758  2.06767

beta_solver, min eigen of left matrix = 0.93487877
MCPImpl: num_above=406, num_below=374
beta_solver, min eigen of left matrix = 0.93487877
MCPImpl: num_above=402, num_below=378
beta_solver, min eigen of left matrix = 0.93487877
MCPImpl: num_above=406, num_below=374
beta_solver, min eigen of left matrix = 0.93487877
MCPImpl: num_above=403, num_below=377
beta_solver, min eigen of left matrix = 0.93487877
MCPImpl: num_above=406, num_below=374
beta_solver, min eigen of left matrix = 0.93487877
MCPImpl: num_above=403, num_below=377
beta_solver, min eigen of left matrix = 0.93487877
MCPImpl: num_above=406, num_below=374
kmeans center = [[-1.48554714  3.9155592  -2.22974407 -3.80980293  1.47164182 -5.27609163]
 [ 1.71073654 -3.83025052  2.0277049   3.81931997 -1.73057541  4.94657212]] and inertia = 55.972821431452886
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
new_labels.le

MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9356282+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9356282+0j)
MCPImpl: num_above=400, num_below=380
kmeans center = [[-1.42208014  3.69986465 -2.02863927 -3.59349718  1.51788082 -4.9227519 ]
 [ 1.61609306 -3.86257154  2.05502668  3.82263888 -1.58137608  5.06820482]] and inertia = 47.20564103730188
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
beta_solver, min eigen of left matrix = 0.93390787
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.93390787
MCPImpl: num_above=404, num_below=376
beta_solver, min eigen of left matrix = 0.93390787
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.93390787
MCPImpl: num_above=404, num_below=376
beta_sol

beta_solver, min eigen of left matrix = (0.9303806+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9303806+0j)
MCPImpl: num_above=404, num_below=376
beta_solver, min eigen of left matrix = (0.9303806+0j)
MCPImpl: num_above=402, num_below=378
beta_solver, min eigen of left matrix = (0.9303806+0j)
MCPImpl: num_above=404, num_below=376
beta_solver, min eigen of left matrix = (0.9303806+0j)
MCPImpl: num_above=402, num_below=378
beta_solver, min eigen of left matrix = (0.9303806+0j)
MCPImpl: num_above=404, num_below=376
beta_solver, min eigen of left matrix = (0.9303806+0j)
MCPImpl: num_above=402, num_below=378
beta_solver, min eigen of left matrix = (0.9303806+0j)
MCPImpl: num_above=404, num_below=376
beta_solver, min eigen of left matrix = (0.9303806+0j)
MCPImpl: num_above=402, num_below=378
beta_solver, min eigen of left matrix = (0.9303806+0j)
MCPImpl: num_above=404, num_below=376
kmeans center = [[-1.71000707  3.79099055 -2.254176   -4.0052298   1.6

beta_solver, min eigen of left matrix = (0.9300188+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9300188+0j)
MCPImpl: num_above=401, num_below=379
beta_solver, min eigen of left matrix = (0.9300188+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9300188+0j)
MCPImpl: num_above=401, num_below=379
beta_solver, min eigen of left matrix = (0.9300188+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9300188+0j)
MCPImpl: num_above=401, num_below=379
kmeans center = [[-1.48787792  3.86396071 -2.11256025 -3.7496424   1.46555591 -5.02786069]
 [ 1.57696805 -3.78327586  2.2080563   3.73437711 -1.59579964  5.06474701]] and inertia = 52.58234451439329
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
beta_solver, min eigen of left

new_labels.length=1600 matches number of records
beta_solver, min eigen of left matrix = (0.9268125+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9268125+0j)
MCPImpl: num_above=406, num_below=374
beta_solver, min eigen of left matrix = (0.9268125+0j)
MCPImpl: num_above=401, num_below=379
beta_solver, min eigen of left matrix = (0.9268125+0j)
MCPImpl: num_above=406, num_below=374
beta_solver, min eigen of left matrix = (0.9268125+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9268125+0j)
MCPImpl: num_above=406, num_below=374
beta_solver, min eigen of left matrix = (0.9268125+0j)
MCPImpl: num_above=401, num_below=379
beta_solver, min eigen of left matrix = (0.9268125+0j)
MCPImpl: num_above=406, num_below=374
beta_solver, min eigen of left matrix = (0.9268125+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.9268125+0j)
MCPImpl: num_above=406, num_below=374
kmeans center = [[-1.

beta_solver, min eigen of left matrix = (0.93556386+0j)
MCPImpl: num_above=411, num_below=369
beta_solver, min eigen of left matrix = (0.93556386+0j)
MCPImpl: num_above=403, num_below=377
beta_solver, min eigen of left matrix = (0.93556386+0j)
MCPImpl: num_above=411, num_below=369
beta_solver, min eigen of left matrix = (0.93556386+0j)
MCPImpl: num_above=403, num_below=377
beta_solver, min eigen of left matrix = (0.93556386+0j)
MCPImpl: num_above=411, num_below=369
beta_solver, min eigen of left matrix = (0.93556386+0j)
MCPImpl: num_above=403, num_below=377
beta_solver, min eigen of left matrix = (0.93556386+0j)
MCPImpl: num_above=411, num_below=369
beta_solver, min eigen of left matrix = (0.93556386+0j)
MCPImpl: num_above=403, num_below=377
beta_solver, min eigen of left matrix = (0.93556386+0j)
MCPImpl: num_above=411, num_below=369
kmeans center = [[-1.47559811  3.65159512 -1.98490463 -3.87549952  1.64984401 -4.9445889 ]
 [ 1.62956739 -3.83464397  2.14130473  3.65426755 -1.70744696  

In [5]:
beta_learned_list[0].betas

[array([-1.3146352,  4.1371684, -2.2762327, -4.026389 ,  1.5439901,
        -5.2463217], dtype=float32),
 array([ 1.3276949, -3.8160374,  2.4772327,  3.7160957, -1.5836291,
         5.4018064], dtype=float32)]

In [6]:
mu_learned_list = []
sigma_learned_list = []
z_score_learned_list = []
mu_ng_list = []
sigma_ng_list = []
z_score_ng_list = []

for beta_learned, beta_ng in zip(beta_learned_list, beta_ng_list):
    v_mus, v_sigmas = compute_V_estimate(u_truth, beta_learned)
    z_score_learned = [
        (mu - truth) / sigma for mu, truth, sigma in zip(v_mus, v_truth, v_sigmas)
    ]
    mu_learned_list.append(v_mus)
    sigma_learned_list.append(v_sigmas)
    z_score_learned_list.append(z_score_learned)

    ngv_mus, ngv_sigmas = compute_V_estimate(u_truth, beta_ng)
    z_score_ng = [
        (mu - truth) / sigma
        for mu, truth, sigma in zip(
            ngv_mus * len(v_truth), v_truth, ngv_sigmas * len(v_truth)
        )
    ]
    mu_ng_list.append(ngv_mus)
    sigma_ng_list.append(ngv_sigmas)
    z_score_ng_list.append(z_score_ng)

# Reports that average over two groups

In [7]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD
learned_in_threshold_perc = learned_in_threshold.sum() / learned_in_threshold.size
ng_in_threshold_perc = ng_in_threshold.sum() / ng_in_threshold.size
print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)
if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= 0.96 , ng_in_threshold_perc= 0.0


In [8]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list)
ac_mse = np.mean( (mu_learned_list - v_truth)**2 )

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list)
mv_mse = np.mean( (mu_ng_list - v_truth)**2 )

In [9]:
print("ACPE results: (average over groups)")
print(f"ACL: {ac_acl}") 
print(f"MSE: {ac_mse}")
print(f"ECP: {learned_in_threshold_perc}")


print("MVPE results: (average over groups)")
print(f"ACL: {mv_acl}") 
print(f"MSE: {mv_mse}")
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: (average over groups)
ACL: 0.5630263983020498
MSE: 0.01685441844165325
ECP: 0.96
MVPE results: (average over groups)
ACL: 0.8594511736297609
MSE: 12.973557472229004
ECP: 0.0


# Reports that seperate two groups

In [10]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD

learned_in_threshold_perc = np.sum(learned_in_threshold, axis=0) / learned_in_threshold.shape[0]
ng_in_threshold_perc = np.sum(ng_in_threshold, axis=0) / ng_in_threshold.shape[0]

print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)

if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= [0.96 0.96] , ng_in_threshold_perc= [0. 0.]


In [11]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list, axis=0)
ac_mse = np.mean((mu_learned_list - v_truth)**2, axis=0)

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list, axis=0)
mv_mse = np.mean( (mu_ng_list - v_truth)**2, axis=0 )

In [12]:
print("ACPE results: Group1, Group 2")
print(f"MSE: {ac_mse}")
print(f"ACL: {ac_acl}") 
print(f"ECP: {learned_in_threshold_perc}")

print("===")
print("MVPE results: ")
print(f"MSE: {mv_mse}")
print(f"ACL: {mv_acl}") 
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: Group1, Group 2
MSE: [0.01729083 0.016418  ]
ACL: [0.56384562 0.56220717]
ECP: [0.96 0.96]
===
MVPE results: 
MSE: [12.861096 13.086015]
ACL: [0.85945117]
ECP: [0. 0.]
