In [1]:
import copy
import pickle
from datetime import datetime as dtdt

import attrs
import numpy as np

from hetero.config import DTYPE, AlgoConfig, DataGenConfig, GroupingConfig
from hetero.datagen import generate_data_from_config
from hetero.policies import AlternativePolicy
from hetero.tasks import (
    beta_estimate_from,
    beta_estimate_from_e2e_learning,
    beta_estimate_from_nongrouped,
    compute_UV_truths,
    compute_V_estimate,
)

group_reward_coeff = np.array(
    [
        [-2, 2, 2, -2],
        [2, -2, -2, 2],
    ],
    dtype=DTYPE,
)

action_reward_coeff = [-1, 1]

FEATURE_TYPE = "LEGENDRE"
TRANS = "NORMCDF"
NOISE = "STUDENT"

data_config_init = dict(
    num_trajectories=100,
    num_time_steps=30,
    group_reward_coeff_override=group_reward_coeff,
    action_reward_coeff=action_reward_coeff,
    num_burnin_steps=100,
    basis_expansion_method=FEATURE_TYPE,
    transformation_method=TRANS,
    add_intercept_column=True,
    noise_type=NOISE,
    noise_student_degree=4,
)

# First time runner: set COMPUTE_TRUTH = True
# Change the flag to False after generated the truth file.
COMPUTE_TRUTH = False #True
# Change the truth file name if settings are changed.
TRUTH_FILE = f"hetero/data/{FEATURE_TYPE}_{TRANS}_{NOISE}_truth_20230528_2.68_2.89.pkl"
print("truth file name =", TRUTH_FILE)

time_tag = dtdt.now().strftime("%Y%m%d_%H-%M-%S")
tag = f'N={data_config_init["num_trajectories"]}_T={data_config_init["num_time_steps"]}_{time_tag}'
RESULT_FILE = f"hetero/data/{FEATURE_TYPE}_{TRANS}_{NOISE}_result_20230528_2.68_2.89_{tag}.pkl"
print("result file name =", RESULT_FILE)

SAVE_RESULT = True
if not SAVE_RESULT:
    print("Result will NOT be saved, only use this for experimental runs!!!")

NUM_EXPERIMENTS = 100

truth file name = hetero/data/LEGENDRE_NORMCDF_STUDENT_truth_20230528_2.68_2.89.pkl
result file name = hetero/data/LEGENDRE_NORMCDF_STUDENT_result_20230528_2.68_2.89_N=100_T=30_20230602_19-36-42.pkl


=====================================================================================================
# Algorithm 

- Set configure below.

In [2]:

algo_config = AlgoConfig(
    max_num_iters=10,
    gam=2.7,
    lam=2.0,
    rho=2.0,
    should_remove_outlier=True,
    outlier_lower_perc=2,
    outlier_upper_perc=98,
    nu_coeff=1e-5,
    delta_coeff=1e-5,
    use_group_wise_regression_init=True,
)

pi_eval = AlternativePolicy(2)

grouping_config = GroupingConfig()

In [3]:
if COMPUTE_TRUTH:
    truth_data_config_init = copy.copy(data_config_init)
    truth_data_config_init.pop("num_trajectories")
    us, vs = compute_UV_truths(
        truth_data_config_init,
        algo_config.discount,
        pi_eval,
        num_repeats=100,
        num_trajectories=1000,
    )  # For best results, use num_repeats=10
    u_truth = us.mean(axis=0)
    v_truth = vs.mean(axis=0)

    with open(TRUTH_FILE, "wb") as f:
        pickle.dump(
            dict(
                u_truth=u_truth,
                v_truth=v_truth,
                data_config_dict=truth_data_config_init,
                algo_config_dict=attrs.asdict(algo_config),
            ),
            f,
        )
else:
    with open(TRUTH_FILE, "rb") as f:
        loaded = pickle.load(f)
        u_truth = loaded["u_truth"]
        v_truth = loaded["v_truth"]

