In [1]:
import copy
import pickle
from datetime import datetime as dtdt

import attrs
import numpy as np

from hetero.config import DTYPE, AlgoConfig, DataGenConfig, GroupingConfig
from hetero.datagen import generate_data_from_config
from hetero.policies import AlternativePolicy
from hetero.tasks import (
    beta_estimate_from,
    beta_estimate_from_e2e_learning,
    beta_estimate_from_nongrouped,
    compute_UV_truths,
    compute_V_estimate,
)

group_reward_coeff = np.array(
    [
        [-2, 2, 2, -2],
        [2, -2, -2, 2],
    ],
    dtype=DTYPE,
)

action_reward_coeff = [-1, 1]

FEATURE_TYPE = "LEGENDRE"
TRANS = "NORMCDF"
NOISE = "STUDENT"

data_config_init = dict(
    num_trajectories=50,
    num_time_steps=20,
    group_reward_coeff_override=group_reward_coeff,
    action_reward_coeff=action_reward_coeff,
    num_burnin_steps=100,
    basis_expansion_method=FEATURE_TYPE,
    transformation_method=TRANS,
    add_intercept_column=True,
    noise_type=NOISE,
    noise_student_degree=4,
)

# First time runner: set COMPUTE_TRUTH = True
# Change the flag to False after generated the truth file.
COMPUTE_TRUTH = False #True
# Change the truth file name if settings are changed.
TRUTH_FILE = f"hetero/data/{FEATURE_TYPE}_{TRANS}_{NOISE}_truth_20230528_2.68_2.89.pkl"
print("truth file name =", TRUTH_FILE)

time_tag = dtdt.now().strftime("%Y%m%d_%H-%M-%S")
tag = f'N={data_config_init["num_trajectories"]}_T={data_config_init["num_time_steps"]}_{time_tag}'
RESULT_FILE = f"hetero/data/{FEATURE_TYPE}_{TRANS}_{NOISE}_result_20230528_2.68_2.89_{tag}.pkl"
print("result file name =", RESULT_FILE)

SAVE_RESULT = True
if not SAVE_RESULT:
    print("Result will NOT be saved, only use this for experimental runs!!!")

NUM_EXPERIMENTS = 100

truth file name = hetero/data/LEGENDRE_NORMCDF_STUDENT_truth_20230528_2.68_2.89.pkl
result file name = hetero/data/LEGENDRE_NORMCDF_STUDENT_result_20230528_2.68_2.89_N=50_T=20_20230602_19-30-04.pkl


=====================================================================================================
# Algorithm 

- Set configure below.

In [2]:

algo_config = AlgoConfig(
    max_num_iters=1,
    gam=2.7,
    lam=2.0,
    rho=2.0,
    should_remove_outlier=True,
    outlier_lower_perc=2,
    outlier_upper_perc=98,
    nu_coeff=1e-5,
    delta_coeff=1e-5,
    use_group_wise_regression_init=True,
)

pi_eval = AlternativePolicy(2)

grouping_config = GroupingConfig()

In [3]:
if COMPUTE_TRUTH:
    truth_data_config_init = copy.copy(data_config_init)
    truth_data_config_init.pop("num_trajectories")
    us, vs = compute_UV_truths(
        truth_data_config_init,
        algo_config.discount,
        pi_eval,
        num_repeats=100,
        num_trajectories=1000,
    )  # For best results, use num_repeats=10
    u_truth = us.mean(axis=0)
    v_truth = vs.mean(axis=0)

    with open(TRUTH_FILE, "wb") as f:
        pickle.dump(
            dict(
                u_truth=u_truth,
                v_truth=v_truth,
                data_config_dict=truth_data_config_init,
                algo_config_dict=attrs.asdict(algo_config),
            ),
            f,
        )
else:
    with open(TRUTH_FILE, "rb") as f:
        loaded = pickle.load(f)
        u_truth = loaded["u_truth"]
        v_truth = loaded["v_truth"]

