In [1]:
import copy
import pickle
from datetime import datetime as dtdt

import attrs
import numpy as np

from hetero.config import DTYPE, AlgoConfig, DataGenConfig, GroupingConfig
from hetero.datagen import generate_data_from_config
from hetero.policies import AlternativePolicy
from hetero.tasks import (
    beta_estimate_from,
    beta_estimate_from_e2e_learning,
    beta_estimate_from_nongrouped,
    compute_UV_truths,
    compute_V_estimate,
)

group_reward_coeff = np.array(
    [
        [-2, 2, 2, -2],
        [2, -2, -2, 2],
    ],
    dtype=DTYPE,
)

action_reward_coeff = [-1, 1]

FEATURE_TYPE = "LEGENDRE"
TRANS = "NORMCDF"
NOISE = "STUDENT"

data_config_init = dict(
    num_trajectories=20,
    num_time_steps=40,
    group_reward_coeff_override=group_reward_coeff,
    action_reward_coeff=action_reward_coeff,
    num_burnin_steps=100,
    basis_expansion_method=FEATURE_TYPE,
    transformation_method=TRANS,
    add_intercept_column=True,
    noise_type=NOISE,
    noise_student_degree=4,
)

# First time runner: set COMPUTE_TRUTH = True
# Change the flag to False after generated the truth file.
COMPUTE_TRUTH = False #True
# Change the truth file name if settings are changed.
TRUTH_FILE = f"hetero/data/{FEATURE_TYPE}_{TRANS}_{NOISE}_truth_20230528_2.68_2.89.pkl"
print("truth file name =", TRUTH_FILE)

time_tag = dtdt.now().strftime("%Y%m%d_%H-%M-%S")
tag = f'N={data_config_init["num_trajectories"]}_T={data_config_init["num_time_steps"]}_{time_tag}'
RESULT_FILE = f"hetero/data/{FEATURE_TYPE}_{TRANS}_{NOISE}_result_20230528_2.68_2.89_{tag}.pkl"
print("result file name =", RESULT_FILE)

SAVE_RESULT = True
if not SAVE_RESULT:
    print("Result will NOT be saved, only use this for experimental runs!!!")

NUM_EXPERIMENTS = 100

truth file name = hetero/data/LEGENDRE_NORMCDF_STUDENT_truth_20230528_2.68_2.89.pkl
result file name = hetero/data/LEGENDRE_NORMCDF_STUDENT_result_20230528_2.68_2.89_N=20_T=40_20230602_19-28-02.pkl


=====================================================================================================
# Algorithm 

- Set configure below.

In [2]:

algo_config = AlgoConfig(
    max_num_iters=2,
    gam=2.7,
    lam=2.0,
    rho=2.0,
    should_remove_outlier=True,
    outlier_lower_perc=2,
    outlier_upper_perc=98,
    nu_coeff=1e-5,
    delta_coeff=1e-5,
    use_group_wise_regression_init=True,
)

pi_eval = AlternativePolicy(2)

grouping_config = GroupingConfig()

In [3]:
if COMPUTE_TRUTH:
    truth_data_config_init = copy.copy(data_config_init)
    truth_data_config_init.pop("num_trajectories")
    us, vs = compute_UV_truths(
        truth_data_config_init,
        algo_config.discount,
        pi_eval,
        num_repeats=100,
        num_trajectories=1000,
    )  # For best results, use num_repeats=10
    u_truth = us.mean(axis=0)
    v_truth = vs.mean(axis=0)

    with open(TRUTH_FILE, "wb") as f:
        pickle.dump(
            dict(
                u_truth=u_truth,
                v_truth=v_truth,
                data_config_dict=truth_data_config_init,
                algo_config_dict=attrs.asdict(algo_config),
            ),
            f,
        )
else:
    with open(TRUTH_FILE, "rb") as f:
        loaded = pickle.load(f)
        u_truth = loaded["u_truth"]
        v_truth = loaded["v_truth"]

In [4]:
beta_ng_list = []
beta_learned_list = []

