In [1]:
import copy
import pickle
from datetime import datetime as dtdt

import attrs
import numpy as np

from hetero.config import DTYPE, AlgoConfig, DataGenConfig, GroupingConfig
from hetero.datagen import generate_data_from_config
from hetero.policies import AlternativePolicy
from hetero.tasks import (
    beta_estimate_from,
    beta_estimate_from_e2e_learning,
    beta_estimate_from_nongrouped,
    compute_UV_truths,
    compute_V_estimate,
)

group_reward_coeff = np.array(
    [
        [-2, 2, 2, -2],
        [2, -2, -2, 2],
    ],
    dtype=DTYPE,
)

action_reward_coeff = [-1, 1]

FEATURE_TYPE = "LEGENDRE"
TRANS = "NORMCDF"
NOISE = "STUDENT"

data_config_init = dict(
    num_trajectories=50,
    num_time_steps=30,
    group_reward_coeff_override=group_reward_coeff,
    action_reward_coeff=action_reward_coeff,
    num_burnin_steps=100,
    basis_expansion_method=FEATURE_TYPE,
    transformation_method=TRANS,
    add_intercept_column=True,
    noise_type=NOISE,
    noise_student_degree=4,
)

# First time runner: set COMPUTE_TRUTH = True
# Change the flag to False after generated the truth file.
COMPUTE_TRUTH = False #True
# Change the truth file name if settings are changed.
TRUTH_FILE = f"hetero/data/{FEATURE_TYPE}_{TRANS}_{NOISE}_truth_20230528_2.68_2.89.pkl"
print("truth file name =", TRUTH_FILE)

time_tag = dtdt.now().strftime("%Y%m%d_%H-%M-%S")
tag = f'N={data_config_init["num_trajectories"]}_T={data_config_init["num_time_steps"]}_{time_tag}'
RESULT_FILE = f"hetero/data/{FEATURE_TYPE}_{TRANS}_{NOISE}_result_20230528_2.68_2.89_{tag}.pkl"
print("result file name =", RESULT_FILE)

SAVE_RESULT = True
if not SAVE_RESULT:
    print("Result will NOT be saved, only use this for experimental runs!!!")

NUM_EXPERIMENTS = 100


truth file name = hetero/data/LEGENDRE_NORMCDF_STUDENT_truth_20230528_2.68_2.89.pkl
result file name = hetero/data/LEGENDRE_NORMCDF_STUDENT_result_20230528_2.68_2.89_N=50_T=30_20230602_19-31-34.pkl


=====================================================================================================
# Algorithm 

- Set configure below.

In [2]:

algo_config = AlgoConfig(
    max_num_iters=2,
    gam=2.7,
    lam=2.0,
    rho=2.0,
    should_remove_outlier=True,
    outlier_lower_perc=2,
    outlier_upper_perc=98,
    nu_coeff=1e-5,
    delta_coeff=1e-5,
    use_group_wise_regression_init=True,
)

pi_eval = AlternativePolicy(2)

grouping_config = GroupingConfig()

In [3]:
if COMPUTE_TRUTH:
    truth_data_config_init = copy.copy(data_config_init)
    truth_data_config_init.pop("num_trajectories")
    us, vs = compute_UV_truths(
        truth_data_config_init,
        algo_config.discount,
        pi_eval,
        num_repeats=100,
        num_trajectories=1000,
    )  # For best results, use num_repeats=10
    u_truth = us.mean(axis=0)
    v_truth = vs.mean(axis=0)

    with open(TRUTH_FILE, "wb") as f:
        pickle.dump(
            dict(
                u_truth=u_truth,
                v_truth=v_truth,
                data_config_dict=truth_data_config_init,
                algo_config_dict=attrs.asdict(algo_config),
            ),
            f,
        )
else:
    with open(TRUTH_FILE, "rb") as f:
        loaded = pickle.load(f)
        u_truth = loaded["u_truth"]
        v_truth = loaded["v_truth"]

In [4]:
beta_ng_list = []
beta_learned_list = []

for i in range(NUM_EXPERIMENTS):
    data_config = DataGenConfig(seed=7531 * (i + 1), **data_config_init)
    data = generate_data_from_config(data_config)
    beta_ng_list.append(
        beta_estimate_from_nongrouped(data, pi_eval, algo_config.discount)
    )
    beta_learned_list.append(
        beta_estimate_from_e2e_learning(data, algo_config, grouping_config, pi_eval)
    )

