In [1]:
import copy
import pickle
from datetime import datetime as dtdt

import attrs
import numpy as np

from hetero.config import DTYPE, AlgoConfig, DataGenConfig, GroupingConfig
from hetero.datagen import generate_data_from_config
from hetero.policies import AlternativePolicy
from hetero.tasks import (
    beta_estimate_from,
    beta_estimate_from_e2e_learning,
    beta_estimate_from_nongrouped,
    compute_UV_truths,
    compute_V_estimate,
)

group_reward_coeff = np.array(
    [
        [-2.68, 2.68],
        [2.68, -2.68],
    ],
    dtype=DTYPE,
)

action_reward_coeff = [-2.89, 2.89]

data_config_init = dict(
    num_trajectories=20,
    num_time_steps=30,
    group_reward_coeff_override=group_reward_coeff,
    action_reward_coeff=action_reward_coeff,
    num_burnin_steps=100,
    basis_expansion_method="NONE",
    add_intercept_column=True,
)

FEATURE_TYPE = "NONE"

# First time runner: set COMPUTE_TRUTH = True
# Change the flag to False after generated the truth file.
COMPUTE_TRUTH = False
# Change the truth file name if settings are changed.
TRUTH_FILE = f"hetero/data/{FEATURE_TYPE}_truth_20230528_2.68_2.89.pkl"
print("truth file name =", TRUTH_FILE)

time_tag = dtdt.now().strftime("%Y%m%d_%H-%M-%S")
tag = f'N={data_config_init["num_trajectories"]}_T={data_config_init["num_time_steps"]}_{time_tag}'
RESULT_FILE = f"hetero/data/{FEATURE_TYPE}_result_20230528_2.68_2.89_{tag}.pkl"
print("result file name =", RESULT_FILE)

SAVE_RESULT = True
if not SAVE_RESULT:
    print("Result will NOT be saved, only use this for experimental runs!!!")

NUM_EXPERIMENTS = 100

truth file name = hetero/data/NONE_truth_20230528_2.68_2.89.pkl
result file name = hetero/data/NONE_result_20230528_2.68_2.89_N=20_T=30_20230602_11-50-39.pkl


=====================================================================================================
# Algorithm 

- Set configure below.

In [2]:

algo_config = AlgoConfig(
    max_num_iters=10,
    gam=2.7,
    lam=1.6,
    rho=2.0,
    should_remove_outlier=True,
    outlier_lower_perc=2,
    outlier_upper_perc=98,
    nu_coeff=1e-5,
    delta_coeff=1e-5,
    use_group_wise_regression_init=True,
)

pi_eval = AlternativePolicy(2)

grouping_config = GroupingConfig()

In [3]:
if COMPUTE_TRUTH:
    truth_data_config_init = copy.copy(data_config_init)
    truth_data_config_init.pop("num_trajectories")
    us, vs = compute_UV_truths(
        truth_data_config_init,
        algo_config.discount,
        pi_eval,
        num_repeats=10,
        num_trajectories=1000,
    )  # For best results, use num_repeats=10
    u_truth = us.mean(axis=0)
    v_truth = vs.mean(axis=0)

    with open(TRUTH_FILE, "wb") as f:
        pickle.dump(
            dict(
                u_truth=u_truth,
                v_truth=v_truth,
                data_config_dict=truth_data_config_init,
                algo_config_dict=attrs.asdict(algo_config),
            ),
            f,
        )
else:
    with open(TRUTH_FILE, "rb") as f:
        loaded = pickle.load(f)
        u_truth = loaded["u_truth"]
        v_truth = loaded["v_truth"]

In [4]:
beta_ng_list = []
beta_learned_list = []

for i in range(NUM_EXPERIMENTS):
    data_config = DataGenConfig(seed=7531 * (i + 1), **data_config_init)
    data = generate_data_from_config(data_config)
    beta_ng_list.append(
        beta_estimate_from_nongrouped(data, pi_eval, algo_config.discount)
    )
    beta_learned_list.append(
        beta_estimate_from_e2e_learning(data, algo_config, grouping_config, pi_eval)
    )

