In [1]:
import copy
import pickle
from datetime import datetime as dtdt

import attrs
import numpy as np

from hetero.config import DTYPE, AlgoConfig, DataGenConfig, GroupingConfig
from hetero.datagen import generate_data_from_config
from hetero.policies import AlternativePolicy
from hetero.tasks import (
    beta_estimate_from,
    beta_estimate_from_e2e_learning,
    beta_estimate_from_nongrouped,
    compute_UV_truths,
    compute_V_estimate,
)

group_reward_coeff = np.array(
    [
        [-2, 2, 2, -2],
        [2, -2, -2, 2],
    ],
    dtype=DTYPE,
)

action_reward_coeff = [-1, 1]

FEATURE_TYPE = "LEGENDRE"
TRANS = "NORMCDF"
NOISE = "STUDENT"

data_config_init = dict(
    num_trajectories=100,
    num_time_steps=40,
    group_reward_coeff_override=group_reward_coeff,
    action_reward_coeff=action_reward_coeff,
    num_burnin_steps=100,
    basis_expansion_method=FEATURE_TYPE,
    transformation_method=TRANS,
    add_intercept_column=True,
    noise_type=NOISE,
    noise_student_degree=4,
)

# First time runner: set COMPUTE_TRUTH = True
# Change the flag to False after generated the truth file.
COMPUTE_TRUTH = False #True
# Change the truth file name if settings are changed.
TRUTH_FILE = f"hetero/data/{FEATURE_TYPE}_{TRANS}_{NOISE}_truth_20230528_2.68_2.89.pkl"
print("truth file name =", TRUTH_FILE)

time_tag = dtdt.now().strftime("%Y%m%d_%H-%M-%S")
tag = f'N={data_config_init["num_trajectories"]}_T={data_config_init["num_time_steps"]}_{time_tag}'
RESULT_FILE = f"hetero/data/{FEATURE_TYPE}_{TRANS}_{NOISE}_result_20230528_2.68_2.89_{tag}.pkl"
print("result file name =", RESULT_FILE)

SAVE_RESULT = True
if not SAVE_RESULT:
    print("Result will NOT be saved, only use this for experimental runs!!!")

NUM_EXPERIMENTS = 100

truth file name = hetero/data/LEGENDRE_NORMCDF_STUDENT_truth_20230528_2.68_2.89.pkl
result file name = hetero/data/LEGENDRE_NORMCDF_STUDENT_result_20230528_2.68_2.89_N=100_T=40_20230602_20-51-59.pkl


=====================================================================================================
# Algorithm 

- Set configure below.

In [2]:

algo_config = AlgoConfig(
    max_num_iters=1,
    gam=2.7,
    lam=2.0,
    rho=2.0,
    should_remove_outlier=True,
    outlier_lower_perc=2,
    outlier_upper_perc=98,
    nu_coeff=1e-5,
    delta_coeff=1e-5,
    use_group_wise_regression_init=True,
)

pi_eval = AlternativePolicy(2)

grouping_config = GroupingConfig()

In [3]:
if COMPUTE_TRUTH:
    truth_data_config_init = copy.copy(data_config_init)
    truth_data_config_init.pop("num_trajectories")
    us, vs = compute_UV_truths(
        truth_data_config_init,
        algo_config.discount,
        pi_eval,
        num_repeats=100,
        num_trajectories=1000,
    )  # For best results, use num_repeats=10
    u_truth = us.mean(axis=0)
    v_truth = vs.mean(axis=0)

    with open(TRUTH_FILE, "wb") as f:
        pickle.dump(
            dict(
                u_truth=u_truth,
                v_truth=v_truth,
                data_config_dict=truth_data_config_init,
                algo_config_dict=attrs.asdict(algo_config),
            ),
            f,
        )
else:
    with open(TRUTH_FILE, "rb") as f:
        loaded = pickle.load(f)
        u_truth = loaded["u_truth"]
        v_truth = loaded["v_truth"]

In [4]:
beta_ng_list = []
beta_learned_list = []

for i in range(NUM_EXPERIMENTS):
    data_config = DataGenConfig(seed=7531 * (i + 1), **data_config_init)
    data = generate_data_from_config(data_config)
    beta_ng_list.append(
        beta_estimate_from_nongrouped(data, pi_eval, algo_config.discount)
    )
    beta_learned_list.append(
        beta_estimate_from_e2e_learning(data, algo_config, grouping_config, pi_eval)
    )