In [4]:
beta_ng_list = []
beta_learned_list = []

for i in range(NUM_EXPERIMENTS):
    data_config = DataGenConfig(seed=7531 * (i + 1), **data_config_init)
    data = generate_data_from_config(data_config)
    beta_ng_list.append(
        beta_estimate_from_nongrouped(data, pi_eval, algo_config.discount)
    )
    beta_learned_list.append(
        beta_estimate_from_e2e_learning(data, algo_config, grouping_config, pi_eval)
    )

new_labels.length=6000 matches number of records
new_labels.length=6000 matches number of records
beta_solver, min eigen of left matrix = 0.11791802
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.11791802
MCPImpl: num_above=12152, num_below=7748
beta_solver, min eigen of left matrix = 0.11791802
MCPImpl: num_above=18069, num_below=1831
beta_solver, min eigen of left matrix = 0.11791802
MCPImpl: num_above=15803, num_below=4097
beta_solver, min eigen of left matrix = 0.11791802
MCPImpl: num_above=12989, num_below=6911
beta_solver, min eigen of left matrix = 0.11791802
MCPImpl: num_above=15012, num_below=4888
beta_solver, min eigen of left matrix = 0.11791802
MCPImpl: num_above=11014, num_below=8886
beta_solver, min eigen of left matrix = 0.11791802
MCPImpl: num_above=18630, num_below=1270
beta_solver, min eigen of left matrix = 0.11791802
MCPImpl: num_above=16952, num_below=2948
beta_solver, min eigen of left matrix = 0.11791802
MCPImpl: num_above=1554

new_labels.length=6000 matches number of records
new_labels.length=6000 matches number of records
beta_solver, min eigen of left matrix = 0.1250632
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.1250632
MCPImpl: num_above=12114, num_below=7786
beta_solver, min eigen of left matrix = 0.1250632
MCPImpl: num_above=18186, num_below=1714
beta_solver, min eigen of left matrix = 0.1250632
MCPImpl: num_above=16205, num_below=3695
beta_solver, min eigen of left matrix = 0.1250632
MCPImpl: num_above=11615, num_below=8285
beta_solver, min eigen of left matrix = 0.1250632
MCPImpl: num_above=18077, num_below=1823
beta_solver, min eigen of left matrix = 0.1250632
MCPImpl: num_above=14767, num_below=5133
beta_solver, min eigen of left matrix = 0.1250632
MCPImpl: num_above=15397, num_below=4503
beta_solver, min eigen of left matrix = 0.1250632
MCPImpl: num_above=14595, num_below=5305
beta_solver, min eigen of left matrix = 0.1250632
MCPImpl: num_above=12531, num_bel

new_labels.length=6000 matches number of records
new_labels.length=6000 matches number of records
beta_solver, min eigen of left matrix = 0.12210948
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.12210948
MCPImpl: num_above=12337, num_below=7563
beta_solver, min eigen of left matrix = 0.12210948
MCPImpl: num_above=17623, num_below=2277
beta_solver, min eigen of left matrix = 0.12210948
MCPImpl: num_above=13865, num_below=6035
beta_solver, min eigen of left matrix = 0.12210948
MCPImpl: num_above=13827, num_below=6073
beta_solver, min eigen of left matrix = 0.12210948
MCPImpl: num_above=15519, num_below=4381
beta_solver, min eigen of left matrix = 0.12210948
MCPImpl: num_above=18668, num_below=1232
beta_solver, min eigen of left matrix = 0.12210948
MCPImpl: num_above=17035, num_below=2865
beta_solver, min eigen of left matrix = 0.12210948
MCPImpl: num_above=14002, num_below=5898
beta_solver, min eigen of left matrix = 0.12210948
MCPImpl: num_above=1508