In [4]:
beta_ng_list = []
beta_learned_list = []

for i in range(NUM_EXPERIMENTS):
    data_config = DataGenConfig(seed=7531 * (i + 1), **data_config_init)
    data = generate_data_from_config(data_config)
    beta_ng_list.append(
        beta_estimate_from_nongrouped(data, pi_eval, algo_config.discount)
    )
    beta_learned_list.append(
        beta_estimate_from_e2e_learning(data, algo_config, grouping_config, pi_eval)
    )

new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
beta_solver, min eigen of left matrix = 0.24322975
MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 1.14692249 -2.95176024 -2.77519263  2.85846333  1.68236943  2.54393467
  -1.1555106  -2.67626254  2.67925176  1.13213849]
 [-0.96336648  3.0026941   2.55163416 -2.91406425 -1.50798927 -2.88321968
   1.02354372  3.13186362 -2.94585661 -0.90432525]] and inertia = 250.7698131066054
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
beta_solver, min eigen of left matrix = 0.25521845
MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 0.97273351 -3.20679263 -2.65610652  2.97981374  1.7018105   2.61814966
  -1.15957419 -2.91961894  2.77228986  1.16176774]
 [-1.06856524  2.80215158  2.72259641 -2.67373027 -1.59073445 -2.8

MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 0.99087444 -2.70941659 -2.68255059  2.59555566  1.57311822  3.0454259
  -0.94688715 -2.73841705  2.78992682  0.60554372]
 [-1.06132502  2.78878919  2.96289156 -2.52783709 -1.82063589 -3.00625796
   0.95850374  2.86168731 -2.84099007 -0.67654595]] and inertia = 240.6431499070938
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
beta_solver, min eigen of left matrix = (0.27051264+0j)
MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 0.94591268 -3.13743964 -2.53575151  2.65164475  1.75803363  3.02294523
  -1.10115781 -2.68983022  2.76567837  0.70411303]
 [-1.00440975  2.59877382  2.75501468 -2.85592505 -1.49378574 -2.98214335
   1.00585961  2.61029277 -2.88956996 -0.56778988]] and inertia = 234.3000272452084
Label mismatch = 0
new_labels.length=100 matches num_

MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 0.94345107 -2.83395414 -2.96099508  2.75729389  1.82805415  3.06636026
  -0.93794991 -2.92190184  2.62956882  0.88042643]
 [-0.92335701  3.22572663  2.7020531  -2.97164391 -1.62176023 -3.00445743
   0.98323534  2.64967926 -2.88288316 -0.4712285 ]] and inertia = 264.91465908733346
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
beta_solver, min eigen of left matrix = 0.25708845
MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 1.12393625 -2.66157306 -2.76542139  2.79444831  1.47667416  2.66651752
  -1.10727032 -2.64706843  2.85185244  0.8576866 ]
 [-1.04960032  2.96092993  2.65007523 -2.427674   -1.86891885 -2.98885111
   1.06245777  2.85213584 -2.71302403 -0.85127342]] and inertia = 222.60714944721047
Label mismatch = 0
new_labels.length=100 matches num_un

MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 1.17200401 -2.88359155 -2.55812395  2.87861366  1.40788608  2.73907576
  -1.17560377 -2.80991453  2.73533262  0.96659769]
 [-1.10548013  2.93775095  2.95553502 -2.6137236  -1.91314093 -3.14252288
   0.91970075  2.97238613 -2.64711415 -0.81475996]] and inertia = 246.14187275719001
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
beta_solver, min eigen of left matrix = (0.23770674+0j)
MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 0.97679404 -2.91076585 -3.07113779  2.8231298   2.02749041  3.1339388
  -1.00619363 -2.82399677  2.86530486  0.70805602]
 [-0.94349858  2.98405372  3.06687873 -2.65508331 -2.0582081  -2.90856067
   0.95871005  2.84880241 -2.60904722 -0.93207204]] and inertia = 264.1989395317422
