In [1]:
import copy
import pickle
from datetime import datetime as dtdt

import attrs
import numpy as np

from hetero.config import DTYPE, AlgoConfig, DataGenConfig, GroupingConfig
from hetero.datagen import generate_data_from_config
from hetero.policies import AlternativePolicy
from hetero.tasks import (
    beta_estimate_from,
    beta_estimate_from_e2e_learning,
    beta_estimate_from_nongrouped,
    compute_UV_truths,
    compute_V_estimate,
)

group_reward_coeff = np.array(
    [
        [-2.68, 2.68],
        [2.68, -2.68],
    ],
    dtype=DTYPE,
)

action_reward_coeff = [-2.89, 2.89]

data_config_init = dict(
    num_trajectories=100,
    num_time_steps=40,
    group_reward_coeff_override=group_reward_coeff,
    action_reward_coeff=action_reward_coeff,
    num_burnin_steps=100,
    basis_expansion_method="NONE",
    add_intercept_column=True,
)

FEATURE_TYPE = "NONE"

# Change the flag to False after generated the truth file. Only need once for multiple N and T. 
COMPUTE_TRUTH = False
# Change the truth file name if settings are changed.
TRUTH_FILE = f"hetero/data/{FEATURE_TYPE}_truth_20230528_2.68_2.89.pkl"
print("truth file name =", TRUTH_FILE)

time_tag = dtdt.now().strftime("%Y%m%d_%H-%M-%S")
tag = f'N={data_config_init["num_trajectories"]}_T={data_config_init["num_time_steps"]}_{time_tag}'
RESULT_FILE = f"hetero/data/{FEATURE_TYPE}_result_20230528_2.68_2.89_{tag}.pkl"
print("result file name =", RESULT_FILE)

SAVE_RESULT = True
if not SAVE_RESULT:
    print("Result will NOT be saved, only use this for experimental runs!!!")

NUM_EXPERIMENTS = 100

truth file name = hetero/data/NONE_truth_20230528_2.68_2.89.pkl
result file name = hetero/data/NONE_result_20230528_2.68_2.89_N=100_T=40_20230602_12-52-03.pkl


=====================================================================================================
# Algorithm 

- Set configure below.

In [2]:

algo_config = AlgoConfig(
    max_num_iters=10,
    gam=2.7,
    lam=1.6,
    rho=2.0,
    should_remove_outlier=True,
    outlier_lower_perc=2,
    outlier_upper_perc=98,
    nu_coeff=1e-5,
    delta_coeff=1e-5,
    use_group_wise_regression_init=True,
)

pi_eval = AlternativePolicy(2)

grouping_config = GroupingConfig()

In [3]:
if COMPUTE_TRUTH:
    truth_data_config_init = copy.copy(data_config_init)
    truth_data_config_init.pop("num_trajectories")
    us, vs = compute_UV_truths(
        truth_data_config_init,
        algo_config.discount,
        pi_eval,
        num_repeats=10,
        num_trajectories=1000,
    )  # For best results, use num_repeats=10
    u_truth = us.mean(axis=0)
    v_truth = vs.mean(axis=0)

    with open(TRUTH_FILE, "wb") as f:
        pickle.dump(
            dict(
                u_truth=u_truth,
                v_truth=v_truth,
                data_config_dict=truth_data_config_init,
                algo_config_dict=attrs.asdict(algo_config),
            ),
            f,
        )
else:
    with open(TRUTH_FILE, "rb") as f:
        loaded = pickle.load(f)
        u_truth = loaded["u_truth"]
        v_truth = loaded["v_truth"]

In [None]:
beta_ng_list = []
beta_learned_list = []

for i in range(NUM_EXPERIMENTS):
    data_config = DataGenConfig(seed=7531 * (i + 1), **data_config_init)
    data = generate_data_from_config(data_config)
    beta_ng_list.append(
        beta_estimate_from_nongrouped(data, pi_eval, algo_config.discount)
    )
    beta_learned_list.append(
        beta_estimate_from_e2e_learning(data, algo_config, grouping_config, pi_eval)
    )

new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
beta_solver, min eigen of left matrix = 0.75136054
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.75136054
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.75136054
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.75136054
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.75136054
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.75136054
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.75136054
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.75136054
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.75136054
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.75136054
MCPImpl: num_above=1000

MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7494612
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7494612
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7494612
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7494612
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7494612
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7494612
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7494612
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7494612
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7494612
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.49246079 -3.86853693  2.20477325  3.86399167 -1.49617279  5.09694418]
 [-1.44985501  3.83122663 -2.21454477 -3.88119866

MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.75041175+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.75041175+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.75041175+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.75041175+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.75041175+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.75041175+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.75041175+0j)
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.5305121  -3.85647418  2.16952508  3.87250967 -1.58581467  5.02482494]
 [-1.53569422  3.8810703  -1.95302746 -3.81214032  1.44590719 -4.95105641]] and inertia = 6.778378856835886
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.len

MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7515309
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7515309
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7515309
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7515309
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7515309
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7515309
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.39669729 -3.78757913  2.18877308  3.95997713 -1.50091601  5.05431812]
 [-1.51992512  3.86907801 -2.23612242 -3.75004852  1.52997505 -5.04498395]] and inertia = 11.432504027198782
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records

MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7515459
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7515459
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7515459
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7515459
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.47526461 -3.83285913  2.29903425  3.82451393 -1.50557778  5.21088706]
 [-1.38068152  3.89138308 -2.2916224  -3.85778125  1.54524618 -5.17078839]] and inertia = 9.235681565112461
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
beta_solver, min eigen of left matrix = 0.7479251
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7479251
MCPImpl: num_above=10000, num_below=9900


MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7464021+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7464021+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7464021+0j)
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.47879967 -3.78461217  2.25807871  3.85496901 -1.5239078   5.08020574]
 [-1.48201015  3.81717476 -2.2445577  -3.9132178   1.51345387 -5.1393236 ]] and inertia = 9.220886671622555
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
beta_solver, min eigen of left matrix = (0.74972636+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.74972636+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.74972636+0j)
MCPImpl:

MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.75071514
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.75071514
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.47619157 -3.90055699  2.09658958  3.93284187 -1.47864383  4.99880534]
 [-1.46554861  3.81591704 -2.11653614 -3.95201307  1.46284978 -5.02456025]] and inertia = 9.763380288606868
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
beta_solver, min eigen of left matrix = (0.7506711+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7506711+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7506711+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7506711+0j)
MCPImpl: num_above=

MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.74829584
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.74829584
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.46236172 -3.75986763  2.10630163  3.86007813 -1.51487394  5.03195013]
 [-1.47144303  3.93876933 -2.12372042 -3.80614593  1.40225615 -4.99952148]] and inertia = 8.973015227043348
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
beta_solver, min eigen of left matrix = (0.7486004+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7486004+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7486004+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7486004+0j)
MCPImpl: num_above=

beta_solver, min eigen of left matrix = (0.7485852+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7485852+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7485852+0j)
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.47264088 -4.002702    2.20798482  3.90095458 -1.55084336  5.10832774]
 [-1.5140816   3.98507218 -2.17525923 -3.94833288  1.48360227 -5.14296627]] and inertia = 9.302528059603855
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
beta_solver, min eigen of left matrix = (0.751206+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.751206+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.751206+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, m

MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.74908376+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.74908376+0j)
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.50052606 -3.84575245  2.17030983  3.84896236 -1.50135845  5.03434487]
 [-1.50931042  3.90017566 -2.11577083 -3.89765865  1.49221877 -4.9508233 ]] and inertia = 10.242077957272206
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
beta_solver, min eigen of left matrix = (0.7516212+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7516212+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7516212+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7516212+0j)
MCPImpl:

beta_solver, min eigen of left matrix = (0.74903405+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.74903405+0j)
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.52769165 -3.88797724  2.06630619  3.77486737 -1.44326436  4.9074961 ]
 [-1.44578323  3.90245533 -2.07425668 -3.86180324  1.54972163 -5.08428387]] and inertia = 9.798732310722725
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
beta_solver, min eigen of left matrix = (0.74757165+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.74757165+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.74757165+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.74757165+0j)
MCPImpl: num_above=10000, num_below=9900
beta_

MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7487573+0j)
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.51279984 -3.90264338  2.10675518  3.76787833 -1.43704958  5.04903288]
 [-1.44646001  3.79648791 -2.15268532 -3.68476497  1.51937583 -5.04340348]] and inertia = 9.652151393188507
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
beta_solver, min eigen of left matrix = 0.7469554
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7469554
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7469554
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7469554
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7469554
MCPImpl: num_above=10000, num_below=