new_labels.length=6000 matches number of records
new_labels.length=6000 matches number of records
beta_solver, min eigen of left matrix = 0.12320886
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.12320886
MCPImpl: num_above=12193, num_below=7707
beta_solver, min eigen of left matrix = 0.12320886
MCPImpl: num_above=18137, num_below=1763
beta_solver, min eigen of left matrix = 0.12320886
MCPImpl: num_above=16395, num_below=3505
beta_solver, min eigen of left matrix = 0.12320886
MCPImpl: num_above=10882, num_below=9018
beta_solver, min eigen of left matrix = 0.12320886
MCPImpl: num_above=18983, num_below=917
beta_solver, min eigen of left matrix = 0.12320886
MCPImpl: num_above=17906, num_below=1994
beta_solver, min eigen of left matrix = 0.12320886
MCPImpl: num_above=15272, num_below=4628
beta_solver, min eigen of left matrix = 0.12320886
MCPImpl: num_above=14833, num_below=5067
beta_solver, min eigen of left matrix = 0.12320886
MCPImpl: num_above=13188

MCPImpl: num_above=12648, num_below=7252
kmeans center = [[ 1.02139641 -2.7356946  -2.76910046  2.75470023  1.57500735  2.81704792
  -1.00718715 -2.8263286   2.77579031  0.83118302]
 [-1.00981441  3.03549272  2.58873387 -2.7729535  -1.69333845 -2.95166206
   0.88746515  2.63002793 -2.61563223 -0.68395888]] and inertia = 2435.530869543477
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=6000 matches number of records
new_labels.length=6000 matches number of records
new_labels.length=6000 matches number of records
beta_solver, min eigen of left matrix = 0.11962985
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.11962985
MCPImpl: num_above=12459, num_below=7441
beta_solver, min eigen of left matrix = 0.11962985
MCPImpl: num_above=17298, num_below=2602
beta_solver, min eigen of left matrix = 0.11962985
MCPImpl: num_above=13839, num_below=6061
beta_solver, min eigen of left matrix = 0.11962985
MCPImpl: num_above=15557, n