new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
beta_solver, min eigen of left matrix = (0.938776+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.938776+0j)
MCPImpl: num_above=416, num_below=364
beta_solver, min eigen of left matrix = (0.938776+0j)
MCPImpl: num_above=402, num_below=378
beta_solver, min eigen of left matrix = (0.938776+0j)
MCPImpl: num_above=416, num_below=364
beta_solver, min eigen of left matrix = (0.938776+0j)
MCPImpl: num_above=403, num_below=377
beta_solver, min eigen of left matrix = (0.938776+0j)
MCPImpl: num_above=416, num_below=364
beta_solver, min eigen of left matrix = (0.938776+0j)
MCPImpl: num_above=403, num_below=377
beta_solver, min eigen of left matrix = (0.938776+0j)
MCPImpl: num_above=416, num_below=364
beta_solver, min eigen of left matrix = (0.938776+0j)
MCPImpl: num_above=403, num_below=377
beta_solver, min eigen of left matrix = (0.938776+0j)
MCPImpl: num_above=4

MCPImpl: num_above=410, num_below=370
beta_solver, min eigen of left matrix = (0.94556266-0.0013921246j)
MCPImpl: num_above=422, num_below=358
beta_solver, min eigen of left matrix = (0.94556266-0.0013921246j)
MCPImpl: num_above=411, num_below=369
beta_solver, min eigen of left matrix = (0.94556266-0.0013921246j)
MCPImpl: num_above=422, num_below=358
beta_solver, min eigen of left matrix = (0.94556266-0.0013921246j)
MCPImpl: num_above=412, num_below=368
beta_solver, min eigen of left matrix = (0.94556266-0.0013921246j)
MCPImpl: num_above=422, num_below=358
kmeans center = [[-1.43969647  3.62144383 -1.95641708 -3.91238058  1.44940491 -4.92622492]
 [ 1.63413085 -3.53822159  1.83872371  4.05011482 -1.80424491  5.13876724]] and inertia = 95.55896197390419
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
beta_solver, min eigen of

MCPImpl: num_above=416, num_below=364
beta_solver, min eigen of left matrix = 0.9455186
MCPImpl: num_above=434, num_below=346
beta_solver, min eigen of left matrix = 0.9455186
MCPImpl: num_above=416, num_below=364
beta_solver, min eigen of left matrix = 0.9455186
MCPImpl: num_above=435, num_below=345
kmeans center = [[-1.61409444  3.97698785 -2.16847812 -3.52214889  1.44106328 -4.94282687]
 [ 1.55440169 -3.44586403  2.43202991  4.00313685 -1.73302253  5.26707204]] and inertia = 97.63915980526261
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
beta_solver, min eigen of left matrix = 0.9257688
MCPImpl: num_above=403, num_below=377
beta_solver, min eigen of left matrix = 0.9257688
MCPImpl: num_above=422, num_below=358
beta_solver, min eigen of left matrix = 0.9257688
MCPImpl: num_above=406, num_below=374
beta_solver, min eigen

beta_solver, min eigen of left matrix = (0.9261977+0j)
MCPImpl: num_above=431, num_below=349
kmeans center = [[-1.68942196  3.70728144 -1.69337314 -3.73181184  1.639556   -4.76229157]
 [ 1.59384834 -4.01390547  1.95868499  3.94595503 -1.62048278  4.82646749]] and inertia = 89.62272392224608
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
beta_solver, min eigen of left matrix = 0.9376974
MCPImpl: num_above=401, num_below=379
beta_solver, min eigen of left matrix = 0.9376974
MCPImpl: num_above=422, num_below=358
beta_solver, min eigen of left matrix = 0.9376974
MCPImpl: num_above=403, num_below=377
beta_solver, min eigen of left matrix = 0.9376974
MCPImpl: num_above=422, num_below=358
beta_solver, min eigen of left matrix = 0.9376974
MCPImpl: num_above=404, num_below=376
beta_solver, min eigen of left matrix = 0.9376974
MCPIm