Label mismatch = 0
new_labels.length=100 matches num

MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 1.06732998 -2.99117754 -2.6846207   3.21749393  1.34123445  2.9352123
  -0.93649145 -2.79203704  2.89243872  0.61739725]
 [-1.20294334  2.96974904  2.70190503 -3.17757975 -1.3617562  -2.55923826
   1.16924033  2.78553166 -2.82911364 -1.05346789]] and inertia = 254.85004328694149
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
beta_solver, min eigen of left matrix = (0.25025532+0j)
MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 0.93028368 -2.81230073 -3.01285313  2.82371377  1.84074069  3.10349069
  -1.05463306 -2.87784756  2.96667234  0.69051613]
 [-0.82636333  2.87899492  2.6925398  -2.66908882 -1.78153583 -2.90964444
   0.95476909  2.86323279 -2.99700757 -0.70954418]] and inertia = 254.7954537714229
Label mismatch = 0
new_labels.length=100 matches num

MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 0.86572443 -3.0219556  -2.91558473  2.76475727  2.0062129   2.92307112
  -0.94117319 -2.87346856  2.58041963  0.98937646]
 [-1.02686435  2.89707455  2.78856082 -2.71618895 -1.79122378 -2.82819385
   1.03154     2.97749815 -2.70645694 -1.07048444]] and inertia = 237.36720235871238
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
beta_solver, min eigen of left matrix = 0.26252097
MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 0.98530769 -2.81712821 -3.07648051  3.09591166  1.61857352  2.81600154
  -1.16847814 -3.02731896  2.61719035  1.16468387]
 [-1.03375673  2.83179304  2.9030443  -2.78117262 -1.68861979 -2.76762053
   0.93144121  3.04293009 -2.85071135 -0.93319713]] and inertia = 259.4622734142068
Label mismatch = 0
new_labels.length=100 matches num_uni

MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 1.07788447 -2.95961267 -2.95760514  3.04012683  1.69099754  2.84291712
  -1.04858463 -2.87361647  2.84742981  0.86399891]
 [-1.07982534  3.1350833   3.00124911 -2.81135258 -1.89172811 -2.79146549
   1.01194172  2.69325157 -2.7914777  -0.83615446]] and inertia = 241.98618984966663
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
beta_solver, min eigen of left matrix = (0.25298014+0j)
MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 1.05975105 -3.08598215 -2.77687398  2.69982633  1.72583408  3.14683287
  -1.03069531 -2.6635229   2.60390427  0.7177683 ]
 [-1.00470237  2.87951577  2.75082598 -3.12236921 -1.50853576 -2.88119397
   1.18711638  2.89130417 -2.95218659 -0.88754494]] and inertia = 265.1957766870586
Label mismatch = 0
new_labels.length=100 matches nu

MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 1.1603752  -3.00891737 -2.88047897  2.98557547  1.64332993  2.78649361
  -0.89397437 -2.86657004  2.62344925  1.0459099 ]
 [-0.88429476  2.85291991  2.68071018 -2.84566417 -1.60141644 -3.10280832
   1.12216391  2.6680276  -2.76268576 -0.66969242]] and inertia = 243.37113221780777
Label mismatch = 0
new_labels.length=100 matches num_unique_labels
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
new_labels.length=2000 matches number of records
beta_solver, min eigen of left matrix = 0.25226626
MCPImpl: num_above=2500, num_below=2450
kmeans center = [[ 1.03545337 -3.04358093 -2.77452105  2.85916784  1.63836044  2.9824449
  -1.04549307 -2.5836372   2.48923545  0.80227661]
 [-0.97073565  2.97405873  2.68266124 -2.8414464  -1.61296427 -3.11656226
   1.20551018  2.85427346 -2.69900719 -0.87633861]] and inertia = 232.4449962293804