new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
beta_solver, min eigen of left matrix = (0.09711564+0j)
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.11698307 -2.79210784 -2.72765136  2.91444181  1.49520674  2.8087918
  -1.0812657  -2.77855847  2.86131932  0.895992  ]
 [-1.10582598  2.93220131  2.92613584 -2.82294296 -1.76469278 -2.8838273
   1.07637123  2.89561457 -2.91087771 -0.88317579]] and inertia = 127.46361837353281
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
beta_solver, min eigen of left matrix = (0.10115076+0j)
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.18266917 -2.8783947  -2.84336967  2.7921839   1.62559649  2.93838724
  -1.05198432 -2.80890877  2.84624804  0.77057125]
 [-1.04314016  2.88884104  2.80271885 -2.80015611 -1.66

new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
beta_solver, min eigen of left matrix = 0.09666582
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.07190705 -2.90144019 -2.90736079  2.73249242  1.75396533  2.81022177
  -1.16555572 -2.78183331  2.7838928   0.90748668]
 [-1.03744223  2.96221788  2.81072877 -2.82934274 -1.68564277 -2.95439813
   1.0386462   2.85910559 -2.77511522 -0.81815355]] and inertia = 124.29825872214349
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
beta_solver, min eigen of left matrix = 0.101031326
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.05784857 -2.95766623 -2.91924374  2.80499705  1.7481658   2.86902876
  -1.07101509 -2.99794302  2.82754867  0.9365681 ]
 [-1.06708524  2.89902901  2.74340391 -2.81140015 -1.57172923 

new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
beta_solver, min eigen of left matrix = 0.09955815
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.16179749 -2.8981307  -2.78778069  2.91269264  1.56267945  2.91085465
  -1.16160528 -2.78013182  3.03796994  0.7234113 ]
 [-1.0498408   2.8255356   2.78678888 -2.82177252 -1.60757239 -2.89850504
   1.05504154  2.63622664 -2.90281522 -0.65308771]] and inertia = 121.56043359558502
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
beta_solver, min eigen of left matrix = 0.098167166
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.1021424  -2.86063233 -2.85960051  2.81368081  1.62215958  2.66367351
  -1.09109003 -2.73695351  2.74095829  0.97517478]
 [-1.16483392  2.83611876  2.9515175  -2.82985374 -1.618415   

new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
beta_solver, min eigen of left matrix = (0.099904604+0j)
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 0.99799673 -2.88392933 -2.91870004  2.62441645  1.84510529  2.96511727
  -1.11406527 -2.85543278  2.64461238  0.91415915]
 [-1.04695109  3.05832005  2.86542351 -2.78351424 -1.79556728 -2.88094047
   1.06216055  2.99861659 -2.71748818 -0.98760024]] and inertia = 117.7918355606328
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
beta_solver, min eigen of left matrix = 0.09874569
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.08592578 -2.96538041 -2.84052086  2.83597446  1.65882012  2.83133877
  -1.11491681 -2.7261524   2.86950446  0.79436511]
 [-1.08887775  2.92906679  2.77230861 -3.03949523 -1.53603

new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
beta_solver, min eigen of left matrix = 0.0978289
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.0269685  -2.89919934 -2.70427414  2.80913672  1.616236    2.92152454
  -1.07216993 -2.8466983   2.80282705  0.86227502]
 [-1.07796748  2.93755819  2.80765499 -2.77958621 -1.68679643 -2.86433805
   1.05416257  2.80258181 -2.90506864 -0.76924065]] and inertia = 122.09302589034422
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
beta_solver, min eigen of left matrix = 0.09834202
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.11777919 -2.97491288 -2.74957671  2.87671784  1.65278091  2.9575036
  -1.16546723 -2.78032155  2.83731782  0.87262883]
 [-1.09935851  2.83547419  2.87377205 -2.8665116  -1.65118446 -3.

new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
beta_solver, min eigen of left matrix = 0.101050265
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.06038222 -2.77107185 -2.84101148  2.84048606  1.60702488  2.88773464
  -1.06015668 -2.83054048  2.77314212  0.86746647]
 [-1.03551701  2.84974515  2.81567515 -2.80228874 -1.60226907 -2.82264382
   1.1675557   2.68543947 -2.8730099  -0.78748432]] and inertia = 119.78689855884787
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
beta_solver, min eigen of left matrix = (0.09796093+0j)
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.00664884 -2.84071257 -2.82725965  2.7362224   1.69492879  3.04057526
  -1.10215197 -2.83616697  2.89514844  0.73190581]
 [-1.05994649  2.96635081  2.77002128 -2.94095652 -1.6628

new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
beta_solver, min eigen of left matrix = 0.097010225
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 0.96611431 -2.88444783 -2.80762642  2.82262893  1.63909274  2.950183
  -1.03269262 -2.65019533  2.99281603  0.49930964]
 [-1.03142165  2.95799692  2.72214326 -2.83104367 -1.60388534 -2.91134723
   1.02060629  2.79467461 -2.85375728 -0.73848844]] and inertia = 126.86237318956466
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
beta_solver, min eigen of left matrix = 0.09918017
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.05750265 -2.84521203 -2.82681463  2.65818487  1.7251681   2.90368647
  -1.15273292 -2.99433235  2.77703359  0.99320806]
 [-1.14780285  2.84761719  2.87054757 -2.82235391 -1.60593592 -2

new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
beta_solver, min eigen of left matrix = (0.097606264+0j)
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.13124376 -2.92346596 -2.80591364  2.81482705  1.60283395  2.92326883
  -1.0946556  -2.92447653  2.78546581  0.90033735]
 [-1.08693264  2.70957733  2.86938727 -2.74070616 -1.59921257 -2.86246986
   1.1477102   2.94199128 -2.71416775 -1.01833536]] and inertia = 127.39894352748628
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
new_labels.length=8000 matches number of records
beta_solver, min eigen of left matrix = 0.10005048
MCPImpl: num_above=10000, num_below=9900
kmeans center = [[ 1.08759779 -2.89621454 -2.90046804  2.82651433  1.72635637  2.88603031
  -1.10426558 -3.00895365  2.78216742  1.00739776]
 [-1.10112615  3.14188571  2.90576187 -2.905895   -1.7815

In [5]:
beta_learned_list[0].betas

[array([-1.13168126,  2.88023033,  2.95894047, -2.91015014, -1.68966461,
        -2.89715212,  1.10132794,  2.92539368, -2.93005162, -0.8917558 ]),
 array([ 1.11786714, -2.83703551, -2.73071427,  2.8841388 ,  1.54645023,
         2.79443523, -1.09286539, -2.76965158,  2.91968194,  0.86873661])]

In [6]:
mu_learned_list = []
sigma_learned_list = []
z_score_learned_list = []
mu_ng_list = []
sigma_ng_list = []
z_score_ng_list = []

for beta_learned, beta_ng in zip(beta_learned_list, beta_ng_list):
    v_mus, v_sigmas = compute_V_estimate(u_truth, beta_learned)
    z_score_learned = [
        (mu - truth) / sigma for mu, truth, sigma in zip(v_mus, v_truth, v_sigmas)
    ]
    mu_learned_list.append(v_mus)
    sigma_learned_list.append(v_sigmas)
    z_score_learned_list.append(z_score_learned)

    ngv_mus, ngv_sigmas = compute_V_estimate(u_truth, beta_ng)
    z_score_ng = [
        (mu - truth) / sigma
        for mu, truth, sigma in zip(
            ngv_mus * len(v_truth), v_truth, ngv_sigmas * len(v_truth)
        )
    ]
    mu_ng_list.append(ngv_mus)
    sigma_ng_list.append(ngv_sigmas)
    z_score_ng_list.append(z_score_ng)

In [7]:
z_score_learned

[-0.8117784976997229, -0.20937281619415204]

In [8]:
v_mus, v_truth, v_sigmas

([-1.266620966415501, 1.2445002899201323],
 array([-1.2428372,  1.2507738], dtype=float32),
 [0.029298356454185, 0.029963286029091773])

# Reports that average over two groups

In [9]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD
learned_in_threshold_perc = learned_in_threshold.sum() / learned_in_threshold.size
ng_in_threshold_perc = ng_in_threshold.sum() / ng_in_threshold.size
print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)
if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= 0.97 , ng_in_threshold_perc= 0.0


In [10]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list)
ac_mse = np.mean( (mu_learned_list - v_truth)**2 )

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list)
mv_mse = np.mean( (mu_ng_list - v_truth)**2 )

