In [1]:
import copy
import pickle
from datetime import datetime as dtdt

import attrs
import numpy as np

from hetero.config import DTYPE, AlgoConfig, DataGenConfig, GroupingConfig
from hetero.datagen import generate_data_from_config
from hetero.policies import AlternativePolicy
from hetero.tasks import (
    beta_estimate_from,
    beta_estimate_from_e2e_learning,
    beta_estimate_from_nongrouped,
    compute_UV_truths,
    compute_V_estimate,
)

group_reward_coeff = np.array(
    [
        [-2, 2, 2, -2],
        [2, -2, -2, 2],
    ],
    dtype=DTYPE,
)

action_reward_coeff = [-1, 1]

FEATURE_TYPE = "LEGENDRE"
TRANS = "NORMCDF"
NOISE = "STUDENT"

data_config_init = dict(
    num_trajectories=20,
    num_time_steps=20,
    group_reward_coeff_override=group_reward_coeff,
    action_reward_coeff=action_reward_coeff,
    num_burnin_steps=100,
    basis_expansion_method=FEATURE_TYPE,
    transformation_method=TRANS,
    add_intercept_column=True,
    noise_type=NOISE,
    noise_student_degree=4,
)

# First time runner: set COMPUTE_TRUTH = True
# Change the flag to False after generated the truth file.
COMPUTE_TRUTH = True
# Change the truth file name if settings are changed.
TRUTH_FILE = f"hetero/data/{FEATURE_TYPE}_{TRANS}_{NOISE}_truth_20230528_2.68_2.89.pkl"
print("truth file name =", TRUTH_FILE)

time_tag = dtdt.now().strftime("%Y%m%d_%H-%M-%S")
tag = f'N={data_config_init["num_trajectories"]}_T={data_config_init["num_time_steps"]}_{time_tag}'
RESULT_FILE = f"hetero/data/{FEATURE_TYPE}_{TRANS}_{NOISE}_result_20230528_2.68_2.89_{tag}.pkl"
print("result file name =", RESULT_FILE)

SAVE_RESULT = True
if not SAVE_RESULT:
    print("Result will NOT be saved, only use this for experimental runs!!!")

NUM_EXPERIMENTS = 100

truth file name = hetero/data/LEGENDRE_truth_20230528_2.68_2.89.pkl
result file name = hetero/data/LEGENDRE_result_20230528_2.68_2.89_N=20_T=20_20230602_19-12-51.pkl


=====================================================================================================
# Algorithm 

- Set configure below.

In [2]:

algo_config = AlgoConfig(
    max_num_iters=2,
    gam=2.7,
    lam=2.0,
    rho=2.0,
    should_remove_outlier=True,
    outlier_lower_perc=2,
    outlier_upper_perc=98,
    nu_coeff=1e-5,
    delta_coeff=1e-5,
    use_group_wise_regression_init=True,
)

pi_eval = AlternativePolicy(2)

grouping_config = GroupingConfig()

In [3]:
if COMPUTE_TRUTH:
    truth_data_config_init = copy.copy(data_config_init)
    truth_data_config_init.pop("num_trajectories")
    us, vs = compute_UV_truths(
        truth_data_config_init,
        algo_config.discount,
        pi_eval,
        num_repeats=100,
        num_trajectories=1000,
    )  # For best results, use num_repeats=10
    u_truth = us.mean(axis=0)
    v_truth = vs.mean(axis=0)

    with open(TRUTH_FILE, "wb") as f:
        pickle.dump(
            dict(
                u_truth=u_truth,
                v_truth=v_truth,
                data_config_dict=truth_data_config_init,
                algo_config_dict=attrs.asdict(algo_config),
            ),
            f,
        )
else:
    with open(TRUTH_FILE, "rb") as f:
        loaded = pickle.load(f)
        u_truth = loaded["u_truth"]
        v_truth = loaded["v_truth"]

In [4]:
beta_ng_list = []
beta_learned_list = []

