In [1]:
import copy
import pickle
from datetime import datetime as dtdt

import attrs
import numpy as np

from hetero.config import DTYPE, AlgoConfig, DataGenConfig, GroupingConfig
from hetero.datagen import generate_data_from_config
from hetero.policies import AlternativePolicy
from hetero.tasks import (
    beta_estimate_from,
    beta_estimate_from_e2e_learning,
    beta_estimate_from_nongrouped,
    compute_UV_truths,
    compute_V_estimate,
)

group_reward_coeff = np.array(
    [
        [-2.68, 2.68],
        [2.68, -2.68],
    ],
    dtype=DTYPE,
)

action_reward_coeff = [-2.89, 2.89]

data_config_init = dict(
    num_trajectories=50,
    num_time_steps=40,
    group_reward_coeff_override=group_reward_coeff,
    action_reward_coeff=action_reward_coeff,
    num_burnin_steps=100,
    basis_expansion_method="NONE",
    add_intercept_column=True,
)

FEATURE_TYPE = "NONE"

# First time runner: set COMPUTE_TRUTH = True
# Change the flag to False after generated the truth file.
COMPUTE_TRUTH = True
# Change the truth file name if settings are changed.
TRUTH_FILE = f"hetero/data/{FEATURE_TYPE}_truth_20230528_2.68_2.89.pkl"
print("truth file name =", TRUTH_FILE)

time_tag = dtdt.now().strftime("%Y%m%d_%H-%M-%S")
tag = f'N={data_config_init["num_trajectories"]}_T={data_config_init["num_time_steps"]}_{time_tag}'
RESULT_FILE = f"hetero/data/{FEATURE_TYPE}_result_20230528_2.68_2.89_{tag}.pkl"
print("result file name =", RESULT_FILE)

SAVE_RESULT = True
if not SAVE_RESULT:
    print("Result will NOT be saved, only use this for experimental runs!!!")

NUM_EXPERIMENTS = 100

truth file name = hetero/data/NONE_truth_20230528_2.68_2.89.pkl
result file name = hetero/data/NONE_result_20230528_2.68_2.89_N=50_T=40_20230602_12-15-37.pkl


=====================================================================================================
# Algorithm 

- Set configure below.

In [2]:

algo_config = AlgoConfig(
    max_num_iters=10,
    gam=2.7,
    lam=1.6,
    rho=2.0,
    should_remove_outlier=True,
    outlier_lower_perc=2,
    outlier_upper_perc=98,
    nu_coeff=1e-5,
    delta_coeff=1e-5,
    use_group_wise_regression_init=True,
)

pi_eval = AlternativePolicy(2)

grouping_config = GroupingConfig()

In [3]:
if COMPUTE_TRUTH:
    truth_data_config_init = copy.copy(data_config_init)
    truth_data_config_init.pop("num_trajectories")
    us, vs = compute_UV_truths(
        truth_data_config_init,
        algo_config.discount,
        pi_eval,
        num_repeats=10,
        num_trajectories=1000,
    )  # For best results, use num_repeats=10
    u_truth = us.mean(axis=0)
    v_truth = vs.mean(axis=0)

    with open(TRUTH_FILE, "wb") as f:
        pickle.dump(
            dict(
                u_truth=u_truth,
                v_truth=v_truth,
                data_config_dict=truth_data_config_init,
                algo_config_dict=attrs.asdict(algo_config),
            ),
            f,
        )
else:
    with open(TRUTH_FILE, "rb") as f:
        loaded = pickle.load(f)
        u_truth = loaded["u_truth"]
        v_truth = loaded["v_truth"]

In [4]:
beta_ng_list = []
beta_learned_list = []

for i in range(NUM_EXPERIMENTS):
    data_config = DataGenConfig(seed=7531 * (i + 1), **data_config_init)
    data = generate_data_from_config(data_config)
    beta_ng_list.append(
        beta_estimate_from_nongrouped(data, pi_eval, algo_config.discount)
    )
    beta_learned_list.append(
        beta_estimate_from_e2e_learning(data, algo_config, grouping_config, pi_eval)
    )

new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = 0.8558852
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8558852
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8558852
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8558852
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8558852
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8558852
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8558852
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8558852
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8558852
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8558852
MCPImpl: num_above=2500, num_below=2450
km