new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
beta_solver, min eigen of left matrix = (0.7478022+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7478022+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7478022+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7478022+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7478022+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7478022+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7478022+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7478022+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7478022+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix 

beta_solver, min eigen of left matrix = 0.75163054
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.75163054
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.75163054
MCPImpl: num_above=10001, num_below=9899
beta_solver, min eigen of left matrix = 0.75163054
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.75163054
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.75163054
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.75163054
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.75163054
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.75163054
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.75163054
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.50766142 -3.88324936  2.06734957  3.91805738 -1.44300211  4

MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.75238997+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.75238997+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.75238997+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.75238997+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.75238997+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.75238997+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.75238997+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.75238997+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.75238997+0j)
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.50826716 -3.92446276  2.10898196  4.01034211 -1.59844403  5.08836

beta_solver, min eigen of left matrix = (0.7467225+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7467225+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7467225+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7467225+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7467225+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7467225+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7467225+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7467225+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.7467225+0j)
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.4954884  -3.85410629  2.15241573  3.83184203 -1.49683031  5.0201964 ]
 [-1.4065968   3.85832895 -2.2531759  -3.9111

MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7516162
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7516162
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7516162
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7516162
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7516162
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7516162
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.7516162
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.58131305 -3.83561424  2.08963223  3.7782184  -1.45431561  4.97241707]
 [-1.5166278   3.82543096 -2.11846927 -3.84714097  1.4470808  -5.04248294]] and inertia = 10.996480195822251
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=8000 matches number of records
new_la

In [None]:
beta_learned_list[0].betas

In [None]:
mu_learned_list = []
sigma_learned_list = []
z_score_learned_list = []
mu_ng_list = []
sigma_ng_list = []
z_score_ng_list = []

for beta_learned, beta_ng in zip(beta_learned_list, beta_ng_list):
    v_mus, v_sigmas = compute_V_estimate(u_truth, beta_learned)
    z_score_learned = [
        (mu - truth) / sigma for mu, truth, sigma in zip(v_mus, v_truth, v_sigmas)
    ]
    mu_learned_list.append(v_mus)
    sigma_learned_list.append(v_sigmas)
    z_score_learned_list.append(z_score_learned)

    ngv_mus, ngv_sigmas = compute_V_estimate(u_truth, beta_ng)
    z_score_ng = [
        (mu - truth) / sigma
        for mu, truth, sigma in zip(
            ngv_mus * len(v_truth), v_truth, ngv_sigmas * len(v_truth)
        )
    ]
    mu_ng_list.append(ngv_mus)
    sigma_ng_list.append(ngv_sigmas)
    z_score_ng_list.append(z_score_ng)

# Reports that average over two groups

In [None]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD
learned_in_threshold_perc = learned_in_threshold.sum() / learned_in_threshold.size
ng_in_threshold_perc = ng_in_threshold.sum() / ng_in_threshold.size
print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)
if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

In [None]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list)
ac_mse = np.mean( (mu_learned_list - v_truth)**2 )

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list)
mv_mse = np.mean( (mu_ng_list - v_truth)**2 )

In [None]:
print("ACPE results: (average over groups)")
print(f"ACL: {ac_acl}") 
print(f"MSE: {ac_mse}")
print(f"ECP: {learned_in_threshold_perc}")


print("MVPE results: (average over groups)")
print(f"ACL: {mv_acl}") 
print(f"MSE: {mv_mse}")
print(f"ECP: {ng_in_threshold_perc}")

# Reports that seperate two groups

In [None]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD

learned_in_threshold_perc = np.sum(learned_in_threshold, axis=0) / learned_in_threshold.shape[0]
ng_in_threshold_perc = np.sum(ng_in_threshold, axis=0) / ng_in_threshold.shape[0]

print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)

if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

In [None]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list, axis=0)
ac_mse = np.mean((mu_learned_list - v_truth)**2, axis=0)

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list, axis=0)
mv_mse = np.mean( (mu_ng_list - v_truth)**2, axis=0 )

In [None]:
print("ACPE results: Group1, Group 2")
print(f"MSE: {ac_mse}")
print(f"ACL: {ac_acl}") 
print(f"ECP: {learned_in_threshold_perc}")

print("===")
print("MVPE results: ")
print(f"MSE: {mv_mse}")
print(f"ACL: {mv_acl}") 
print(f"ECP: {ng_in_threshold_perc}")