In [1]:
import copy
import pickle
from datetime import datetime as dtdt

import attrs
import numpy as np

from hetero.config import DTYPE, AlgoConfig, DataGenConfig, GroupingConfig
from hetero.datagen import generate_data_from_config
from hetero.policies import AlternativePolicy
from hetero.tasks import (
    beta_estimate_from,
    beta_estimate_from_e2e_learning,
    beta_estimate_from_nongrouped,
    compute_UV_truths,
    compute_V_estimate,
)

group_reward_coeff = np.array(
    [
        [-2, 2, 2, -2],
        [2, -2, -2, 2],
    ],
    dtype=DTYPE,
)

action_reward_coeff = [-1, 1]

FEATURE_TYPE = "LEGENDRE"
TRANS = "NORMCDF"
NOISE = "STUDENT"

data_config_init = dict(
    num_trajectories=20,
    num_time_steps=30,
    group_reward_coeff_override=group_reward_coeff,
    action_reward_coeff=action_reward_coeff,
    num_burnin_steps=100,
    basis_expansion_method=FEATURE_TYPE,
    transformation_method=TRANS,
    add_intercept_column=True,
    noise_type=NOISE,
    noise_student_degree=4,
)

# First time runner: set COMPUTE_TRUTH = True
# Change the flag to False after generated the truth file.
COMPUTE_TRUTH = False #True
# Change the truth file name if settings are changed.
TRUTH_FILE = f"hetero/data/{FEATURE_TYPE}_{TRANS}_{NOISE}_truth_20230528_2.68_2.89.pkl"
print("truth file name =", TRUTH_FILE)

time_tag = dtdt.now().strftime("%Y%m%d_%H-%M-%S")
tag = f'N={data_config_init["num_trajectories"]}_T={data_config_init["num_time_steps"]}_{time_tag}'
RESULT_FILE = f"hetero/data/{FEATURE_TYPE}_{TRANS}_{NOISE}_result_20230528_2.68_2.89_{tag}.pkl"
print("result file name =", RESULT_FILE)

SAVE_RESULT = True
if not SAVE_RESULT:
    print("Result will NOT be saved, only use this for experimental runs!!!")

NUM_EXPERIMENTS = 100

truth file name = hetero/data/LEGENDRE_NORMCDF_STUDENT_truth_20230528_2.68_2.89.pkl
result file name = hetero/data/LEGENDRE_NORMCDF_STUDENT_result_20230528_2.68_2.89_N=20_T=30_20230602_19-21-52.pkl


=====================================================================================================
# Algorithm 

- Set configure below.

In [2]:

algo_config = AlgoConfig(
    max_num_iters=2,
    gam=2.7,
    lam=2.0,
    rho=2.0,
    should_remove_outlier=True,
    outlier_lower_perc=2,
    outlier_upper_perc=98,
    nu_coeff=1e-5,
    delta_coeff=1e-5,
    use_group_wise_regression_init=True,
)

pi_eval = AlternativePolicy(2)

grouping_config = GroupingConfig()

In [3]:
if COMPUTE_TRUTH:
    truth_data_config_init = copy.copy(data_config_init)
    truth_data_config_init.pop("num_trajectories")
    us, vs = compute_UV_truths(
        truth_data_config_init,
        algo_config.discount,
        pi_eval,
        num_repeats=100,
        num_trajectories=1000,
    )  # For best results, use num_repeats=10
    u_truth = us.mean(axis=0)
    v_truth = vs.mean(axis=0)

    with open(TRUTH_FILE, "wb") as f:
        pickle.dump(
            dict(
                u_truth=u_truth,
                v_truth=v_truth,
                data_config_dict=truth_data_config_init,
                algo_config_dict=attrs.asdict(algo_config),
            ),
            f,
        )
else:
    with open(TRUTH_FILE, "rb") as f:
        loaded = pickle.load(f)
        u_truth = loaded["u_truth"]
        v_truth = loaded["v_truth"]

In [4]:
beta_ng_list = []
beta_learned_list = []

for i in range(NUM_EXPERIMENTS):
    data_config = DataGenConfig(seed=7531 * (i + 1), **data_config_init)
    data = generate_data_from_config(data_config)
    beta_ng_list.append(
        beta_estimate_from_nongrouped(data, pi_eval, algo_config.discount)
    )
    beta_learned_list.append(
        beta_estimate_from_e2e_learning(data, algo_config, grouping_config, pi_eval)
    )