Label mismatch = 0
new_labels.length=100 matches num_uniq

In [5]:
beta_learned_list[0].betas

[array([-1.10325295,  2.99754758,  2.58230748, -3.02438533, -1.41873896,
        -2.77774583,  1.06721488,  3.04395721, -3.04658301, -0.87861699]),
 array([ 1.20829972, -2.94552576, -3.02892341,  2.81618039,  1.78166875,
         2.51148362, -1.21754294, -2.91799362,  2.78961736,  1.22356028])]

In [6]:
mu_learned_list = []
sigma_learned_list = []
z_score_learned_list = []
mu_ng_list = []
sigma_ng_list = []
z_score_ng_list = []

for beta_learned, beta_ng in zip(beta_learned_list, beta_ng_list):
    v_mus, v_sigmas = compute_V_estimate(u_truth, beta_learned)
    z_score_learned = [
        (mu - truth) / sigma for mu, truth, sigma in zip(v_mus, v_truth, v_sigmas)
    ]
    mu_learned_list.append(v_mus)
    sigma_learned_list.append(v_sigmas)
    z_score_learned_list.append(z_score_learned)

    ngv_mus, ngv_sigmas = compute_V_estimate(u_truth, beta_ng)
    z_score_ng = [
        (mu - truth) / sigma
        for mu, truth, sigma in zip(
            ngv_mus * len(v_truth), v_truth, ngv_sigmas * len(v_truth)
        )
    ]
    mu_ng_list.append(ngv_mus)
    sigma_ng_list.append(ngv_sigmas)
    z_score_ng_list.append(z_score_ng)

# Reports that average over two groups

In [7]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD
learned_in_threshold_perc = learned_in_threshold.sum() / learned_in_threshold.size
ng_in_threshold_perc = ng_in_threshold.sum() / ng_in_threshold.size
print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)
if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= 0.975 , ng_in_threshold_perc= 0.0


In [8]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list)
ac_mse = np.mean( (mu_learned_list - v_truth)**2 )

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list)
mv_mse = np.mean( (mu_ng_list - v_truth)**2 )

In [9]:
print("ACPE results: (average over groups)")
print(f"ACL: {ac_acl}") 
print(f"MSE: {ac_mse}")
print(f"ECP: {learned_in_threshold_perc}")


print("MVPE results: (average over groups)")
print(f"ACL: {mv_acl}") 
print(f"MSE: {mv_mse}")
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: (average over groups)
ACL: 0.2309703886472189
MSE: 0.0026643771763127464
ECP: 0.975
MVPE results: (average over groups)
ACL: 0.28797216589895486
MSE: 1.5541961331179344
ECP: 0.0


# Reports that seperate two groups

In [10]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD

learned_in_threshold_perc = np.sum(learned_in_threshold, axis=0) / learned_in_threshold.shape[0]
ng_in_threshold_perc = np.sum(ng_in_threshold, axis=0) / ng_in_threshold.shape[0]

print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)

if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= [0.97 0.98] , ng_in_threshold_perc= [0. 0.]


In [11]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list, axis=0)
ac_mse = np.mean((mu_learned_list - v_truth)**2, axis=0)

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list, axis=0)
mv_mse = np.mean( (mu_ng_list - v_truth)**2, axis=0 )

In [12]:
print("ACPE results: Group1, Group 2")
print(f"MSE: {ac_mse}")
print(f"ACL: {ac_acl}") 
print(f"ECP: {learned_in_threshold_perc}")

print("===")
print("MVPE results: ")
print(f"MSE: {mv_mse}")
print(f"ACL: {mv_acl}") 
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: Group1, Group 2
MSE: [0.00302398 0.00230478]
ACL: [0.23160256 0.23033822]
ECP: [0.97 0.98]
===
MVPE results: 
MSE: [1.5257716  1.58262066]
ACL: [0.28797217]
ECP: [0. 0.]