new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = 0.21172589
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.21172589
MCPImpl: num_above=3289, num_below=1661
kmeans center = [[ 1.22256673 -2.93378604 -2.65545035  2.69965906  1.66057346  2.68645154
  -1.04774924 -2.8416291   2.48031286  1.28303017]
 [-1.22020988  2.63031888  2.84601028 -2.66773059 -1.5659217  -2.59432909
   1.2581601   2.57177564 -2.76680787 -0.93748003]] and inertia = 1812.961771911872
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = 0.21698116
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.21698116
MCPImpl: num_above=3401, num_below=1549
kmeans center = [[ 1.1405

MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.2065867+0j)
MCPImpl: num_above=3489, num_below=1461
kmeans center = [[ 1.36691658 -2.92878167 -2.81827543  2.71620623  1.59034955  3.0477975
  -1.0812906  -2.54392981  2.71710921  0.63102675]
 [-1.12899819  2.57445794  2.82565086 -2.65421676 -1.5714187  -2.88290877
   1.27046981  2.84629492 -2.81386177 -0.97297828]] and inertia = 2078.382996060377
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = 0.20477371
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.20477371
MCPImpl: num_above=3270, num_below=1680
kmeans center = [[ 1.24882939 -2.56350396 -2.80148181  2.71824239  1.53138001  2.99096533
  -1.02842742 -2.92710788  2.87232641  0.82796167]
 [-1.0963192   2.80018278  2.9

MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.20810655
MCPImpl: num_above=3517, num_below=1433
kmeans center = [[ 1.17208142 -2.56864703 -2.7321359   2.64556476  1.58866747  2.37639867
  -1.20427104 -2.7537611   3.00272306  1.07264775]
 [-1.29678649  3.00785307  2.69941206 -2.50092598 -1.67022839 -3.2344043
   1.15118571  2.78187764 -2.60370886 -0.77616253]] and inertia = 2045.9365773554268
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = 0.20451838
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.20451838
MCPImpl: num_above=3577, num_below=1373
kmeans center = [[ 1.18306647 -3.02220905 -2.76829183  2.74560503  1.77732406  3.06835061
  -1.3639659  -2.42405739  2.78229571  0.68505847]
 [-1.32833899  2.57327855  2.7465

MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.20415588
MCPImpl: num_above=3402, num_below=1548
kmeans center = [[ 1.23483907 -2.94329249 -2.78940102  3.17076924  1.24654688  2.63533702
  -1.231669   -2.62252619  2.89323301  0.73585903]
 [-1.00237648  2.72971235  2.587662   -2.63685158 -1.66093546 -2.6050231
   1.15526522  2.7159969  -2.78303447 -1.08986107]] and inertia = 1936.7831308044047
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = (0.2065918+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.2065918+0j)
MCPImpl: num_above=3323, num_below=1627
kmeans center = [[ 1.00685914 -3.11193312 -2.82956686  2.50706386  2.03986264  3.09516854
  -0.98388418 -2.71831367  2.72003525  0.79436734]
 [-1.33711482  2.44687877

beta_solver, min eigen of left matrix = 0.2053649
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.2053649
MCPImpl: num_above=3399, num_below=1551
kmeans center = [[ 1.24367564 -3.0428665  -2.84140987  2.77757365  1.58842167  2.99825714
  -1.11531478 -2.78254799  3.04358847  0.63223483]
 [-1.3807595   2.67750377  2.82941178 -2.58029727 -1.47924143 -2.63410391
   1.1794936   2.88341437 -2.65246752 -1.06915769]] and inertia = 1929.3080725394454
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = 0.2043946
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.2043946
MCPImpl: num_above=3426, num_below=1524
kmeans center = [[ 1.05436777 -2.61875277 -2.70607345  2.71795752  1.38696385  2.77565557
  -1.12483457 -2.84054254  2.8681795

new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = (0.2134258+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.2134258+0j)
MCPImpl: num_above=3395, num_below=1555
kmeans center = [[ 1.26689161 -2.84203922 -2.75997834  2.886412    1.45715491  2.95139054
  -1.1728458  -2.89993477  2.85289555  0.85397075]
 [-1.07452426  2.96721202  2.67036109 -2.73029742 -1.7826269  -2.91288955
   1.11405796  2.42012145 -2.66697906 -0.8517193 ]] and inertia = 1933.2686788949707
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = (0.20686243+0j)
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = (0.20686243+0j)
MCPImpl: num_above=3438, num_below=1512
kmeans

beta_solver, min eigen of left matrix = 0.19705334
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.19705334
MCPImpl: num_above=3533, num_below=1417
kmeans center = [[ 1.36220602 -2.81265521 -2.77182388  2.56285571  1.43191839  2.85124894
  -1.19836562 -2.81582453  2.81453681  0.8846409 ]
 [-1.06860908  2.82349815  3.06708795 -2.96561536 -1.69656381 -2.83591022
   1.22170864  2.85856601 -2.58393849 -1.17845874]] and inertia = 2018.293118645499
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = 0.20261604
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.20261604
MCPImpl: num_above=3396, num_below=1554
kmeans center = [[ 1.18211937 -2.88922046 -2.88843664  2.92296399  1.60825268  3.02730246
  -1.21727669 -2.92591908  2.9901

new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = 0.20569935
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.20569935
MCPImpl: num_above=3361, num_below=1589
kmeans center = [[ 1.12037836 -3.08831929 -2.90386264  2.65451447  1.74951569  2.71752762
  -1.23040738 -2.65936558  2.76865343  0.9443608 ]
 [-1.27496726  2.71647986  2.92383514 -2.66439199 -1.71302209 -2.6344696
   1.04231298  2.66582141 -2.76326905 -0.99958028]] and inertia = 1869.3096011363114
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = 0.20225376
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.20225376
MCPImpl: num_above=3454, num_below=1496
kmeans center = [[ 0.97060128 -2.43898447 -2.96232731  2.58992658  1.64219

beta_solver, min eigen of left matrix = 0.21261662
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.21261662
MCPImpl: num_above=3384, num_below=1566
kmeans center = [[ 1.24324424 -2.76663221 -2.65418174  2.69499743  1.50121219  2.88575262
  -1.38699931 -2.76417347  2.59055819  1.01777663]
 [-1.122367    2.97212049  2.73454264 -2.80050323 -1.59390691 -3.07759113
   1.04030873  2.78280625 -3.04399018 -0.42892596]] and inertia = 1920.8476878723031
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
new_labels.length=3000 matches number of records
beta_solver, min eigen of left matrix = 0.21701114
MCPImpl: num_above=2500, num_below=2450
beta_solver, min eigen of left matrix = 0.21701114
MCPImpl: num_above=3296, num_below=1654
kmeans center = [[ 1.16705752 -2.97980108 -2.76476584  2.85910693  1.68256989  2.82085869
  -0.9996035  -2.89475315  3.046

In [5]:
beta_learned_list[0].betas

[array([-1.14825252,  2.7560394 ,  2.80102004, -2.85296955, -1.50420937,
        -2.72469127,  1.07996754,  2.81299547, -2.71056711, -1.00898266]),
 array([ 1.10722816, -3.07183181, -2.9288532 ,  2.79775352,  1.86935299,
         2.79250136, -0.99397352, -2.80506411,  2.61283611,  1.10567068])]

In [6]:
mu_learned_list = []
sigma_learned_list = []
z_score_learned_list = []
mu_ng_list = []
sigma_ng_list = []
z_score_ng_list = []

for beta_learned, beta_ng in zip(beta_learned_list, beta_ng_list):
    v_mus, v_sigmas = compute_V_estimate(u_truth, beta_learned)
    z_score_learned = [
        (mu - truth) / sigma for mu, truth, sigma in zip(v_mus, v_truth, v_sigmas)
    ]
    mu_learned_list.append(v_mus)
    sigma_learned_list.append(v_sigmas)
    z_score_learned_list.append(z_score_learned)

    ngv_mus, ngv_sigmas = compute_V_estimate(u_truth, beta_ng)
    z_score_ng = [
        (mu - truth) / sigma
        for mu, truth, sigma in zip(
            ngv_mus * len(v_truth), v_truth, ngv_sigmas * len(v_truth)
        )
    ]
    mu_ng_list.append(ngv_mus)
    sigma_ng_list.append(ngv_sigmas)
    z_score_ng_list.append(z_score_ng)

# Reports that average over two groups

In [7]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD
learned_in_threshold_perc = learned_in_threshold.sum() / learned_in_threshold.size
ng_in_threshold_perc = ng_in_threshold.sum() / ng_in_threshold.size
print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)
if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= 0.935 , ng_in_threshold_perc= 0.0


In [8]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list)
ac_mse = np.mean( (mu_learned_list - v_truth)**2 )

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list)
mv_mse = np.mean( (mu_ng_list - v_truth)**2 )

In [9]:
print("ACPE results: (average over groups)")
print(f"ACL: {ac_acl}") 
print(f"MSE: {ac_mse}")
print(f"ECP: {learned_in_threshold_perc}")


print("MVPE results: (average over groups)")
print(f"ACL: {mv_acl}") 
print(f"MSE: {mv_mse}")
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: (average over groups)
ACL: 0.18896357540956257
MSE: 0.002425944262492409
ECP: 0.935
MVPE results: (average over groups)
ACL: 0.23573208478249486
MSE: 1.5541338314281112
ECP: 0.0


# Reports that seperate two groups

In [10]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD

learned_in_threshold_perc = np.sum(learned_in_threshold, axis=0) / learned_in_threshold.shape[0]
ng_in_threshold_perc = np.sum(ng_in_threshold, axis=0) / ng_in_threshold.shape[0]

print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)

if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= [0.91 0.96] , ng_in_threshold_perc= [0. 0.]


In [11]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list, axis=0)
ac_mse = np.mean((mu_learned_list - v_truth)**2, axis=0)

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list, axis=0)
mv_mse = np.mean( (mu_ng_list - v_truth)**2, axis=0 )

In [12]:
print("ACPE results: Group1, Group 2")
print(f"MSE: {ac_mse}")
print(f"ACL: {ac_acl}") 
print(f"ECP: {learned_in_threshold_perc}")

print("===")
print("MVPE results: ")
print(f"MSE: {mv_mse}")
print(f"ACL: {mv_acl}") 
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: Group1, Group 2
MSE: [0.0026913  0.00216059]
ACL: [0.18898262 0.18894453]
ECP: [0.91 0.96]
===
MVPE results: 
MSE: [1.49468768 1.61357998]
ACL: [0.23573208]
ECP: [0. 0.]