MCPImpl: num_above=11940, num_below=7960
beta_solver, min eigen of left matrix = (0.124578245+0j)
MCPImpl: num_above=18497, num_below=1403
beta_solver, min eigen of left matrix = (0.124578245+0j)
MCPImpl: num_above=17238, num_below=2662
beta_solver, min eigen of left matrix = (0.124578245+0j)
MCPImpl: num_above=14244, num_below=5656
beta_solver, min eigen of left matrix = (0.124578245+0j)
MCPImpl: num_above=14517, num_below=5383
kmeans center = [[ 0.95888564 -2.78767648 -2.55580628  2.75950487  1.58008538  2.84223203
  -1.05209233 -2.57810374  2.91062485  0.72469289]
 [-1.02505082  3.09486087  2.84291666 -2.77605069 -1.84021479 -2.87756526
   0.90828329  2.81199038 -2.72611207 -0.81657926]] and inertia = 3877.2751701748543
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=6000 matches number of records
new_labels.length=6000 matches number of records
new_labels.length=6000 matches number of records
beta_solver, min eigen of left matrix = (0.12196571+0

new_labels.length=6000 matches number of records
new_labels.length=6000 matches number of records
beta_solver, min eigen of left matrix = 0.124633126
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.124633126
MCPImpl: num_above=11820, num_below=8080
beta_solver, min eigen of left matrix = 0.124633126
MCPImpl: num_above=18535, num_below=1365
beta_solver, min eigen of left matrix = 0.124633126
MCPImpl: num_above=17813, num_below=2087
beta_solver, min eigen of left matrix = 0.124633126
MCPImpl: num_above=12947, num_below=6953
beta_solver, min eigen of left matrix = 0.124633126
MCPImpl: num_above=15209, num_below=4691
beta_solver, min eigen of left matrix = 0.124633126
MCPImpl: num_above=18125, num_below=1775
beta_solver, min eigen of left matrix = 0.124633126
MCPImpl: num_above=15145, num_below=4755
beta_solver, min eigen of left matrix = 0.124633126
MCPImpl: num_above=14812, num_below=5088
beta_solver, min eigen of left matrix = 0.124633126
MCPImpl: num_

MCPImpl: num_above=12864, num_below=7036
kmeans center = [[ 1.09562226 -2.88198013 -2.6648878   2.7701635   1.48762596  2.81749051
  -0.91888578 -2.78367887  2.5116108   0.86651786]
 [-1.0816266   2.94760772  2.63271289 -2.57843933 -1.73351981 -2.87468827
   0.98736125  2.75667492 -2.77789666 -0.81556117]] and inertia = 2639.5624973317335
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=6000 matches number of records
new_labels.length=6000 matches number of records
new_labels.length=6000 matches number of records
beta_solver, min eigen of left matrix = (0.12053244+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.12053244+0j)
MCPImpl: num_above=11935, num_below=7965
beta_solver, min eigen of left matrix = (0.12053244+0j)
MCPImpl: num_above=18365, num_below=1535
beta_solver, min eigen of left matrix = (0.12053244+0j)
MCPImpl: num_above=17293, num_below=2607
beta_solver, min eigen of left matrix = (0.12053244+0j)
M

MCPImpl: num_above=12732, num_below=7168
beta_solver, min eigen of left matrix = (0.123594694+0j)
MCPImpl: num_above=16287, num_below=3613
beta_solver, min eigen of left matrix = (0.123594694+0j)
MCPImpl: num_above=13202, num_below=6698
beta_solver, min eigen of left matrix = (0.123594694+0j)
MCPImpl: num_above=16131, num_below=3769
beta_solver, min eigen of left matrix = (0.123594694+0j)
MCPImpl: num_above=12468, num_below=7432
beta_solver, min eigen of left matrix = (0.123594694+0j)
MCPImpl: num_above=17319, num_below=2581
kmeans center = [[ 1.04862646 -2.70523685 -2.42023603  2.65667725  1.43044293  2.80570763
  -0.90272389 -2.78062372  2.62821255  0.88350008]
 [-1.10987385  2.79222019  2.76944012 -2.76186319 -1.70907894 -2.80241006
   1.12853601  2.53762377 -2.70078793 -0.91958923]] and inertia = 7191.755279827532
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=6000 matches number of records
new_labels.length=6000 matches number of records
new_l

MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.11917357
MCPImpl: num_above=12288, num_below=7612
beta_solver, min eigen of left matrix = 0.11917357
MCPImpl: num_above=18244, num_below=1656
beta_solver, min eigen of left matrix = 0.11917357
MCPImpl: num_above=16328, num_below=3572
beta_solver, min eigen of left matrix = 0.11917357
MCPImpl: num_above=12395, num_below=7505
beta_solver, min eigen of left matrix = 0.11917357
MCPImpl: num_above=16334, num_below=3566
beta_solver, min eigen of left matrix = 0.11917357
MCPImpl: num_above=12613, num_below=7287
beta_solver, min eigen of left matrix = 0.11917357
MCPImpl: num_above=16793, num_below=3107
beta_solver, min eigen of left matrix = 0.11917357
MCPImpl: num_above=15168, num_below=4732
beta_solver, min eigen of left matrix = 0.11917357
MCPImpl: num_above=16170, num_below=3730
kmeans center = [[ 0.95162423 -2.77419605 -2.79795591  2.72046283  1.6748206   3.03959038
  -0.92325328 -2.75946974  2.73689759  0.

new_labels.length=6000 matches number of records
new_labels.length=6000 matches number of records
beta_solver, min eigen of left matrix = 0.1241066
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.1241066
MCPImpl: num_above=12207, num_below=7693
beta_solver, min eigen of left matrix = 0.1241066
MCPImpl: num_above=17807, num_below=2093
beta_solver, min eigen of left matrix = 0.1241066
MCPImpl: num_above=14974, num_below=4926
beta_solver, min eigen of left matrix = 0.1241066
MCPImpl: num_above=15971, num_below=3929
beta_solver, min eigen of left matrix = 0.1241066
MCPImpl: num_above=12951, num_below=6949
beta_solver, min eigen of left matrix = 0.1241066
MCPImpl: num_above=16457, num_below=3443
beta_solver, min eigen of left matrix = 0.1241066
MCPImpl: num_above=12253, num_below=7647
beta_solver, min eigen of left matrix = 0.1241066
MCPImpl: num_above=17102, num_below=2798
beta_solver, min eigen of left matrix = 0.1241066
MCPImpl: num_above=12984, num_bel

MCPImpl: num_above=14430, num_below=5470
kmeans center = [[ 0.95508071 -2.7617174  -2.51831178  2.74743735  1.41044083  2.94002475
  -1.02536405 -2.52107795  2.65905786  0.64862236]
 [-1.0487521   2.76191206  2.92942688 -2.68918803 -1.67391308 -2.65631334
   1.1163835   2.90935348 -2.69030003 -1.1225886 ]] and inertia = 3987.1047413001106
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=6000 matches number of records
new_labels.length=6000 matches number of records
new_labels.length=6000 matches number of records
beta_solver, min eigen of left matrix = 0.12023358
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.12023358
MCPImpl: num_above=11892, num_below=8008
beta_solver, min eigen of left matrix = 0.12023358
MCPImpl: num_above=18467, num_below=1433
beta_solver, min eigen of left matrix = 0.12023358
MCPImpl: num_above=17516, num_below=2384
beta_solver, min eigen of left matrix = 0.12023358
MCPImpl: num_above=13404, 

MCPImpl: num_above=14289, num_below=5611
beta_solver, min eigen of left matrix = 0.12310763
MCPImpl: num_above=12437, num_below=7463
beta_solver, min eigen of left matrix = 0.12310763
MCPImpl: num_above=16429, num_below=3471
beta_solver, min eigen of left matrix = 0.12310763
MCPImpl: num_above=13969, num_below=5931
beta_solver, min eigen of left matrix = 0.12310763
MCPImpl: num_above=15897, num_below=4003
beta_solver, min eigen of left matrix = 0.12310763
MCPImpl: num_above=13309, num_below=6591
kmeans center = [[ 1.07289808 -2.96409821 -2.59289643  2.77779558  1.59405769  2.61061963
  -0.93068497 -2.65026742  2.83629863  0.87388073]
 [-1.07950588  2.63479353  2.79646333 -2.66529556 -1.53922684 -2.87415214
   1.02592017  2.92588806 -2.67342194 -0.93839218]] and inertia = 2694.2026418039436
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=6000 matches number of records
new_labels.length=6000 matches number of records
new_labels.length=6000 matches num

new_labels.length=6000 matches number of records
new_labels.length=6000 matches number of records
beta_solver, min eigen of left matrix = 0.1242341
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.1242341
MCPImpl: num_above=12064, num_below=7836
beta_solver, min eigen of left matrix = 0.1242341
MCPImpl: num_above=18173, num_below=1727
beta_solver, min eigen of left matrix = 0.1242341
MCPImpl: num_above=16188, num_below=3712
beta_solver, min eigen of left matrix = 0.1242341
MCPImpl: num_above=11482, num_below=8418
beta_solver, min eigen of left matrix = 0.1242341
MCPImpl: num_above=18061, num_below=1839
beta_solver, min eigen of left matrix = 0.1242341
MCPImpl: num_above=15438, num_below=4462
beta_solver, min eigen of left matrix = 0.1242341
MCPImpl: num_above=17640, num_below=2260
beta_solver, min eigen of left matrix = 0.1242341
MCPImpl: num_above=15611, num_below=4289
beta_solver, min eigen of left matrix = 0.1242341
MCPImpl: num_above=13210, num_bel

MCPImpl: num_above=15050, num_below=4850
beta_solver, min eigen of left matrix = 0.123035304
MCPImpl: num_above=11815, num_below=8085
kmeans center = [[ 1.04184694 -2.85171068 -2.71788343  2.8893449   1.51971901  2.71775182
  -0.96274837 -2.50135065  2.77402444  0.74899064]
 [-0.94383357  2.86118175  2.71202099 -2.59366463 -1.74278877 -2.95083071
   0.93088476  2.72913697 -2.6242079  -0.80485117]] and inertia = 2104.2349966438633
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=6000 matches number of records
new_labels.length=6000 matches number of records
new_labels.length=6000 matches number of records
beta_solver, min eigen of left matrix = (0.121908374+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.121908374+0j)
MCPImpl: num_above=12522, num_below=7378
beta_solver, min eigen of left matrix = (0.121908374+0j)
MCPImpl: num_above=17365, num_below=2535
beta_solver, min eigen of left matrix = (0.121908374+0j)
M

MCPImpl: num_above=17701, num_below=2199
beta_solver, min eigen of left matrix = 0.12397472
MCPImpl: num_above=14357, num_below=5543
beta_solver, min eigen of left matrix = 0.12397472
MCPImpl: num_above=13282, num_below=6618
beta_solver, min eigen of left matrix = 0.12397472
MCPImpl: num_above=14940, num_below=4960
beta_solver, min eigen of left matrix = 0.12397472
MCPImpl: num_above=18824, num_below=1076
beta_solver, min eigen of left matrix = 0.12397472
MCPImpl: num_above=16930, num_below=2970
beta_solver, min eigen of left matrix = 0.12397472
MCPImpl: num_above=15099, num_below=4801
beta_solver, min eigen of left matrix = 0.12397472
MCPImpl: num_above=14121, num_below=5779
kmeans center = [[ 0.87125009 -3.06478397 -2.62649612  2.76453222  1.79642331  2.84277221
  -1.0887819  -2.72759356  2.71713356  0.92800404]
 [-1.14082964  2.86576099  2.57658945 -2.75614933 -1.39975131 -2.67794317
   0.97846233  2.46331598 -2.8338766  -0.64396339]] and inertia = 3597.699440475073
Label mismatch =

new_labels.length=6000 matches number of records
new_labels.length=6000 matches number of records
beta_solver, min eigen of left matrix = 0.12143166
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.12143166
MCPImpl: num_above=12275, num_below=7625
beta_solver, min eigen of left matrix = 0.12143166
MCPImpl: num_above=17895, num_below=2005
beta_solver, min eigen of left matrix = 0.12143166
MCPImpl: num_above=15089, num_below=4811
beta_solver, min eigen of left matrix = 0.12143166
MCPImpl: num_above=11301, num_below=8599
beta_solver, min eigen of left matrix = 0.12143166
MCPImpl: num_above=18586, num_below=1314
beta_solver, min eigen of left matrix = 0.12143166
MCPImpl: num_above=17079, num_below=2821
beta_solver, min eigen of left matrix = 0.12143166
MCPImpl: num_above=13098, num_below=6802
beta_solver, min eigen of left matrix = 0.12143166
MCPImpl: num_above=16153, num_below=3747
beta_solver, min eigen of left matrix = 0.12143166
MCPImpl: num_above=1112

new_labels.length=6000 matches number of records
new_labels.length=6000 matches number of records
beta_solver, min eigen of left matrix = 0.11832493
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.11832493
MCPImpl: num_above=12408, num_below=7492
beta_solver, min eigen of left matrix = 0.11832493
MCPImpl: num_above=17558, num_below=2342
beta_solver, min eigen of left matrix = 0.11832493
MCPImpl: num_above=13325, num_below=6575
beta_solver, min eigen of left matrix = 0.11832493
MCPImpl: num_above=14590, num_below=5310
beta_solver, min eigen of left matrix = 0.11832493
MCPImpl: num_above=14742, num_below=5158
beta_solver, min eigen of left matrix = 0.11832493
MCPImpl: num_above=16782, num_below=3118
beta_solver, min eigen of left matrix = 0.11832493
MCPImpl: num_above=15531, num_below=4369
beta_solver, min eigen of left matrix = 0.11832493
MCPImpl: num_above=14445, num_below=5455
beta_solver, min eigen of left matrix = 0.11832493
MCPImpl: num_above=1368

In [5]:
beta_learned_list[0].betas

[array([-1.09837526,  2.8082234 ,  2.88110058, -2.87794663, -1.61368794,
        -2.93753647,  1.08416964,  2.85638628, -2.79561138, -0.87871883]),
 array([ 1.03968762, -2.87014253, -2.79828366,  2.90394625,  1.61533076,
         2.97185562, -1.18041133, -2.84607875,  2.96327181,  0.78605497])]

In [6]:
mu_learned_list = []
sigma_learned_list = []
z_score_learned_list = []
mu_ng_list = []
sigma_ng_list = []
z_score_ng_list = []

for beta_learned, beta_ng in zip(beta_learned_list, beta_ng_list):
    v_mus, v_sigmas = compute_V_estimate(u_truth, beta_learned)
    z_score_learned = [
        (mu - truth) / sigma for mu, truth, sigma in zip(v_mus, v_truth, v_sigmas)
    ]
    mu_learned_list.append(v_mus)
    sigma_learned_list.append(v_sigmas)
    z_score_learned_list.append(z_score_learned)

    ngv_mus, ngv_sigmas = compute_V_estimate(u_truth, beta_ng)
    z_score_ng = [
        (mu - truth) / sigma
        for mu, truth, sigma in zip(
            ngv_mus * len(v_truth), v_truth, ngv_sigmas * len(v_truth)
        )
    ]
    mu_ng_list.append(ngv_mus)
    sigma_ng_list.append(ngv_sigmas)
    z_score_ng_list.append(z_score_ng)

# Reports that average over two groups

In [7]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD
learned_in_threshold_perc = learned_in_threshold.sum() / learned_in_threshold.size
ng_in_threshold_perc = ng_in_threshold.sum() / ng_in_threshold.size
print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)
if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= 0.935 , ng_in_threshold_perc= 0.0


In [8]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list)
ac_mse = np.mean( (mu_learned_list - v_truth)**2 )

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list)
mv_mse = np.mean( (mu_ng_list - v_truth)**2 )