for i in range(NUM_EXPERIMENTS):
    data_config = DataGenConfig(seed=7531 * (i + 1), **data_config_init)
    data = generate_data_from_config(data_config)
    beta_ng_list.append(
        beta_estimate_from_nongrouped(data, pi_eval, algo_config.discount)
    )
    beta_learned_list.append(
        beta_estimate_from_e2e_learning(data, algo_config, grouping_config, pi_eval)
    )

new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
beta_solver, min eigen of left matrix = (0.32294682+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.32294682+0j)
MCPImpl: num_above=450, num_below=330
kmeans center = [[-0.80543279  2.6558304   2.7773659  -2.43838116 -1.9048328  -2.61387618
   1.23645853  2.3467529  -2.51464342 -0.87147047]
 [ 1.10704633 -3.25724832 -2.76908969  2.81145655  1.66344601  3.43748822
  -1.12915355 -2.91098619  2.74759139  0.39059904]] and inertia = 548.6041824321708
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
beta_solver, min eigen of left matrix = (0.34699604+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.34699604+0j)
MCPImpl: num_above=447, num_below=333
kmeans center 

beta_solver, min eigen of left matrix = 0.33260986
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.33260986
MCPImpl: num_above=465, num_below=315
kmeans center = [[-1.38334535  2.92658124  2.63477157 -2.94924159 -1.52044337 -2.70223978
   1.16251154  2.55091122 -2.77661876 -1.0407032 ]
 [ 1.26197581 -2.91814992 -2.62858071  2.45870847  1.35397906  2.65031539
  -1.25563522 -2.56483192  2.84122976  0.54980281]] and inertia = 591.3369561874542
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
beta_solver, min eigen of left matrix = 0.34232634
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.34232634
MCPImpl: num_above=450, num_below=330
kmeans center = [[-1.03567965  2.90781175  2.63733129 -2.57871774 -1.63461    -2.67964109
   0.94897322  2.79909499 -2.71305761 -1.0

new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
beta_solver, min eigen of left matrix = (0.3406431+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.3406431+0j)
MCPImpl: num_above=476, num_below=304
kmeans center = [[-1.73595055  2.85246626  2.85646193 -2.87653498 -1.45421252 -2.89163405
   1.3126686   2.8856996  -2.80392863 -1.04513224]
 [ 1.09553433 -2.91409985 -2.99808524  2.96224307  1.74074384  2.95348947
  -0.93399081 -2.74776722  2.8951494   0.67344734]] and inertia = 597.1687800249557
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
beta_solver, min eigen of left matrix = 0.33309045
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.33309045
MCPImpl: num_above=448, num_below=332
kmeans center = [[-1.22369

MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.34627917+0j)
MCPImpl: num_above=443, num_below=337
kmeans center = [[-0.99785997  2.86084141  2.98063754 -2.72453801 -1.62287815 -3.05887352
   1.15696316  2.79163176 -2.38765149 -0.93848042]
 [ 1.2192781  -3.31133615 -2.854102    2.92870309  1.78244279  2.65865366
  -0.92700478 -2.82675486  3.08892472  0.73317421]] and inertia = 554.5681615738183
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
beta_solver, min eigen of left matrix = 0.32378078
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.32378078
MCPImpl: num_above=467, num_below=313
kmeans center = [[-1.10411846  2.89969587  2.80566183 -3.06836175 -1.75158791 -2.7675352
   1.13438503  2.40316791 -2.82185887 -0.7333095 ]
 [ 1.39170201 -2.60626639 -2.76834612 

beta_solver, min eigen of left matrix = 0.33984122
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.33984122
MCPImpl: num_above=446, num_below=334
kmeans center = [[-1.0204324   2.87842717  2.3271715  -2.81154421 -1.29629719 -2.71144976
   1.23697627  2.59698148 -2.88271466 -0.86587671]
 [ 1.00999637 -2.72523683 -3.00778618  2.92795899  1.68460896  2.84209277
  -1.28184584 -2.71734413  2.36855551  1.16309682]] and inertia = 505.566674915732
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
beta_solver, min eigen of left matrix = (0.3243695+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.3243695+0j)
MCPImpl: num_above=448, num_below=332
kmeans center = [[-0.99392412  2.63687426  2.8124394  -2.70321998 -1.57129468 -3.00839899
   1.25951636  3.07785554 -2.741650

MCPImpl: num_above=462, num_below=318
kmeans center = [[-1.31032669  2.96297476  2.72320362 -2.59544563 -1.6519495  -2.92649342
   1.03877615  2.61446041 -2.64708582 -0.71163757]
 [ 0.97521561 -2.37368405 -2.60721387  2.49620358  1.41854719  2.96097743
  -1.40558293 -2.5518854   3.01166916  0.6573093 ]] and inertia = 576.7241019528392
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
beta_solver, min eigen of left matrix = (0.3456734+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.3456734+0j)
MCPImpl: num_above=430, num_below=350
kmeans center = [[-1.33263832  2.6181045   3.05826336 -2.93060077 -1.45434406 -3.00503034
   1.12221358  2.72181658 -2.69905962 -0.86911958]
 [ 1.26042284 -3.01070348 -2.7477068   2.41487756  1.8695503   2.92093285
  -1.07958307 -3.1221179   2.77386505  0.9417912 

MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.3421925
MCPImpl: num_above=466, num_below=314
kmeans center = [[-1.23856455  2.88064394  2.63158504 -2.45728808 -1.60688772 -2.43632964
   1.07301046  2.60266708 -2.6507194  -1.13732191]
 [ 1.08226282 -2.63288477 -2.76226234  2.40759042  1.7190451   2.97447064
  -1.11750417 -2.78161027  2.70554023  0.69104155]] and inertia = 592.6540777644561
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
beta_solver, min eigen of left matrix = 0.34696054
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.34696054
MCPImpl: num_above=456, num_below=324
kmeans center = [[-1.02070868  2.8789444   2.6388107  -2.46940031 -2.00265666 -2.97109338
   1.30495545  3.16952909 -2.84825301 -1.24822954]
 [ 1.12393116 -2.74264576 -3.02122525  2.94

beta_solver, min eigen of left matrix = (0.33105266+0j)
MCPImpl: num_above=469, num_below=311
kmeans center = [[-1.35966257  2.9823707   2.71389767 -2.61036578 -1.64300816 -2.75838101
   1.22693107  2.7901723  -2.78671515 -0.85671154]
 [ 0.96820689 -2.93527248 -3.08717499  2.40107726  2.11228399  2.84488291
  -1.26002061 -3.21507642  2.79242385  1.23147925]] and inertia = 541.0787549440909
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
beta_solver, min eigen of left matrix = (0.3269735+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.3269735+0j)
MCPImpl: num_above=464, num_below=316
kmeans center = [[-1.08133168  3.12436065  2.75713113 -2.83850744 -1.63067977 -2.9027539
   1.15313388  2.6117926  -2.7625026  -0.62179864]
 [ 0.84652818 -2.75327706 -2.62574681  2.68913534  1.65137221  2.616

MCPImpl: num_above=423, num_below=357
kmeans center = [[-1.31826025  2.70051767  2.73859128 -2.79220676 -1.35374792 -2.82322241
   0.96746598  2.68802865 -2.69557932 -0.87434954]
 [ 1.25039921 -2.62581083 -3.09124224  2.65668032  1.62629121  2.81281974
  -1.34909228 -2.57030127  2.79899504  0.78877012]] and inertia = 452.39439267799037
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
new_labels.length=1600 matches number of records
beta_solver, min eigen of left matrix = 0.32850948
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.32850948
MCPImpl: num_above=486, num_below=294
kmeans center = [[-1.19490585  2.3222569   3.04820122 -2.71460155 -1.33776067 -2.78938349
   1.15436821  2.82010098 -2.67109407 -0.75924468]
 [ 1.08377139 -3.31016261 -2.6658146   2.79393247  1.82085622  2.93959182
  -1.08477251 -3.01923867  3.03340506  0.80547339]] and 

In [5]:
beta_learned_list[0].betas

[array([-0.86866768,  2.98572489,  2.88142249, -2.59050249, -2.04725239,
        -2.9300162 ,  1.12203073,  2.48283532, -2.71399135, -0.74628915]),
 array([ 1.15650656, -3.04619695, -2.75552644,  2.87002622,  1.53738659,
         2.94252536, -1.27393997, -2.85006352,  3.05926713,  0.68237785])]

In [6]:
mu_learned_list = []
sigma_learned_list = []
z_score_learned_list = []
mu_ng_list = []
sigma_ng_list = []
z_score_ng_list = []

for beta_learned, beta_ng in zip(beta_learned_list, beta_ng_list):
    v_mus, v_sigmas = compute_V_estimate(u_truth, beta_learned)
    z_score_learned = [
        (mu - truth) / sigma for mu, truth, sigma in zip(v_mus, v_truth, v_sigmas)
    ]
    mu_learned_list.append(v_mus)
    sigma_learned_list.append(v_sigmas)
    z_score_learned_list.append(z_score_learned)

    ngv_mus, ngv_sigmas = compute_V_estimate(u_truth, beta_ng)
    z_score_ng = [
        (mu - truth) / sigma
        for mu, truth, sigma in zip(
            ngv_mus * len(v_truth), v_truth, ngv_sigmas * len(v_truth)
        )
    ]
    mu_ng_list.append(ngv_mus)
    sigma_ng_list.append(ngv_sigmas)
    z_score_ng_list.append(z_score_ng)

# Reports that average over two groups

In [7]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD
learned_in_threshold_perc = learned_in_threshold.sum() / learned_in_threshold.size
ng_in_threshold_perc = ng_in_threshold.sum() / ng_in_threshold.size
print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)
if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= 0.95 , ng_in_threshold_perc= 0.0


In [8]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list)
ac_mse = np.mean( (mu_learned_list - v_truth)**2 )

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list)
mv_mse = np.mean( (mu_ng_list - v_truth)**2 )

In [9]:
print("ACPE results: (average over groups)")
print(f"ACL: {ac_acl}") 
print(f"MSE: {ac_mse}")
print(f"ECP: {learned_in_threshold_perc}")


print("MVPE results: (average over groups)")
print(f"ACL: {mv_acl}") 
print(f"MSE: {mv_mse}")
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: (average over groups)
ACL: 0.25752530389093564
MSE: 0.00521173257159066
ECP: 0.95
MVPE results: (average over groups)
ACL: 0.32211410059434287
MSE: 1.561364601166128
ECP: 0.0


# Reports that seperate two groups

In [10]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD

learned_in_threshold_perc = np.sum(learned_in_threshold, axis=0) / learned_in_threshold.shape[0]
ng_in_threshold_perc = np.sum(ng_in_threshold, axis=0) / ng_in_threshold.shape[0]

print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)

if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= [0.94 0.96] , ng_in_threshold_perc= [0. 0.]


In [11]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list, axis=0)
ac_mse = np.mean((mu_learned_list - v_truth)**2, axis=0)

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list, axis=0)
mv_mse = np.mean( (mu_ng_list - v_truth)**2, axis=0 )

In [12]:
print("ACPE results: Group1, Group 2")
print(f"MSE: {ac_mse}")
print(f"ACL: {ac_acl}") 
print(f"ECP: {learned_in_threshold_perc}")

print("===")
print("MVPE results: ")
print(f"MSE: {mv_mse}")
print(f"ACL: {mv_acl}") 
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: Group1, Group 2
MSE: [0.00504553 0.00537793]
ACL: [0.25735495 0.25769566]
ECP: [0.94 0.96]
===
MVPE results: 
MSE: [1.53775892 1.58497028]
ACL: [0.3221141]
ECP: [0. 0.]
