In [1]:
import copy
import pickle
from datetime import datetime as dtdt

import attrs
import numpy as np

from hetero.config import DTYPE, AlgoConfig, DataGenConfig, GroupingConfig
from hetero.datagen import generate_data_from_config
from hetero.policies import AlternativePolicy
from hetero.tasks import (
    beta_estimate_from,
    beta_estimate_from_e2e_learning,
    beta_estimate_from_nongrouped,
    compute_UV_truths,
    compute_V_estimate,
)

group_reward_coeff = np.array(
    [
        [-2, 2, 2, -2],
        [2, -2, -2, 2],
    ],
    dtype=DTYPE,
)

action_reward_coeff = [-1, 1]

FEATURE_TYPE = "LEGENDRE"
TRANS = "NORMCDF"
NOISE = "STUDENT"

data_config_init = dict(
    num_trajectories=50,
    num_time_steps=40,
    group_reward_coeff_override=group_reward_coeff,
    action_reward_coeff=action_reward_coeff,
    num_burnin_steps=100,
    basis_expansion_method=FEATURE_TYPE,
    transformation_method=TRANS,
    add_intercept_column=True,
    noise_type=NOISE,
    noise_student_degree=4,
)

# First time runner: set COMPUTE_TRUTH = True
# Change the flag to False after generated the truth file.
COMPUTE_TRUTH = False #True
# Change the truth file name if settings are changed.
TRUTH_FILE = f"hetero/data/{FEATURE_TYPE}_{TRANS}_{NOISE}_truth_20230528_2.68_2.89.pkl"
print("truth file name =", TRUTH_FILE)

time_tag = dtdt.now().strftime("%Y%m%d_%H-%M-%S")
tag = f'N={data_config_init["num_trajectories"]}_T={data_config_init["num_time_steps"]}_{time_tag}'
RESULT_FILE = f"hetero/data/{FEATURE_TYPE}_{TRANS}_{NOISE}_result_20230528_2.68_2.89_{tag}.pkl"
print("result file name =", RESULT_FILE)

SAVE_RESULT = True
if not SAVE_RESULT:
    print("Result will NOT be saved, only use this for experimental runs!!!")

NUM_EXPERIMENTS = 100

truth file name = hetero/data/LEGENDRE_NORMCDF_STUDENT_truth_20230528_2.68_2.89.pkl
result file name = hetero/data/LEGENDRE_NORMCDF_STUDENT_result_20230528_2.68_2.89_N=50_T=40_20230602_19-32-03.pkl


=====================================================================================================
# Algorithm 

- Set configure below.

In [2]:

algo_config = AlgoConfig(
    max_num_iters=10,
    gam=2.7,
    lam=2.0,
    rho=2.0,
    should_remove_outlier=True,
    outlier_lower_perc=2,
    outlier_upper_perc=98,
    nu_coeff=1e-5,
    delta_coeff=1e-5,
    use_group_wise_regression_init=True,
)

pi_eval = AlternativePolicy(2)

grouping_config = GroupingConfig()

In [3]:
if COMPUTE_TRUTH:
    truth_data_config_init = copy.copy(data_config_init)
    truth_data_config_init.pop("num_trajectories")
    us, vs = compute_UV_truths(
        truth_data_config_init,
        algo_config.discount,
        pi_eval,
        num_repeats=100,
        num_trajectories=1000,
    )  # For best results, use num_repeats=10
    u_truth = us.mean(axis=0)
    v_truth = vs.mean(axis=0)

    with open(TRUTH_FILE, "wb") as f:
        pickle.dump(
            dict(
                u_truth=u_truth,
                v_truth=v_truth,
                data_config_dict=truth_data_config_init,
                algo_config_dict=attrs.asdict(algo_config),
            ),
            f,
        )
else:
    with open(TRUTH_FILE, "rb") as f:
        loaded = pickle.load(f)
        u_truth = loaded["u_truth"]
        v_truth = loaded["v_truth"]

In [4]:
beta_ng_list = []
beta_learned_list = []

for i in range(NUM_EXPERIMENTS):
    data_config = DataGenConfig(seed=7531 * (i + 1), **data_config_init)
    data = generate_data_from_config(data_config)
    beta_ng_list.append(
        beta_estimate_from_nongrouped(data, pi_eval, algo_config.discount)
    )
    beta_learned_list.append(
        beta_estimate_from_e2e_learning(data, algo_config, grouping_config, pi_eval)
    )