MCPImpl: num_above=2501, num_below=2449
beta_solver, min eigen of left matrix = (0.85078424-0.00053900096j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85078424-0.00053900096j)
MCPImpl: num_above=2501, num_below=2449
beta_solver, min eigen of left matrix = (0.85078424-0.00053900096j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85078424-0.00053900096j)
MCPImpl: num_above=2501, num_below=2449
beta_solver, min eigen of left matrix = (0.85078424-0.00053900096j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85078424-0.00053900096j)
MCPImpl: num_above=2501, num_below=2449
beta_solver, min eigen of left matrix = (0.85078424-0.00053900096j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85078424-0.00053900096j)
MCPImpl: num_above=2501, num_below=2449
kmeans center = [[ 1.49412713 -3.99420384  2.13268928  3.80482153 -1.57779511  5.13051233]
 [-1.

MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85442144+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85442144+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85442144+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85442144+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85442144+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85442144+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85442144+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85442144+0j)
MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 1.53613383 -3.84928748  2.23936296  3.85102453 -1.47540621  5.14751737]
 [-1.421419    3.72081659 -2.20940387 -3.82242733  1.48407489 -4.95017726]] and inertia = 51.29920738

MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85511494+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85511494+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85511494+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85511494+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85511494+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85511494+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85511494+0j)
MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 1.52293238 -3.92220492  2.11130775  3.80851072 -1.48556374  5.04787488]
 [-1.53250344  3.73965334 -2.10289951 -3.92604539  1.60775191 -5.04636535]] and inertia = 51.59789259414171
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=4000

MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8458287
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8458287
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8458287
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8458287
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8458287
MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 1.59841496 -3.80181815  2.11087569  3.72726985 -1.54375936  4.93741177]
 [-1.63251634  3.82340286 -2.08138686 -3.89793892  1.52374523 -5.1388842 ]] and inertia = 43.07900899139379
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.8518705+0j)
MCPImpl: num_above=2500, num_below=2450
be

MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.84913844
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.84913844
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.84913844
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.84913844
MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 1.62855168 -3.85167161  2.02771488  3.78486455 -1.53273454  4.90917778]
 [-1.52708511  3.64748575 -2.32641836 -3.74715657  1.61474203 -5.10355919]] and inertia = 37.786663807867384
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = 0.85023266
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.85023266
MCPImpl: num_above=2500, num_below=2450


MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.8521398+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.8521398+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.8521398+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.8521398+0j)
MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 1.6422442  -3.89846667  2.11881763  3.79495532 -1.59589539  5.05705815]
 [-1.57187499  3.74962986 -2.21509198 -3.98777027  1.57515217 -5.12624716]] and inertia = 64.3901521141959
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.8562241+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.8562241+0j)
MCPImpl: num_above

MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.8551264+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.8551264+0j)
MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 1.51330024 -3.80248202  2.16118167  3.860234   -1.57169577  5.00798203]
 [-1.58767955  3.81016393 -2.09169055 -3.76460129  1.62909704 -5.02945527]] and inertia = 49.7834180904119
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.8557877+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.8557877+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.8557877+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.8557877+0j)
MCPImpl: num_above

MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85243833+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85243833+0j)
MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 1.38776388 -3.75679413  2.2137474   3.57307051 -1.56886152  5.03502393]
 [-1.51135421  3.88116299 -2.04548933 -3.87344687  1.65638582 -5.00480001]] and inertia = 47.60100830433457
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = 0.8554216
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8554216
MCPImpl: num_above=2501, num_below=2449
beta_solver, min eigen of left matrix = 0.8554216
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8554216
MCPImpl: num_above=2501, num_below=

MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85245717+0j)
MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 1.56882545 -3.79311855  2.10248279  3.91729218 -1.61991012  5.01392338]
 [-1.52176925  3.86834437 -2.24680894 -3.85523094  1.51665025 -5.10108746]] and inertia = 46.65207874463128
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.84592694+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.84592694+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.84592694+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.84592694+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.84592694+0j)
MCPImpl: nu

new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.85202724+0j)
MCPImpl: num_above=2501, num_below=2449
beta_solver, min eigen of left matrix = (0.85202724+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85202724+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85202724+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85202724+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85202724+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85202724+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85202724+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85202724+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix 

beta_solver, min eigen of left matrix = 0.8546302
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8546302
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8546302
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8546302
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8546302
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8546302
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8546302
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8546302
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8546302
MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 1.53691504 -3.8089397   2.28472525  3.82937985 -1.46330598  5.15864333]
 [-1.49914322  3.7108403  -1.97936598 -3.77863851  1.57111022 -4.92901791]] and inertia = 48.613104

MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8374714
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8374714
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8374714
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8374714
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8374714
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8374714
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8374714
MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 1.46334007 -3.8137041   2.13766127  3.92060868 -1.50845102  5.0845418 ]
 [-1.54388477  3.88185359 -2.06445272 -3.91414204  1.56221787 -4.97338734]] and inertia = 54.78228717560385
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.leng

MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8486414
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8486414
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8486414
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8486414
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8486414
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8486414
MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 1.61179035 -3.88960987  2.2564541   3.75531508 -1.59871464  5.16217813]
 [-1.60182565  3.74595423 -2.12858764 -3.94385861  1.5219897  -5.09087653]] and inertia = 64.92702361857127
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_so