for i in range(NUM_EXPERIMENTS):
    data_config = DataGenConfig(seed=7531 * (i + 1), **data_config_init)
    data = generate_data_from_config(data_config)
    beta_ng_list.append(
        beta_estimate_from_nongrouped(data, pi_eval, algo_config.discount)
    )
    beta_learned_list.append(
        beta_estimate_from_e2e_learning(data, algo_config, grouping_config, pi_eval)
    )

new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
beta_solver, min eigen of left matrix = 0.40588307
MCPImpl: num_above=401, num_below=379
beta_solver, min eigen of left matrix = 0.40588307
MCPImpl: num_above=492, num_below=288
kmeans center = [[-0.70212888  2.9752624   2.32654549 -3.02939799 -1.73997112 -2.83603754
   1.0655351   2.81028643 -2.34317004 -1.42689089]
 [ 0.99991076 -3.40743428 -2.70345165  3.16659759  1.69322562  3.1387003
  -0.62876525 -2.79882421  2.5731582   0.65077009]] and inertia = 613.9375596142334
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
beta_solver, min eigen of left matrix = 0.42075503
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.42075503
MCPImpl: num_above=466, num_below=314
kmeans center = [[-1.18859896  2.7307812

new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
beta_solver, min eigen of left matrix = (0.41054553+0j)
MCPImpl: num_above=405, num_below=375
beta_solver, min eigen of left matrix = (0.41054553+0j)
MCPImpl: num_above=475, num_below=305
kmeans center = [[-0.83223208  2.73387038  2.42408596 -2.62273771 -1.53937848 -2.93231844
   0.88481018  2.31810835 -2.3925519  -0.52468082]
 [ 0.93756747 -3.10563202 -3.19453894  2.89258972  1.85786742  3.02255016
  -1.18201782 -2.76946855  2.8662812   0.81216958]] and inertia = 525.4350576942727
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
beta_solver, min eigen of left matrix = 0.43321526
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.43321526
MCPImpl: num_above=459, num_below=321
kmeans center = [[-1.24766814

new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
beta_solver, min eigen of left matrix = (0.4005364+0j)
MCPImpl: num_above=401, num_below=379
beta_solver, min eigen of left matrix = (0.4005364+0j)
MCPImpl: num_above=449, num_below=331
kmeans center = [[-1.00723563  3.59998428  2.25086288 -2.69871437 -1.8224869  -3.01191352
   0.80055213  2.81032764 -2.33549054 -1.07113195]
 [ 1.10259904 -2.14861115 -2.96171938  2.54705299  1.43229686  2.49271138
  -1.22679753 -2.52244098  2.41415825  1.04939624]] and inertia = 441.87341265084194
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
beta_solver, min eigen of left matrix = 0.39202347
MCPImpl: num_above=404, num_below=376
beta_solver, min eigen of left matrix = 0.39202347
MCPImpl: num_above=441, num_below=339
kmeans center = [[-0.86355323 

new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
beta_solver, min eigen of left matrix = (0.40304232+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.40304232+0j)
MCPImpl: num_above=508, num_below=272
kmeans center = [[-0.8155612   2.88286323  2.53175563 -2.33837177 -1.60412087 -2.94439765
   1.23703076  2.31167795 -2.57916606 -0.66579787]
 [ 1.1672048  -3.61066779 -2.66512354  2.80070699  1.53377854  2.5642956
  -0.61568665 -3.07447561  2.64239527  0.64491732]] and inertia = 574.9085647913073
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
beta_solver, min eigen of left matrix = (0.43409702+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.43409702+0j)
MCPImpl: num_above=482, num_below=298
kmeans center = [[-1

new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
beta_solver, min eigen of left matrix = (0.38880244+0j)
MCPImpl: num_above=401, num_below=379
beta_solver, min eigen of left matrix = (0.38880244+0j)
MCPImpl: num_above=497, num_below=283
kmeans center = [[-0.52305344  2.818947    1.97296713 -2.95167468 -1.30878554 -3.60403153
   1.17960091  2.63545558 -2.99170915 -0.27221745]
 [ 0.9806705  -2.84407851 -2.5935566   2.54965484  1.68529117  2.43576652
  -0.98970468 -2.91201265  2.3005339   1.27568522]] and inertia = 598.161036745525
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
beta_solver, min eigen of left matrix = (0.4058615+0j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.4058615+0j)
MCPImpl: num_above=461, num_below=319
kmeans center = [[-0.9

new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
beta_solver, min eigen of left matrix = 0.41819277
MCPImpl: num_above=420, num_below=360
beta_solver, min eigen of left matrix = 0.41819277
MCPImpl: num_above=485, num_below=295
kmeans center = [[-1.10748367  3.22808989  2.46685996 -2.36454011 -1.61108375 -2.50697448
   0.61056298  2.39486748 -2.25294694 -0.83362092]
 [ 0.67579812 -3.22948884 -2.58742551  2.82454894  1.76479584  3.18852481
  -1.41516172 -2.27787326  2.62176026  0.66707527]] and inertia = 546.6596743242135
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
beta_solver, min eigen of left matrix = 0.4124608
MCPImpl: num_above=402, num_below=378
beta_solver, min eigen of left matrix = 0.4124608
MCPImpl: num_above=486, num_below=294
kmeans center = [[-1.29019173  3.17732437

new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
beta_solver, min eigen of left matrix = 0.43702915
MCPImpl: num_above=401, num_below=379
beta_solver, min eigen of left matrix = 0.43702915
MCPImpl: num_above=428, num_below=352
kmeans center = [[-1.24871803  2.9702531   2.43318499 -2.54038864 -1.42431259 -2.6262159
   0.96905182  2.77613493 -2.56739252 -0.91451872]
 [ 1.27658098 -2.82644171 -2.60040869  2.41427476  1.55305366  2.88334366
  -1.07741425 -2.72449207  3.45014824  0.38032199]] and inertia = 400.2550902133145
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
beta_solver, min eigen of left matrix = 0.38852325
MCPImpl: num_above=407, num_below=373
beta_solver, min eigen of left matrix = 0.38852325
MCPImpl: num_above=452, num_below=328
kmeans center = [[-1.0179079   3.3252991

new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
beta_solver, min eigen of left matrix = (0.41751504+0j)
MCPImpl: num_above=403, num_below=377
beta_solver, min eigen of left matrix = (0.41751504+0j)
MCPImpl: num_above=453, num_below=327
kmeans center = [[-0.89013156  2.86129577  2.91181622 -3.09446731 -1.5974744  -3.38418551
   0.93939741  2.53821175 -2.52311718 -0.54668135]
 [ 1.43056973 -2.97667931 -2.81256314  2.2973346   1.96821967  2.56410158
  -1.08931674 -2.48822666  2.19272936  1.18753666]] and inertia = 495.02311902051906
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
beta_solver, min eigen of left matrix = (0.4389216-0.007096257j)
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = (0.4389216-0.007096257j)
MCPImpl: num_above=463, num_below=317


new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
beta_solver, min eigen of left matrix = 0.41792578
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.41792578
MCPImpl: num_above=506, num_below=274
kmeans center = [[-1.0153336   3.06595583  2.39719503 -2.26734688 -1.83689755 -2.95522548
   1.06114014  2.48235205 -2.82582935 -0.5195315 ]
 [ 1.05220969 -3.09139532 -2.10181274  2.81332836  1.47485392  2.34085588
  -0.84548729 -2.613754    2.14171459  1.41483885]] and inertia = 633.1976247094103
Label mismatch = 0
new_labels.length=40 matches num_unique_labels
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
new_labels.length=800 matches number of records
beta_solver, min eigen of left matrix = 0.4252992
MCPImpl: num_above=400, num_below=380
beta_solver, min eigen of left matrix = 0.4252992
MCPImpl: num_above=476, num_below=304
kmeans center = [[-1.0545491   2.50671356

In [5]:
beta_learned_list[0].betas

[array([-0.89653033,  3.36128279,  3.01499531, -3.2003674 , -1.95845304,
        -3.00465019,  1.01820777,  3.15902687, -2.64944452, -1.25372578]),
 array([ 1.14358516, -3.09349739, -3.07312925,  3.30359994,  1.5922899 ,
         2.79827571, -1.00328117, -3.53721461,  2.89011123,  1.14055304])]

In [6]:
mu_learned_list = []
sigma_learned_list = []
z_score_learned_list = []
mu_ng_list = []
sigma_ng_list = []
z_score_ng_list = []

for beta_learned, beta_ng in zip(beta_learned_list, beta_ng_list):
    v_mus, v_sigmas = compute_V_estimate(u_truth, beta_learned)
    z_score_learned = [
        (mu - truth) / sigma for mu, truth, sigma in zip(v_mus, v_truth, v_sigmas)
    ]
    mu_learned_list.append(v_mus)
    sigma_learned_list.append(v_sigmas)
    z_score_learned_list.append(z_score_learned)

    ngv_mus, ngv_sigmas = compute_V_estimate(u_truth, beta_ng)
    z_score_ng = [
        (mu - truth) / sigma
        for mu, truth, sigma in zip(
            ngv_mus * len(v_truth), v_truth, ngv_sigmas * len(v_truth)
        )
    ]
    mu_ng_list.append(ngv_mus)
    sigma_ng_list.append(ngv_sigmas)
    z_score_ng_list.append(z_score_ng)

# Reports that average over two groups

In [7]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD
learned_in_threshold_perc = learned_in_threshold.sum() / learned_in_threshold.size
ng_in_threshold_perc = ng_in_threshold.sum() / ng_in_threshold.size
print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)
if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= 0.96 , ng_in_threshold_perc= 0.0


In [8]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list)
ac_mse = np.mean( (mu_learned_list - v_truth)**2 )

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list)
mv_mse = np.mean( (mu_ng_list - v_truth)**2 )

In [9]:
print("ACPE results: (average over groups)")
print(f"ACL: {ac_acl}") 
print(f"MSE: {ac_mse}")
print(f"ECP: {learned_in_threshold_perc}")


print("MVPE results: (average over groups)")
print(f"ACL: {mv_acl}") 
print(f"MSE: {mv_mse}")
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: (average over groups)
ACL: 0.3663329519924552
MSE: 0.008459958854660048
ECP: 0.96
MVPE results: (average over groups)
ACL: 0.45597919892359035
MSE: 1.571508557771207
ECP: 0.0


# Reports that seperate two groups

In [10]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD

learned_in_threshold_perc = np.sum(learned_in_threshold, axis=0) / learned_in_threshold.shape[0]
ng_in_threshold_perc = np.sum(ng_in_threshold, axis=0) / ng_in_threshold.shape[0]

print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)

if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= [0.97 0.95] , ng_in_threshold_perc= [0. 0.]


In [11]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list, axis=0)
ac_mse = np.mean((mu_learned_list - v_truth)**2, axis=0)

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list, axis=0)
mv_mse = np.mean( (mu_ng_list - v_truth)**2, axis=0 )

In [12]:
print("ACPE results: Group1, Group 2")
print(f"MSE: {ac_mse}")
print(f"ACL: {ac_acl}") 
print(f"ECP: {learned_in_threshold_perc}")

print("===")
print("MVPE results: ")
print(f"MSE: {mv_mse}")
print(f"ACL: {mv_acl}") 
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: Group1, Group 2
MSE: [0.08743837 0.09630394]
ACL: [0.36304876 0.36961714]
ECP: [0.97 0.95]
===
MVPE results: 
MSE: [1.25815009 1.24902981]
ACL: [0.4559792]
ECP: [0. 0.]


In [1]:
0.08743837**2, 0.09630394**2
#sigma_learned_list

(0.0076454685482569, 0.009274448859523601)

In [2]:
1.25815009**2, 1.24902981**2

(1.582941648967008, 1.5600754662686358)