new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.1715187+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.1715187+0j)
MCPImpl: num_above=3070, num_below=1880
beta_solver, min eigen of left matrix = (0.1715187+0j)
MCPImpl: num_above=4241, num_below=709
beta_solver, min eigen of left matrix = (0.1715187+0j)
MCPImpl: num_above=2910, num_below=2040
beta_solver, min eigen of left matrix = (0.1715187+0j)
MCPImpl: num_above=4234, num_below=716
beta_solver, min eigen of left matrix = (0.1715187+0j)
MCPImpl: num_above=2869, num_below=2081
beta_solver, min eigen of left matrix = (0.1715187+0j)
MCPImpl: num_above=4477, num_below=473
beta_solver, min eigen of left matrix = (0.1715187+0j)
MCPImpl: num_above=3885, num_below=1065
beta_solver, min eigen of left matrix = (0.1715187+0j)
MCPImpl: num_above=3788, num_below=1162
beta_solver, min eigen of left matrix = (0.1715187

MCPImpl: num_above=2844, num_below=2106
kmeans center = [[ 1.24932054 -2.82718807 -2.79209861  2.94109569  1.36312233  2.6744919
  -1.04185377 -2.81598069  2.76825365  0.9292674 ]
 [-0.91733271  2.75940578  2.7976736  -2.62609678 -1.80615873 -2.86927613
   1.12027924  2.69822995 -2.78118248 -0.92977694]] and inertia = 751.8871410025682
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.17945807+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.17945807+0j)
MCPImpl: num_above=2965, num_below=1985
beta_solver, min eigen of left matrix = (0.17945807+0j)
MCPImpl: num_above=4504, num_below=446
beta_solver, min eigen of left matrix = (0.17945807+0j)
MCPImpl: num_above=3816, num_below=1134
beta_solver, min eigen of left matrix = (0.17945807+0j)
MCPImpl: 

beta_solver, min eigen of left matrix = 0.18091357
MCPImpl: num_above=3894, num_below=1056
beta_solver, min eigen of left matrix = 0.18091357
MCPImpl: num_above=3730, num_below=1220
beta_solver, min eigen of left matrix = 0.18091357
MCPImpl: num_above=3234, num_below=1716
beta_solver, min eigen of left matrix = 0.18091357
MCPImpl: num_above=3474, num_below=1476
kmeans center = [[ 0.98139477 -2.86701581 -2.75357925  2.83037656  1.59766343  2.98426768
  -0.76277391 -2.82305129  2.80836641  0.65725853]
 [-1.18578244  2.85145368  2.63246926 -2.82945869 -1.40042692 -2.66900345
   1.11337414  2.67178985 -2.69789805 -0.88081912]] and inertia = 1867.2170545559104
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.16868652+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of 

MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.17932123
MCPImpl: num_above=2970, num_below=1980
beta_solver, min eigen of left matrix = 0.17932123
MCPImpl: num_above=4459, num_below=491
beta_solver, min eigen of left matrix = 0.17932123
MCPImpl: num_above=3779, num_below=1171
beta_solver, min eigen of left matrix = 0.17932123
MCPImpl: num_above=3515, num_below=1435
beta_solver, min eigen of left matrix = 0.17932123
MCPImpl: num_above=3145, num_below=1805
beta_solver, min eigen of left matrix = 0.17932123
MCPImpl: num_above=3761, num_below=1189
beta_solver, min eigen of left matrix = 0.17932123
MCPImpl: num_above=4533, num_below=417
beta_solver, min eigen of left matrix = 0.17932123
MCPImpl: num_above=3768, num_below=1182
beta_solver, min eigen of left matrix = 0.17932123
MCPImpl: num_above=3067, num_below=1883
kmeans center = [[ 0.89868336 -2.69405747 -2.75859051  2.60348623  1.70198592  2.92065631
  -1.07224468 -2.63460289  2.46982279  0.86722953]
 [

new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.17460462+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.17460462+0j)
MCPImpl: num_above=2874, num_below=2076
beta_solver, min eigen of left matrix = (0.17460462+0j)
MCPImpl: num_above=4570, num_below=380
beta_solver, min eigen of left matrix = (0.17460462+0j)
MCPImpl: num_above=4212, num_below=738
beta_solver, min eigen of left matrix = (0.17460462+0j)
MCPImpl: num_above=2878, num_below=2072
beta_solver, min eigen of left matrix = (0.17460462+0j)
MCPImpl: num_above=4166, num_below=784
beta_solver, min eigen of left matrix = (0.17460462+0j)
MCPImpl: num_above=2943, num_below=2007
beta_solver, min eigen of left matrix = (0.17460462+0j)
MCPImpl: num_above=4317, num_below=633
beta_solver, min eigen of left matrix = (0.17460462+0j)
MCPImpl: num_above=3711, num_below=1239
beta_solver, min eigen of left matrix = (0

new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = 0.17628823
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.17628823
MCPImpl: num_above=3065, num_below=1885
beta_solver, min eigen of left matrix = 0.17628823
MCPImpl: num_above=4307, num_below=643
beta_solver, min eigen of left matrix = 0.17628823
MCPImpl: num_above=3303, num_below=1647
beta_solver, min eigen of left matrix = 0.17628823
MCPImpl: num_above=3755, num_below=1195
beta_solver, min eigen of left matrix = 0.17628823
MCPImpl: num_above=4256, num_below=694
beta_solver, min eigen of left matrix = 0.17628823
MCPImpl: num_above=3162, num_below=1788
beta_solver, min eigen of left matrix = 0.17628823
MCPImpl: num_above=4028, num_below=922
beta_solver, min eigen of left matrix = 0.17628823
MCPImpl: num_above=2852, num_below=2098
beta_solver, min eigen of left matrix = 0.17628823
MCPImpl: num_above=4304, num_below=

new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = 0.17853503
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.17853503
MCPImpl: num_above=3131, num_below=1819
beta_solver, min eigen of left matrix = 0.17853503
MCPImpl: num_above=4131, num_below=819
beta_solver, min eigen of left matrix = 0.17853503
MCPImpl: num_above=2834, num_below=2116
beta_solver, min eigen of left matrix = 0.17853503
MCPImpl: num_above=4143, num_below=807
beta_solver, min eigen of left matrix = 0.17853503
MCPImpl: num_above=3377, num_below=1573
beta_solver, min eigen of left matrix = 0.17853503
MCPImpl: num_above=3783, num_below=1167
beta_solver, min eigen of left matrix = 0.17853503
MCPImpl: num_above=3426, num_below=1524
beta_solver, min eigen of left matrix = 0.17853503
MCPImpl: num_above=3438, num_below=1512
beta_solver, min eigen of left matrix = 0.17853503
MCPImpl: num_above=4025, num_below

beta_solver, min eigen of left matrix = (0.17687562+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.17687562+0j)
MCPImpl: num_above=3021, num_below=1929
beta_solver, min eigen of left matrix = (0.17687562+0j)
MCPImpl: num_above=4384, num_below=566
beta_solver, min eigen of left matrix = (0.17687562+0j)
MCPImpl: num_above=3460, num_below=1490
beta_solver, min eigen of left matrix = (0.17687562+0j)
MCPImpl: num_above=3680, num_below=1270
beta_solver, min eigen of left matrix = (0.17687562+0j)
MCPImpl: num_above=3764, num_below=1186
beta_solver, min eigen of left matrix = (0.17687562+0j)
MCPImpl: num_above=3410, num_below=1540
beta_solver, min eigen of left matrix = (0.17687562+0j)
MCPImpl: num_above=3540, num_below=1410
beta_solver, min eigen of left matrix = (0.17687562+0j)
MCPImpl: num_above=4420, num_below=530
beta_solver, min eigen of left matrix = (0.17687562+0j)
MCPImpl: num_above=3833, num_below=1117
kmeans center = [[ 0.83491963 -2.8044636  

new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = 0.17393605
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.17393605
MCPImpl: num_above=3248, num_below=1702
beta_solver, min eigen of left matrix = 0.17393605
MCPImpl: num_above=3708, num_below=1242
beta_solver, min eigen of left matrix = 0.17393605
MCPImpl: num_above=2792, num_below=2158
beta_solver, min eigen of left matrix = 0.17393605
MCPImpl: num_above=4346, num_below=604
beta_solver, min eigen of left matrix = 0.17393605
MCPImpl: num_above=3351, num_below=1599
beta_solver, min eigen of left matrix = 0.17393605
MCPImpl: num_above=3673, num_below=1277
beta_solver, min eigen of left matrix = 0.17393605
MCPImpl: num_above=4194, num_below=756
beta_solver, min eigen of left matrix = 0.17393605
MCPImpl: num_above=3273, num_below=1677
beta_solver, min eigen of left matrix = 0.17393605
MCPImpl: num_above=3626, num_below

MCPImpl: num_above=3969, num_below=981
kmeans center = [[ 1.30048138 -2.52468783 -2.76489186  3.08501875  1.10464884  2.72630266
  -1.18701262 -3.03506825  2.91218115  0.88764685]
 [-1.07391593  2.97667259  2.72650614 -2.41288358 -1.81074864 -3.01135428
   0.90768196  2.69309451 -2.71050743 -0.60408163]] and inertia = 2802.5278479623366
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = 0.16914266
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.16914266
MCPImpl: num_above=3171, num_below=1779
beta_solver, min eigen of left matrix = 0.16914266
MCPImpl: num_above=4093, num_below=857
beta_solver, min eigen of left matrix = 0.16914266
MCPImpl: num_above=2696, num_below=2254
beta_solver, min eigen of left matrix = 0.16914266
MCPImpl: num_above=4535, num_belo

beta_solver, min eigen of left matrix = (0.17660056+0j)
MCPImpl: num_above=3457, num_below=1493
beta_solver, min eigen of left matrix = (0.17660056+0j)
MCPImpl: num_above=3784, num_below=1166
beta_solver, min eigen of left matrix = (0.17660056+0j)
MCPImpl: num_above=3297, num_below=1653
beta_solver, min eigen of left matrix = (0.17660056+0j)
MCPImpl: num_above=3615, num_below=1335
kmeans center = [[ 1.00248078 -3.08280009 -2.66174499  2.55125648  1.77399835  2.96334222
  -0.86137976 -2.79967408  2.60734498  0.79847221]
 [-1.14816806  2.51594748  2.28692491 -2.69784361 -1.31810532 -2.96330984
   1.22517769  2.34077709 -2.71014553 -0.77047925]] and inertia = 2370.893791263567
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.17172308+0j)
MCPImpl: num_above=2500, num_below=2450
beta_so

MCPImpl: num_above=3073, num_below=1877
beta_solver, min eigen of left matrix = 0.17454016
MCPImpl: num_above=4297, num_below=653
beta_solver, min eigen of left matrix = 0.17454016
MCPImpl: num_above=3052, num_below=1898
beta_solver, min eigen of left matrix = 0.17454016
MCPImpl: num_above=3815, num_below=1135
beta_solver, min eigen of left matrix = 0.17454016
MCPImpl: num_above=2793, num_below=2157
beta_solver, min eigen of left matrix = 0.17454016
MCPImpl: num_above=4312, num_below=638
beta_solver, min eigen of left matrix = 0.17454016
MCPImpl: num_above=3592, num_below=1358
beta_solver, min eigen of left matrix = 0.17454016
MCPImpl: num_above=4279, num_below=671
beta_solver, min eigen of left matrix = 0.17454016
MCPImpl: num_above=3751, num_below=1199
kmeans center = [[ 1.09568572 -2.99208311 -2.63306063  2.98962444  1.49454037  2.74859002
  -0.83778074 -2.90767139  2.67164467  0.89296014]
 [-1.03990145  2.98044816  2.48702453 -2.6200784  -1.68551685 -2.86517374
   1.0790534   2.705

beta_solver, min eigen of left matrix = 0.17695455
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.17695455
MCPImpl: num_above=3052, num_below=1898
beta_solver, min eigen of left matrix = 0.17695455
MCPImpl: num_above=4266, num_below=684
beta_solver, min eigen of left matrix = 0.17695455
MCPImpl: num_above=3034, num_below=1916
beta_solver, min eigen of left matrix = 0.17695455
MCPImpl: num_above=3872, num_below=1078
beta_solver, min eigen of left matrix = 0.17695455
MCPImpl: num_above=3766, num_below=1184
beta_solver, min eigen of left matrix = 0.17695455
MCPImpl: num_above=2888, num_below=2062
beta_solver, min eigen of left matrix = 0.17695455
MCPImpl: num_above=4159, num_below=791
beta_solver, min eigen of left matrix = 0.17695455
MCPImpl: num_above=3177, num_below=1773
beta_solver, min eigen of left matrix = 0.17695455
MCPImpl: num_above=4054, num_below=896
kmeans center = [[ 0.85919744 -2.96964413 -2.83662817  2.70654648  1.81922872  3.07997835
  -

MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.17882091
MCPImpl: num_above=3024, num_below=1926
beta_solver, min eigen of left matrix = 0.17882091
MCPImpl: num_above=4416, num_below=534
beta_solver, min eigen of left matrix = 0.17882091
MCPImpl: num_above=3454, num_below=1496
beta_solver, min eigen of left matrix = 0.17882091
MCPImpl: num_above=3038, num_below=1912
beta_solver, min eigen of left matrix = 0.17882091
MCPImpl: num_above=4017, num_below=933
beta_solver, min eigen of left matrix = 0.17882091
MCPImpl: num_above=3226, num_below=1724
beta_solver, min eigen of left matrix = 0.17882091
MCPImpl: num_above=3882, num_below=1068
beta_solver, min eigen of left matrix = 0.17882091
MCPImpl: num_above=2913, num_below=2037
beta_solver, min eigen of left matrix = 0.17882091
MCPImpl: num_above=4319, num_below=631
kmeans center = [[ 0.96889602 -2.81296726 -2.94488807  2.7634455   1.70937836  3.03470062
  -0.88893512 -2.66595574  2.65563893  0.72180144]
 [-

new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.1707096+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.1707096+0j)
MCPImpl: num_above=3131, num_below=1819
beta_solver, min eigen of left matrix = (0.1707096+0j)
MCPImpl: num_above=4073, num_below=877
beta_solver, min eigen of left matrix = (0.1707096+0j)
MCPImpl: num_above=2805, num_below=2145
beta_solver, min eigen of left matrix = (0.1707096+0j)
MCPImpl: num_above=4269, num_below=681
beta_solver, min eigen of left matrix = (0.1707096+0j)
MCPImpl: num_above=3204, num_below=1746
beta_solver, min eigen of left matrix = (0.1707096+0j)
MCPImpl: num_above=3998, num_below=952
beta_solver, min eigen of left matrix = (0.1707096+0j)
MCPImpl: num_above=3491, num_below=1459
beta_solver, min eigen of left matrix = (0.1707096+0j)
MCPImpl: num_above=3775, num_below=1175
beta_solver, min eigen of left matrix = (0.1707096

beta_solver, min eigen of left matrix = (0.17487864+0j)
MCPImpl: num_above=4090, num_below=860
kmeans center = [[ 1.17744572 -2.84152934 -2.35303051  2.69898005  1.44151573  2.61395737
  -0.86026576 -2.38702615  2.53990238  0.74573562]
 [-1.17304868  2.78941394  2.8890724  -2.63958724 -1.59586488 -3.00548881
   1.23192634  2.91262911 -2.98616876 -0.76390951]] and inertia = 2972.5466948410794
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.17557034+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.17557034+0j)
MCPImpl: num_above=3029, num_below=1921
beta_solver, min eigen of left matrix = (0.17557034+0j)
MCPImpl: num_above=4299, num_below=651
beta_solver, min eigen of left matrix = (0.17557034+0j)
MCPImpl: num_above=3347, num_below=1603
beta_sol

beta_solver, min eigen of left matrix = 0.17528743
MCPImpl: num_above=4226, num_below=724
beta_solver, min eigen of left matrix = 0.17528743
MCPImpl: num_above=3494, num_below=1456
beta_solver, min eigen of left matrix = 0.17528743
MCPImpl: num_above=3590, num_below=1360
beta_solver, min eigen of left matrix = 0.17528743
MCPImpl: num_above=3757, num_below=1193
beta_solver, min eigen of left matrix = 0.17528743
MCPImpl: num_above=3804, num_below=1146
kmeans center = [[ 1.19626962 -2.95792699 -2.5559038   2.85856765  1.43939252  2.60618215
  -1.10807303 -2.62141712  2.85277519  0.87734842]
 [-0.91443849  2.53483487  2.67315715 -2.55799032 -1.59000703 -2.7412568
   1.1730438   2.60362693 -2.70570527 -0.940578  ]] and inertia = 2218.188650689573
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left mat

MCPImpl: num_above=2961, num_below=1989
beta_solver, min eigen of left matrix = 0.1808616
MCPImpl: num_above=4391, num_below=559
beta_solver, min eigen of left matrix = 0.1808616
MCPImpl: num_above=3572, num_below=1378
beta_solver, min eigen of left matrix = 0.1808616
MCPImpl: num_above=4126, num_below=824
beta_solver, min eigen of left matrix = 0.1808616
MCPImpl: num_above=3359, num_below=1591
beta_solver, min eigen of left matrix = 0.1808616
MCPImpl: num_above=3604, num_below=1346
beta_solver, min eigen of left matrix = 0.1808616
MCPImpl: num_above=3994, num_below=956
beta_solver, min eigen of left matrix = 0.1808616
MCPImpl: num_above=3381, num_below=1569
beta_solver, min eigen of left matrix = 0.1808616
MCPImpl: num_above=3496, num_below=1454
kmeans center = [[ 0.9365491  -2.83502777 -2.79194134  2.64704441  1.73844499  3.02340684
  -1.04585075 -2.70368467  2.52439375  0.84609484]
 [-1.06445185  2.79318166  2.8790254  -2.74347398 -1.79647481 -2.86427185
   0.96261569  2.92576604 -2

In [5]:
beta_learned_list[0].betas

[array([-1.0582747 ,  3.01105287,  2.89693156, -2.92761779, -1.73989323,
        -2.93064042,  1.13319729,  2.96846967, -2.89111097, -0.88097648]),
 array([ 1.14704112, -3.02877021, -2.75908766,  2.71320658,  1.70528253,
         2.96881408, -1.08109668, -2.84457865,  2.85411481,  0.82207178])]

In [6]:
mu_learned_list = []
sigma_learned_list = []
z_score_learned_list = []
mu_ng_list = []
sigma_ng_list = []
z_score_ng_list = []

for beta_learned, beta_ng in zip(beta_learned_list, beta_ng_list):
    v_mus, v_sigmas = compute_V_estimate(u_truth, beta_learned)
    z_score_learned = [
        (mu - truth) / sigma for mu, truth, sigma in zip(v_mus, v_truth, v_sigmas)
    ]
    mu_learned_list.append(v_mus)
    sigma_learned_list.append(v_sigmas)
    z_score_learned_list.append(z_score_learned)

    ngv_mus, ngv_sigmas = compute_V_estimate(u_truth, beta_ng)
    z_score_ng = [
        (mu - truth) / sigma
        for mu, truth, sigma in zip(
            ngv_mus * len(v_truth), v_truth, ngv_sigmas * len(v_truth)
        )
    ]
    mu_ng_list.append(ngv_mus)
    sigma_ng_list.append(ngv_sigmas)
    z_score_ng_list.append(z_score_ng)

# Reports that average over two groups

In [7]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD
learned_in_threshold_perc = learned_in_threshold.sum() / learned_in_threshold.size
ng_in_threshold_perc = ng_in_threshold.sum() / ng_in_threshold.size
print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)
if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= 0.935 , ng_in_threshold_perc= 0.0


In [8]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list)
ac_mse = np.mean( (mu_learned_list - v_truth)**2 )

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list)
mv_mse = np.mean( (mu_ng_list - v_truth)**2 )

In [9]:
print("ACPE results: (average over groups)")
print(f"ACL: {ac_acl}") 
print(f"MSE: {ac_mse}")
print(f"ECP: {learned_in_threshold_perc}")


print("MVPE results: (average over groups)")
print(f"ACL: {mv_acl}") 
print(f"MSE: {mv_mse}")
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: (average over groups)
ACL: 0.1632240695346381
MSE: 0.0018388645059548494
ECP: 0.935
MVPE results: (average over groups)
ACL: 0.20380290685271865
MSE: 1.5527812986190357
ECP: 0.0


# Reports that seperate two groups

In [10]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD

learned_in_threshold_perc = np.sum(learned_in_threshold, axis=0) / learned_in_threshold.shape[0]
ng_in_threshold_perc = np.sum(ng_in_threshold, axis=0) / ng_in_threshold.shape[0]

print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)

if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= [0.92 0.95] , ng_in_threshold_perc= [0. 0.]


In [11]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list, axis=0)
ac_mse = np.mean((mu_learned_list - v_truth)**2, axis=0)

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list, axis=0)
mv_mse = np.mean( (mu_ng_list - v_truth)**2, axis=0 )

In [12]:
print("ACPE results: Group1, Group 2")
print(f"MSE: {ac_mse}")
print(f"ACL: {ac_acl}") 
print(f"ECP: {learned_in_threshold_perc}")

print("===")
print("MVPE results: ")
print(f"MSE: {mv_mse}")
print(f"ACL: {mv_acl}") 
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: Group1, Group 2
MSE: [0.00200633 0.0016714 ]
ACL: [0.16296566 0.16348248]
ECP: [0.92 0.95]
===
MVPE results: 
MSE: [1.50706405 1.59849855]
ACL: [0.20380291]
ECP: [0. 0.]