beta_solver, min eigen of left matrix = (0.8562641+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.8562641+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.8562641+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.8562641+0j)
MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 1.61666207 -3.89622957  2.13238649  3.6984946  -1.47189327  5.0914025 ]
 [-1.43945048  3.79922439 -2.14401603 -3.75146874  1.47710521 -4.94270318]] and inertia = 55.979691043982065
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = 0.8533538
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8533538
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of 

MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.8548473
MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 1.49426877 -3.81013794  2.15032249  3.91197567 -1.45075377  5.23073745]
 [-1.53961127  3.89375074 -2.13636731 -3.80870367  1.57980552 -4.99117875]] and inertia = 44.42287884635045
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.85755855+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85755855+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85755855+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85755855+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85755855+0j)
MCPImpl: num_abov

new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.85192615+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85192615+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85192615+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85192615+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85192615+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85192615+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85192615+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85192615+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.85192615+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix 

In [5]:
beta_learned_list[0].betas

[array([-1.4498472,  3.5568786, -2.0234642, -3.9543889,  1.5418319,
        -5.0576715], dtype=float32),
 array([ 1.4610441, -3.893724 ,  2.0736306,  3.9288344, -1.4145379,
         4.9486365], dtype=float32)]

In [6]:
mu_learned_list = []
sigma_learned_list = []
z_score_learned_list = []
mu_ng_list = []
sigma_ng_list = []
z_score_ng_list = []

for beta_learned, beta_ng in zip(beta_learned_list, beta_ng_list):
    v_mus, v_sigmas = compute_V_estimate(u_truth, beta_learned)
    z_score_learned = [
        (mu - truth) / sigma for mu, truth, sigma in zip(v_mus, v_truth, v_sigmas)
    ]
    mu_learned_list.append(v_mus)
    sigma_learned_list.append(v_sigmas)
    z_score_learned_list.append(z_score_learned)

    ngv_mus, ngv_sigmas = compute_V_estimate(u_truth, beta_ng)
    z_score_ng = [
        (mu - truth) / sigma
        for mu, truth, sigma in zip(
            ngv_mus * len(v_truth), v_truth, ngv_sigmas * len(v_truth)
        )
    ]
    mu_ng_list.append(ngv_mus)
    sigma_ng_list.append(ngv_sigmas)
    z_score_ng_list.append(z_score_ng)

# Reports that average over two groups

In [7]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD
learned_in_threshold_perc = learned_in_threshold.sum() / learned_in_threshold.size
ng_in_threshold_perc = ng_in_threshold.sum() / ng_in_threshold.size
print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)
if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= 0.97 , ng_in_threshold_perc= 0.0


In [8]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list)
ac_mse = np.mean( (mu_learned_list - v_truth)**2 )

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list)
mv_mse = np.mean( (mu_ng_list - v_truth)**2 )

In [9]:
print("ACPE results: (average over groups)")
print(f"ACL: {ac_acl}") 
print(f"MSE: {ac_mse}")
print(f"ECP: {learned_in_threshold_perc}")


print("MVPE results: (average over groups)")
print(f"ACL: {mv_acl}") 
print(f"MSE: {mv_mse}")
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: (average over groups)
ACL: 0.3566992240625845
MSE: 0.006548701319843531
ECP: 0.97
MVPE results: (average over groups)
ACL: 0.5454984091820911
MSE: 12.952200889587402
ECP: 0.0


# Reports that seperate two groups

In [10]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD

learned_in_threshold_perc = np.sum(learned_in_threshold, axis=0) / learned_in_threshold.shape[0]
ng_in_threshold_perc = np.sum(ng_in_threshold, axis=0) / ng_in_threshold.shape[0]

print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)

if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= [0.99 0.95] , ng_in_threshold_perc= [0. 0.]


In [11]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list, axis=0)
ac_mse = np.mean((mu_learned_list - v_truth)**2, axis=0)

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list, axis=0)
mv_mse = np.mean( (mu_ng_list - v_truth)**2, axis=0 )

In [12]:
print("ACPE results: Group1, Group 2")
print(f"MSE: {ac_mse}")
print(f"ACL: {ac_acl}") 
print(f"ECP: {learned_in_threshold_perc}")

print("===")
print("MVPE results: ")
print(f"MSE: {mv_mse}")
print(f"ACL: {mv_acl}") 
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: Group1, Group 2
MSE: [0.00571771 0.00737969]
ACL: [0.35693396 0.35646449]
ECP: [0.99 0.95]
===
MVPE results: 
MSE: [13.011875  12.8925295]
ACL: [0.54549841]
ECP: [0. 0.]