In [9]:
print("ACPE results: (average over groups)")
print(f"ACL: {ac_acl}") 
print(f"MSE: {ac_mse}")
print(f"ECP: {learned_in_threshold_perc}")


print("MVPE results: (average over groups)")
print(f"ACL: {mv_acl}") 
print(f"MSE: {mv_mse}")
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: (average over groups)
ACL: 0.13345486501716344
MSE: 0.00121464743509852
ECP: 0.935
MVPE results: (average over groups)
ACL: 0.1662323365544558
MSE: 1.5521078296266537
ECP: 0.0


# Reports that seperate two groups

In [10]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD

learned_in_threshold_perc = np.sum(learned_in_threshold, axis=0) / learned_in_threshold.shape[0]
ng_in_threshold_perc = np.sum(ng_in_threshold, axis=0) / ng_in_threshold.shape[0]

print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)

if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= [0.9  0.97] , ng_in_threshold_perc= [0. 0.]


In [11]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list, axis=0)
ac_mse = np.mean((mu_learned_list - v_truth)**2, axis=0)

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list, axis=0)
mv_mse = np.mean( (mu_ng_list - v_truth)**2, axis=0 )

In [12]:
print("ACPE results: Group1, Group 2")
print(f"MSE: {ac_mse}")
print(f"ACL: {ac_acl}") 
print(f"ECP: {learned_in_threshold_perc}")

print("===")
print("MVPE results: ")
print(f"MSE: {mv_mse}")
print(f"ACL: {mv_acl}") 
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: Group1, Group 2
MSE: [0.001506  0.0009233]
ACL: [0.13365246 0.13325727]
ECP: [0.9  0.97]
===
MVPE results: 
MSE: [1.49338392 1.61083174]
ACL: [0.16623234]
ECP: [0. 0.]