new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
beta_solver, min eigen of left matrix = (0.36857423+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.36857423+0j)
MCPImpl: num_above=482, num_below=298
kmeans center = [[-0.95671193  2.60821005  2.10845756 -2.8483227  -1.25230806 -2.45597742
   1.07234406  2.26095402 -2.70290114 -0.89745401]
 [ 0.960868   -3.25163876 -2.84485388  2.40421746  2.04251224  3.20268551
  -0.90158307 -2.90116388  2.6120388   0.67796106]] and inertia = 644.4033518076014
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
beta_solver, min eigen of left matrix = 0.38309756
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.38309756
MCPImpl: num_above=445, num_below=335
kmeans center = [[-1.194

new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
beta_solver, min eigen of left matrix = (0.37712154+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.37712154+0j)
MCPImpl: num_above=440, num_below=340
kmeans center = [[-1.23938757  3.17463701  2.5063893  -2.23622944 -1.80556271 -3.67314221
   1.29357675  2.49369031 -2.93241202 -0.17990498]
 [ 0.90776988 -2.59609538 -2.90690551  2.73192261  1.52332797  2.99448189
  -1.01360643 -2.7584412   2.84251476  0.61455889]] and inertia = 524.7968409730415
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
beta_solver, min eigen of left matrix = 0.38039184
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.38039184
MCPImpl: num_above=500, num_below=280
kmeans center = [[-1.152

new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
beta_solver, min eigen of left matrix = 0.38214406
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.38214406
MCPImpl: num_above=466, num_below=314
kmeans center = [[-1.0342418   2.245987    2.68610527 -2.48430106 -1.42791619 -2.72706859
   1.25366697  2.92462842 -2.29108161 -1.22618956]
 [ 1.11239708 -3.19396301 -2.99194596  2.30865312  2.15104004  3.35462442
  -1.05320345 -2.72462479  2.66047616  0.5352675 ]] and inertia = 581.842997389819
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
beta_solver, min eigen of left matrix = 0.36076435
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.36076435
MCPImpl: num_above=461, num_below=319
kmeans center = [[-0.77258689  3.03

new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
beta_solver, min eigen of left matrix = 0.3806718
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.3806718
MCPImpl: num_above=448, num_below=332
kmeans center = [[-0.93351084  3.30999893  2.55024579 -2.87443985 -1.82321726 -2.97857947
   0.92444417  2.46738958 -2.63194479 -0.85733898]
 [ 1.02585072 -3.09419617 -2.59505064  2.78216806  1.68968069  2.66985078
  -0.89932428 -2.68779633  3.07869318  0.5279713 ]] and inertia = 541.5576962627148
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
beta_solver, min eigen of left matrix = 0.36459014
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.36459014
MCPImpl: num_above=463, num_below=317
kmeans center = [[-1.00010766  2.725

new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
beta_solver, min eigen of left matrix = 0.3525951
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.3525951
MCPImpl: num_above=449, num_below=331
kmeans center = [[-1.06766445  3.13064808  3.00644702 -2.13655407 -2.06811639 -2.77615544
   1.11007044  2.89923985 -2.28700784 -1.16488929]
 [ 1.01024518 -2.61507991 -2.42820801  2.94398834  1.24972754  2.93991272
  -1.11213297 -2.4379372   2.88652429  0.52637331]] and inertia = 551.7547029633633
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
beta_solver, min eigen of left matrix = 0.368183
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.368183
MCPImpl: num_above=467, num_below=313
kmeans center = [[-0.82277331  2.2715670

MCPImpl: num_above=443, num_below=337
kmeans center = [[-1.2911617   3.00609072  2.83638735 -2.9178593  -1.70295189 -3.38078708
   1.03311157  2.79742921 -2.86042302 -0.6207376 ]
 [ 1.09735015 -3.31960235 -2.7199431   2.96643294  1.76990314  2.96616052
  -1.07360656 -2.87237137  2.73444192  1.13715755]] and inertia = 497.55707013871046
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
beta_solver, min eigen of left matrix = 0.38349408
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.38349408
MCPImpl: num_above=440, num_below=340
kmeans center = [[-1.25111373  2.93688728  2.6731514  -2.61928903 -1.6131669  -2.84942654
   1.44169742  2.76919133 -2.42469865 -1.31232248]
 [ 1.27707714 -2.83382447 -3.18911698  2.46812125  1.931399    3.43521699
  -1.19536818 -3.00525045  3.12977022  0.48537501]] and 

MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.37508798
MCPImpl: num_above=465, num_below=315
kmeans center = [[-1.21881459  3.13161155  2.95893349 -2.58888091 -2.08895692 -2.79978582
   1.0779324   2.46862137 -2.12907157 -1.17549447]
 [ 1.22151533 -3.03521116 -2.6743713   2.48605675  1.68598294  2.65165667
  -1.0647717  -3.21737725  3.04908148  1.08750205]] and inertia = 563.9423919111795
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
beta_solver, min eigen of left matrix = 0.3977438
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.3977438
MCPImpl: num_above=411, num_below=369
kmeans center = [[-0.9572121   2.5065176   2.86859882 -2.5549312  -1.69184609 -2.8665267
   1.22185143  2.73778033 -2.77952826 -0.9531684 ]
 [ 1.13340308 -3.2401351  -2.78349747  2.5760

beta_solver, min eigen of left matrix = 0.38688347
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.38688347
MCPImpl: num_above=497, num_below=283
kmeans center = [[-1.10840054  3.27803268  2.74286619 -2.83488945 -1.70815752 -3.01091506
   1.21225737  2.72554593 -3.0224987  -0.61937963]
 [ 1.21378197 -2.39229353 -2.97505361  2.90501426  1.37268687  3.11576413
  -1.00987312 -2.85839339  2.5407621   0.75486783]] and inertia = 646.2976591207288
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
beta_solver, min eigen of left matrix = 0.368152
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.368152
MCPImpl: num_above=494, num_below=286
kmeans center = [[-1.10546247  2.90570233  2.57795098 -2.75763636 -1.57912484 -2.75381102
   1.57857611  2.48795981 -2.91178862 -0.86652

beta_solver, min eigen of left matrix = 0.3734565
MCPImpl: num_above=467, num_below=313
kmeans center = [[-0.91500675  2.97207133  2.55791799 -2.24570275 -1.76988048 -2.58881979
   1.20757451  2.54500547 -2.51250558 -1.0455914 ]
 [ 1.19840436 -2.66885557 -3.0456952   2.94060335  1.5854749   3.13949602
  -0.83769126 -2.82118692  2.58050186  0.70408351]] and inertia = 597.0430359866892
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
new_labels.length=1200 matches number of records
beta_solver, min eigen of left matrix = (0.38778588+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.38778588+0j)
MCPImpl: num_above=486, num_below=294
kmeans center = [[-0.84147267  2.94682213  2.44247498 -2.59162188 -1.64428176 -2.71089124
   1.05770933  2.22025364 -2.6904215  -0.52958523]
 [ 1.15268698 -2.73222552 -2.69498694  2.59512079  1.66440916  3.048842

In [5]:
beta_learned_list[0].betas

[array([-1.05547452,  3.17671026,  2.28370074, -2.7408655 , -1.67856539,
        -2.82143241,  0.96647793,  2.39598621, -2.76245355, -0.73861912]),
 array([ 0.97113797, -3.08212394, -2.85479678,  2.83905844,  1.67612835,
         2.83935417, -0.97075665, -2.9908613 ,  2.63522642,  0.84299434])]

In [6]:
mu_learned_list = []
sigma_learned_list = []
z_score_learned_list = []
mu_ng_list = []
sigma_ng_list = []
z_score_ng_list = []

for beta_learned, beta_ng in zip(beta_learned_list, beta_ng_list):
    v_mus, v_sigmas = compute_V_estimate(u_truth, beta_learned)
    z_score_learned = [
        (mu - truth) / sigma for mu, truth, sigma in zip(v_mus, v_truth, v_sigmas)
    ]
    mu_learned_list.append(v_mus)
    sigma_learned_list.append(v_sigmas)
    z_score_learned_list.append(z_score_learned)

    ngv_mus, ngv_sigmas = compute_V_estimate(u_truth, beta_ng)
    z_score_ng = [
        (mu - truth) / sigma
        for mu, truth, sigma in zip(
            ngv_mus * len(v_truth), v_truth, ngv_sigmas * len(v_truth)
        )
    ]
    mu_ng_list.append(ngv_mus)
    sigma_ng_list.append(ngv_sigmas)
    z_score_ng_list.append(z_score_ng)

# Reports that average over two groups

In [7]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD
learned_in_threshold_perc = learned_in_threshold.sum() / learned_in_threshold.size
ng_in_threshold_perc = ng_in_threshold.sum() / ng_in_threshold.size
print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)
if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= 0.96 , ng_in_threshold_perc= 0.0


In [8]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list)
ac_mse = np.mean( (mu_learned_list - v_truth)**2 )

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list)
mv_mse = np.mean( (mu_ng_list - v_truth)**2 )

In [9]:
print("ACPE results: (average over groups)")
print(f"ACL: {ac_acl}") 
print(f"MSE: {ac_mse}")
print(f"ECP: {learned_in_threshold_perc}")


print("MVPE results: (average over groups)")
print(f"ACL: {mv_acl}") 
print(f"MSE: {mv_mse}")
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: (average over groups)
ACL: 0.2979337816789931
MSE: 0.0053470257607602465
ECP: 0.96
MVPE results: (average over groups)
ACL: 0.3713872377948651
MSE: 1.562883057462219
ECP: 0.0


# Reports that seperate two groups

In [10]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD

learned_in_threshold_perc = np.sum(learned_in_threshold, axis=0) / learned_in_threshold.shape[0]
ng_in_threshold_perc = np.sum(ng_in_threshold, axis=0) / ng_in_threshold.shape[0]

print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)

if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= [0.95 0.97] , ng_in_threshold_perc= [0. 0.]


In [11]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list, axis=0)
ac_mse = np.mean((mu_learned_list - v_truth)**2, axis=0)

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list, axis=0)
mv_mse = np.mean( (mu_ng_list - v_truth)**2, axis=0 )

In [12]:
print("ACPE results: Group1, Group 2")
print(f"MSE: {ac_mse}")
print(f"ACL: {ac_acl}") 
print(f"ECP: {learned_in_threshold_perc}")

print("===")
print("MVPE results: ")
print(f"MSE: {mv_mse}")
print(f"ACL: {mv_acl}") 
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: Group1, Group 2
MSE: [0.00579391 0.00490014]
ACL: [0.29874625 0.29712131]
ECP: [0.95 0.97]
===
MVPE results: 
MSE: [1.54873517 1.57703095]
ACL: [0.37138724]
ECP: [0. 0.]