new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
beta_solver, min eigen of left matrix = (0.93709207-0.000565137j)
MCPImpl: num_above=406, num_below=374
beta_solver, min eigen of left matrix = (0.93709207-0.000565137j)
MCPImpl: num_above=451, num_below=329
beta_solver, min eigen of left matrix = (0.93709207-0.000565137j)
MCPImpl: num_above=425, num_below=355
beta_solver, min eigen of left matrix = (0.93709207-0.000565137j)
MCPImpl: num_above=457, num_below=323
beta_solver, min eigen of left matrix = (0.93709207-0.000565137j)
MCPImpl: num_above=426, num_below=354
beta_solver, min eigen of left matrix = (0.93709207-0.000565137j)
MCPImpl: num_above=456, num_below=324
beta_solver, min eigen of left matrix = (0.93709207-0.000565137j)
MCPImpl: num_above=426, num_below=354
beta_solver, min eigen of left matrix = (0.93709207-0.000565137j)
MCPImpl: num_above=456, num_below=324
beta_solver, min eigen of left matrix = (0.93709207-0.000565137j)
MCPI

beta_solver, min eigen of left matrix = (0.94156986+0j)
MCPImpl: num_above=401, num_below=379
beta_solver, min eigen of left matrix = (0.94156986+0j)
MCPImpl: num_above=412, num_below=368
beta_solver, min eigen of left matrix = (0.94156986+0j)
MCPImpl: num_above=402, num_below=378
beta_solver, min eigen of left matrix = (0.94156986+0j)
MCPImpl: num_above=412, num_below=368
beta_solver, min eigen of left matrix = (0.94156986+0j)
MCPImpl: num_above=403, num_below=377
beta_solver, min eigen of left matrix = (0.94156986+0j)
MCPImpl: num_above=412, num_below=368
beta_solver, min eigen of left matrix = (0.94156986+0j)
MCPImpl: num_above=402, num_below=378
beta_solver, min eigen of left matrix = (0.94156986+0j)
MCPImpl: num_above=412, num_below=368
beta_solver, min eigen of left matrix = (0.94156986+0j)
MCPImpl: num_above=402, num_below=378
beta_solver, min eigen of left matrix = (0.94156986+0j)
MCPImpl: num_above=412, num_below=368
kmeans center = [[-1.69350678  3.74175413 -2.17671705 -3.577

MCPImpl: num_above=449, num_below=331
beta_solver, min eigen of left matrix = (0.9314034+0j)
MCPImpl: num_above=426, num_below=354
beta_solver, min eigen of left matrix = (0.9314034+0j)
MCPImpl: num_above=449, num_below=331
beta_solver, min eigen of left matrix = (0.9314034+0j)
MCPImpl: num_above=426, num_below=354
beta_solver, min eigen of left matrix = (0.9314034+0j)
MCPImpl: num_above=449, num_below=331
kmeans center = [[-1.67740556  3.65095655 -1.99256067 -3.7636005   1.5708132  -5.09242896]
 [ 1.49470947 -4.16671305  2.13761133  3.71669518 -1.65660727  4.74889527]] and inertia = 119.20318247115146
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
beta_solver, min eigen of left matrix = (0.9329359+0j)
MCPImpl: num_above=404, num_below=376
beta_solver, min eigen of left matrix = (0.9329359+0j)
MCPImpl: num_above=424, num_b

MCPImpl: num_above=423, num_below=357
beta_solver, min eigen of left matrix = 0.94109124
MCPImpl: num_above=404, num_below=376
beta_solver, min eigen of left matrix = 0.94109124
MCPImpl: num_above=425, num_below=355
beta_solver, min eigen of left matrix = 0.94109124
MCPImpl: num_above=406, num_below=374
beta_solver, min eigen of left matrix = 0.94109124
MCPImpl: num_above=425, num_below=355
beta_solver, min eigen of left matrix = 0.94109124
MCPImpl: num_above=406, num_below=374
beta_solver, min eigen of left matrix = 0.94109124
MCPImpl: num_above=426, num_below=354
beta_solver, min eigen of left matrix = 0.94109124
MCPImpl: num_above=407, num_below=373
beta_solver, min eigen of left matrix = 0.94109124
MCPImpl: num_above=425, num_below=355
kmeans center = [[-1.81167371  3.5369716  -2.09832394 -3.86247719  1.56247228 -5.18598659]
 [ 1.58725988 -3.8562647   1.79568259  3.63867901 -1.59868019  4.72600388]] and inertia = 89.35639450323345
Label mismatch = 0
new_labels.length=40 matches num

beta_solver, min eigen of left matrix = (0.9463826-0.0011695676j)
MCPImpl: num_above=405, num_below=375
beta_solver, min eigen of left matrix = (0.9463826-0.0011695676j)
MCPImpl: num_above=427, num_below=353
kmeans center = [[-1.46608147  3.69219593 -2.05783772 -3.76564509  1.55493703 -5.02914703]
 [ 1.46686601 -3.57130944  2.49231221  3.88976695 -1.53055716  5.23659882]] and inertia = 85.13537344207114
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
beta_solver, min eigen of left matrix = (0.94573414+0j)
MCPImpl: num_above=401, num_below=379
beta_solver, min eigen of left matrix = (0.94573414+0j)
MCPImpl: num_above=419, num_below=361
beta_solver, min eigen of left matrix = (0.94573414+0j)
MCPImpl: num_above=406, num_below=374
beta_solver, min eigen of left matrix = (0.94573414+0j)
MCPImpl: num_above=420, num_below=360
beta

new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
beta_solver, min eigen of left matrix = 0.93657506
MCPImpl: num_above=403, num_below=377
beta_solver, min eigen of left matrix = 0.93657506
MCPImpl: num_above=423, num_below=357
beta_solver, min eigen of left matrix = 0.93657506
MCPImpl: num_above=403, num_below=377
beta_solver, min eigen of left matrix = 0.93657506
MCPImpl: num_above=423, num_below=357
beta_solver, min eigen of left matrix = 0.93657506
MCPImpl: num_above=405, num_below=375
beta_solver, min eigen of left matrix = 0.93657506
MCPImpl: num_above=423, num_below=357
beta_solver, min eigen of left matrix = 0.93657506
MCPImpl: num_above=405, num_below=375
beta_solver, min eigen of left matrix = 0.93657506
MCPImpl: num_above=423, num_below=357
beta_solver, min eigen of left matrix = 0.93657506
MCPImpl: num_above=406, num_below=374
beta_solver, min eigen of left matrix = 0.93657506
MCPImpl: num_above=423, num_below=357
kmeans cente

beta_solver, min eigen of left matrix = 0.9452683
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.9452683
MCPImpl: num_above=419, num_below=361
beta_solver, min eigen of left matrix = 0.9452683
MCPImpl: num_above=402, num_below=378
beta_solver, min eigen of left matrix = 0.9452683
MCPImpl: num_above=417, num_below=363
beta_solver, min eigen of left matrix = 0.9452683
MCPImpl: num_above=403, num_below=377
beta_solver, min eigen of left matrix = 0.9452683
MCPImpl: num_above=417, num_below=363
beta_solver, min eigen of left matrix = 0.9452683
MCPImpl: num_above=403, num_below=377
beta_solver, min eigen of left matrix = 0.9452683
MCPImpl: num_above=417, num_below=363
beta_solver, min eigen of left matrix = 0.9452683
MCPImpl: num_above=403, num_below=377
beta_solver, min eigen of left matrix = 0.9452683
MCPImpl: num_above=417, num_below=363
kmeans center = [[-1.41106826  3.736387   -2.40044289 -3.97920249  1.65868909 -5.41041224]
 [ 1.64557502 -3.76876928  1.

MCPImpl: num_above=417, num_below=363
beta_solver, min eigen of left matrix = (0.94538546+0j)
MCPImpl: num_above=406, num_below=374
beta_solver, min eigen of left matrix = (0.94538546+0j)
MCPImpl: num_above=416, num_below=364
beta_solver, min eigen of left matrix = (0.94538546+0j)
MCPImpl: num_above=407, num_below=373
beta_solver, min eigen of left matrix = (0.94538546+0j)
MCPImpl: num_above=416, num_below=364
beta_solver, min eigen of left matrix = (0.94538546+0j)
MCPImpl: num_above=407, num_below=373
beta_solver, min eigen of left matrix = (0.94538546+0j)
MCPImpl: num_above=416, num_below=364
kmeans center = [[-1.59781997  3.94311821 -2.12415872 -3.81565584  1.67057244 -5.10423615]
 [ 1.5397372  -3.96263624  1.96943196  3.72824068 -1.40332015  4.7432821 ]] and inertia = 82.18050662025594
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
new_labels.length=1200 matches numb

new_labels.length=1200 matches number of records
beta_solver, min eigen of left matrix = (0.945146-0.00093707j)
MCPImpl: num_above=402, num_below=378
beta_solver, min eigen of left matrix = (0.945146-0.00093707j)
MCPImpl: num_above=414, num_below=366
beta_solver, min eigen of left matrix = (0.945146-0.00093707j)
MCPImpl: num_above=402, num_below=378
beta_solver, min eigen of left matrix = (0.945146-0.00093707j)
MCPImpl: num_above=414, num_below=366
beta_solver, min eigen of left matrix = (0.945146-0.00093707j)
MCPImpl: num_above=404, num_below=376
beta_solver, min eigen of left matrix = (0.945146-0.00093707j)
MCPImpl: num_above=414, num_below=366
beta_solver, min eigen of left matrix = (0.945146-0.00093707j)
MCPImpl: num_above=403, num_below=377
beta_solver, min eigen of left matrix = (0.945146-0.00093707j)
MCPImpl: num_above=414, num_below=366
beta_solver, min eigen of left matrix = (0.945146-0.00093707j)
MCPImpl: num_above=403, num_below=377
beta_solver, min eigen of left matrix = (0

MCPImpl: num_above=401, num_below=379
beta_solver, min eigen of left matrix = (0.941378+0j)
MCPImpl: num_above=419, num_below=361
beta_solver, min eigen of left matrix = (0.941378+0j)
MCPImpl: num_above=403, num_below=377
beta_solver, min eigen of left matrix = (0.941378+0j)
MCPImpl: num_above=420, num_below=360
beta_solver, min eigen of left matrix = (0.941378+0j)
MCPImpl: num_above=403, num_below=377
beta_solver, min eigen of left matrix = (0.941378+0j)
MCPImpl: num_above=420, num_below=360
beta_solver, min eigen of left matrix = (0.941378+0j)
MCPImpl: num_above=403, num_below=377
beta_solver, min eigen of left matrix = (0.941378+0j)
MCPImpl: num_above=420, num_below=360
beta_solver, min eigen of left matrix = (0.941378+0j)
MCPImpl: num_above=403, num_below=377
beta_solver, min eigen of left matrix = (0.941378+0j)
MCPImpl: num_above=420, num_below=360
kmeans center = [[-1.61856022  3.70083441 -2.09714267 -3.79933579  1.66109709 -5.11025533]
 [ 1.42383514 -3.68748534  2.35558586  3.87

beta_solver, min eigen of left matrix = (0.94143313-0.0014806793j)
MCPImpl: num_above=431, num_below=349
beta_solver, min eigen of left matrix = (0.94143313-0.0014806793j)
MCPImpl: num_above=409, num_below=371
beta_solver, min eigen of left matrix = (0.94143313-0.0014806793j)
MCPImpl: num_above=431, num_below=349
kmeans center = [[-1.50283427  3.82731721 -1.98892175 -3.8557657   1.63459654 -4.88165389]
 [ 1.60755664 -3.84479412  2.08102688  3.47235079 -1.49582356  5.23258972]] and inertia = 98.29719794432705
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
beta_solver, min eigen of left matrix = (0.9458678-0.001239369j)
MCPImpl: num_above=402, num_below=378
beta_solver, min eigen of left matrix = (0.9458678-0.001239369j)
MCPImpl: num_above=424, num_below=356
beta_solver, min eigen of left matrix = (0.9458678-0.001239369j)
MC

MCPImpl: num_above=415, num_below=365
beta_solver, min eigen of left matrix = 0.94176704
MCPImpl: num_above=438, num_below=342
kmeans center = [[-1.33326238  3.80810728 -2.18001146 -3.77817869  1.7070992  -5.09806177]
 [ 1.3465124  -3.80425278  2.05863902  3.87164334 -1.57762059  4.98543054]] and inertia = 104.52893998133698
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
beta_solver, min eigen of left matrix = (0.93208444+0j)
MCPImpl: num_above=447, num_below=333
beta_solver, min eigen of left matrix = (0.93208444+0j)
MCPImpl: num_above=513, num_below=267
beta_solver, min eigen of left matrix = (0.93208444+0j)
MCPImpl: num_above=458, num_below=322
beta_solver, min eigen of left matrix = (0.93208444+0j)
MCPImpl: num_above=515, num_below=265
beta_solver, min eigen of left matrix = (0.93208444+0j)
MCPImpl: num_above=464, num_

In [5]:
beta_learned_list[0].betas

[array([-1.3389812,  3.8575597, -2.521086 , -3.8335528,  1.555515 ,
        -5.3094077], dtype=float32),
 array([ 1.5471345, -3.911604 ,  2.4589884,  3.580307 , -1.401705 ,
         5.114128 ], dtype=float32)]

In [6]:
mu_learned_list = []
sigma_learned_list = []
z_score_learned_list = []
mu_ng_list = []
sigma_ng_list = []
z_score_ng_list = []

for beta_learned, beta_ng in zip(beta_learned_list, beta_ng_list):
    v_mus, v_sigmas = compute_V_estimate(u_truth, beta_learned)
    z_score_learned = [
        (mu - truth) / sigma for mu, truth, sigma in zip(v_mus, v_truth, v_sigmas)
    ]
    mu_learned_list.append(v_mus)
    sigma_learned_list.append(v_sigmas)
    z_score_learned_list.append(z_score_learned)

    ngv_mus, ngv_sigmas = compute_V_estimate(u_truth, beta_ng)
    z_score_ng = [
        (mu - truth) / sigma
        for mu, truth, sigma in zip(
            ngv_mus * len(v_truth), v_truth, ngv_sigmas * len(v_truth)
        )
    ]
    mu_ng_list.append(ngv_mus)
    sigma_ng_list.append(ngv_sigmas)
    z_score_ng_list.append(z_score_ng)

# Reports that average over two groups

In [7]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD
learned_in_threshold_perc = learned_in_threshold.sum() / learned_in_threshold.size
ng_in_threshold_perc = ng_in_threshold.sum() / ng_in_threshold.size
print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)
if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= 0.98 , ng_in_threshold_perc= 0.0


In [8]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list)
ac_mse = np.mean( (mu_learned_list - v_truth)**2 )

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list)
mv_mse = np.mean( (mu_ng_list - v_truth)**2 )

In [9]:
print("ACPE results: (average over groups)")
print(f"ACL: {ac_acl}") 
print(f"MSE: {ac_mse}")
print(f"ECP: {learned_in_threshold_perc}")


print("MVPE results: (average over groups)")
print(f"ACL: {mv_acl}") 
print(f"MSE: {mv_mse}")
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: (average over groups)
ACL: 0.6494026877066001
MSE: 0.020591571927070618
ECP: 0.98
MVPE results: (average over groups)
ACL: 0.991223332284434
MSE: 12.988564491271973
ECP: 0.0


# Reports that seperate two groups

In [10]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD

learned_in_threshold_perc = np.sum(learned_in_threshold, axis=0) / learned_in_threshold.shape[0]
ng_in_threshold_perc = np.sum(ng_in_threshold, axis=0) / ng_in_threshold.shape[0]

print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)

if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= [0.97 0.99] , ng_in_threshold_perc= [0. 0.]


In [11]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list, axis=0)
ac_mse = np.mean((mu_learned_list - v_truth)**2, axis=0)

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list, axis=0)
mv_mse = np.mean( (mu_ng_list - v_truth)**2, axis=0 )

In [13]:
print("ACPE results: Group1, Group 2")
print(f"MSE: {ac_mse}")
print(f"ACL: {ac_acl}") 
print(f"ECP: {learned_in_threshold_perc}")

print("===")
print("MVPE results: ")
print(f"MSE: {mv_mse}")
print(f"ACL: {mv_acl}") 
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: Group1, Group 2
MSE: [0.01998241 0.02120073]
ACL: [0.64936136 0.64944401]
ECP: [0.97 0.99]
===
MVPE results: 
MSE: [12.82583  13.151297]
ACL: [0.99122333]
ECP: [0. 0.]