In [11]:
print("ACPE results: (average over groups)")
print(f"ACL: {ac_acl}") 
print(f"MSE: {ac_mse}")
print(f"ECP: {learned_in_threshold_perc}")


print("MVPE results: (average over groups)")
print(f"ACL: {mv_acl}") 
print(f"MSE: {mv_mse}")
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: (average over groups)
ACL: 0.11517639197522714
MSE: 0.0007594814179812259
ECP: 0.97
MVPE results: (average over groups)
ACL: 0.14371708455888213
MSE: 1.5564858876388887
ECP: 0.0


# Reports that seperate two groups

In [12]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD

learned_in_threshold_perc = np.sum(learned_in_threshold, axis=0) / learned_in_threshold.shape[0]
ng_in_threshold_perc = np.sum(ng_in_threshold, axis=0) / ng_in_threshold.shape[0]

print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)

if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= [0.94 1.  ] , ng_in_threshold_perc= [0. 0.]


In [13]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list, axis=0)
ac_mse = np.mean((mu_learned_list - v_truth)**2, axis=0)

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list, axis=0)
mv_mse = np.mean( (mu_ng_list - v_truth)**2, axis=0 )

In [14]:
print("ACPE results: Group1, Group 2")
print(f"MSE: {ac_mse}")
print(f"ACL: {ac_acl}") 
print(f"ECP: {learned_in_threshold_perc}")

print("===")
print("MVPE results: ")
print(f"MSE: {mv_mse}")
print(f"ACL: {mv_acl}") 
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: Group1, Group 2
MSE: [0.00090265 0.00061632]
ACL: [0.11530798 0.11504481]
ECP: [0.94 1.  ]
===
MVPE results: 
MSE: [1.53640846 1.57656331]
ACL: [0.14371708]
ECP: [0. 0.]
