In [1]:
import copy
import pickle
from datetime import datetime as dtdt

import attrs
import numpy as np

from hetero.config import DTYPE, AlgoConfig, DataGenConfig, GroupingConfig
from hetero.datagen import generate_data_from_config
from hetero.policies import AlternativePolicy
from hetero.tasks import (
    beta_estimate_from,
    beta_estimate_from_e2e_learning,
    beta_estimate_from_nongrouped,
    compute_UV_truths,
    compute_V_estimate,
)

group_reward_coeff = np.array(
    [
        [-2, 2, 2, -2],
        [2, -2, -2, 2],
    ],
    dtype=DTYPE,
)

action_reward_coeff = [-1, 1]

FEATURE_TYPE = "LEGENDRE"
TRANS = "NORMCDF"
NOISE = "STUDENT"

data_config_init = dict(
    num_trajectories=100,
    num_time_steps=20,
    group_reward_coeff_override=group_reward_coeff,
    action_reward_coeff=action_reward_coeff,
    num_burnin_steps=100,
    basis_expansion_method=FEATURE_TYPE,
    transformation_method=TRANS,
    add_intercept_column=True,
    noise_type=NOISE,
    noise_student_degree=4,
)

# First time runner: set COMPUTE_TRUTH = True
# Change the flag to False after generated the truth file.
COMPUTE_TRUTH = False #True
# Change the truth file name if settings are changed.
TRUTH_FILE = f"hetero/data/{FEATURE_TYPE}_{TRANS}_{NOISE}_truth_20230528_2.68_2.89.pkl"
print("truth file name =", TRUTH_FILE)

time_tag = dtdt.now().strftime("%Y%m%d_%H-%M-%S")
tag = f'N={data_config_init["num_trajectories"]}_T={data_config_init["num_time_steps"]}_{time_tag}'
RESULT_FILE = f"hetero/data/{FEATURE_TYPE}_{TRANS}_{NOISE}_result_20230528_2.68_2.89_{tag}.pkl"
print("result file name =", RESULT_FILE)

SAVE_RESULT = True
if not SAVE_RESULT:
    print("Result will NOT be saved, only use this for experimental runs!!!")

NUM_EXPERIMENTS = 100

truth file name = hetero/data/LEGENDRE_NORMCDF_STUDENT_truth_20230528_2.68_2.89.pkl
result file name = hetero/data/LEGENDRE_NORMCDF_STUDENT_result_20230528_2.68_2.89_N=100_T=20_20230602_19-37-10.pkl


=====================================================================================================
# Algorithm 

- Set configure below.

In [2]:

algo_config = AlgoConfig(
    max_num_iters=2,
    gam=2.7,
    lam=2.0,
    rho=2.0,
    should_remove_outlier=True,
    outlier_lower_perc=2,
    outlier_upper_perc=98,
    nu_coeff=1e-5,
    delta_coeff=1e-5,
    use_group_wise_regression_init=True,
)

pi_eval = AlternativePolicy(2)

grouping_config = GroupingConfig()

In [3]:
if COMPUTE_TRUTH:
    truth_data_config_init = copy.copy(data_config_init)
    truth_data_config_init.pop("num_trajectories")
    us, vs = compute_UV_truths(
        truth_data_config_init,
        algo_config.discount,
        pi_eval,
        num_repeats=100,
        num_trajectories=1000,
    )  # For best results, use num_repeats=10
    u_truth = us.mean(axis=0)
    v_truth = vs.mean(axis=0)

    with open(TRUTH_FILE, "wb") as f:
        pickle.dump(
            dict(
                u_truth=u_truth,
                v_truth=v_truth,
                data_config_dict=truth_data_config_init,
                algo_config_dict=attrs.asdict(algo_config),
            ),
            f,
        )
else:
    with open(TRUTH_FILE, "rb") as f:
        loaded = pickle.load(f)
        u_truth = loaded["u_truth"]
        v_truth = loaded["v_truth"]

In [4]:
beta_ng_list = []
beta_learned_list = []

for i in range(NUM_EXPERIMENTS):
    data_config = DataGenConfig(seed=7531 * (i + 1), **data_config_init)
    data = generate_data_from_config(data_config)
    beta_ng_list.append(
        beta_estimate_from_nongrouped(data, pi_eval, algo_config.discount)
    )
    beta_learned_list.append(
        beta_estimate_from_e2e_learning(data, algo_config, grouping_config, pi_eval)
    )

new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = 0.15831704
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.15831704
MCPImpl: num_above=13898, num_below=6002
kmeans center = [[ 1.18708021 -2.90183154 -2.64112197  2.95300157  1.44757148  2.76931768
  -1.22485987 -2.59704889  2.9426306   0.76873594]
 [-1.21432977  2.77521468  2.75303926 -2.78811103 -1.60641951 -2.80714134
   1.23143624  2.65530498 -2.73401882 -1.05245369]] and inertia = 3978.160303125423
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.1576327+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.1576327+0j)
MCPImpl: num_above=14354, num_below=5546
kmeans center

new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = 0.1501897
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.1501897
MCPImpl: num_above=14253, num_below=5647
kmeans center = [[ 1.16393395 -2.81791091 -2.8274689   2.53516908  1.77862165  2.63525165
  -1.0227578  -2.89812073  3.07264251  0.87668134]
 [-1.10374923  2.87747602  2.69643498 -3.04702946 -1.39167812 -2.949561
   1.13951485  2.61336424 -2.66560568 -0.79649938]] and inertia = 4129.57992092343
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = 0.15161552
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.15161552
MCPImpl: num_above=14750, num_below=5150
kmeans center = [[ 1.20145

new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.15777941+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.15777941+0j)
MCPImpl: num_above=14057, num_below=5843
kmeans center = [[ 1.07964241 -2.77340436 -2.84939454  2.70807175  1.62924931  2.40632551
  -1.21407129 -2.78963796  2.67387281  1.17597796]
 [-1.04890888  2.99526345  2.76951248 -2.54990347 -1.85422899 -3.07660717
   0.97552712  2.60791494 -2.91283918 -0.51303311]] and inertia = 3993.9419490724003
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.15318839+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.15318839+0j)
MCPImpl: num_above=14933, num_below=4967


new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = 0.15828243
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.15828243
MCPImpl: num_above=14173, num_below=5727
kmeans center = [[ 1.19851939 -2.83069257 -2.9323872   2.43677753  1.82381295  2.75724159
  -1.18772025 -2.95028863  2.33613524  1.22253489]
 [-1.05895705  2.86219335  2.76491973 -2.5410271  -1.69223532 -2.90724784
   1.1061974   2.69733858 -2.42345561 -0.92689504]] and inertia = 4037.0927762446054
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = 0.16032295
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.16032295
MCPImpl: num_above=14649, num_below=5251
kmeans center = [[ 1

new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = 0.16165566
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.16165566
MCPImpl: num_above=14478, num_below=5422
kmeans center = [[ 1.16770832 -2.77682434 -2.81310542  2.65512216  1.55131272  2.69732953
  -1.10328846 -2.81402137  2.50288766  1.02024244]
 [-1.181579    2.7725284   2.67779862 -2.55271422 -1.5348131  -2.84931821
   1.22091498  2.83167529 -2.84658188 -0.83629275]] and inertia = 4224.528710558672
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.15712547+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.15712547+0j)
MCPImpl: num_above=14578, num_below=5322
kmeans cent

new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = 0.15726395
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.15726395
MCPImpl: num_above=14551, num_below=5349
kmeans center = [[ 1.08579636 -2.83747333 -2.98818751  2.43060064  1.81299601  3.00537464
  -1.36956641 -2.65764295  2.6320613   0.83891734]
 [-1.36309497  2.90846472  3.03551819 -3.00705509 -1.44571002 -2.82445601
   0.98472967  2.98326711 -2.97626938 -0.69042   ]] and inertia = 4258.402361787837
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = 0.15662919
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.15662919
MCPImpl: num_above=14794, num_below=5106
kmeans center = [[ 1.

new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = 0.16003326
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.16003326
MCPImpl: num_above=14530, num_below=5370
kmeans center = [[ 1.09651104 -2.74802039 -2.79257386  2.78332768  1.50912164  2.95098427
  -1.00740386 -2.89462424  2.83947212  0.72909621]
 [-1.14278868  3.05966937  2.71293452 -2.72975428 -1.6147879  -2.72281577
   1.26974684  2.77090243 -2.87313503 -0.92734052]] and inertia = 4209.6180506201
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = 0.160246
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.160246
MCPImpl: num_above=14372, num_below=5528
kmeans center = [[ 1.210948

new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = 0.1602924
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.1602924
MCPImpl: num_above=14418, num_below=5482
kmeans center = [[ 1.24073963 -3.09506208 -2.68116829  2.85078695  1.46315503  2.77252868
  -1.18416926 -2.7753248   2.57748689  0.94595502]
 [-0.97291994  2.63940198  2.83573289 -2.80861692 -1.40026285 -2.88755781
   1.03359759  2.68215053 -2.8245253  -0.52929147]] and inertia = 4210.959356477089
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = (0.15808435+0j)
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = (0.15808435+0j)
MCPImpl: num_above=14439, num_below=5461
kmeans center

new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = 0.16357933
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.16357933
MCPImpl: num_above=14379, num_below=5521
kmeans center = [[ 0.93310645 -2.92448854 -2.65624913  2.83672869  1.58529553  2.78441
  -1.10616868 -2.36350326  2.58351522  0.76675231]
 [-1.13236777  3.03534772  2.71228349 -2.69631462 -1.68463086 -2.9317929
   1.13609188  2.76121048 -2.78696381 -0.77983589]] and inertia = 4233.492047639062
Label mismatch = 0
new_labels.length=200 matches num_unique_labels
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
new_labels.length=4000 matches number of records
beta_solver, min eigen of left matrix = 0.15733325
MCPImpl: num_above=10000, num_below=9900
beta_solver, min eigen of left matrix = 0.15733325
MCPImpl: num_above=14617, num_below=5283
kmeans center = [[ 1.1104

In [5]:
beta_learned_list[0].betas

[array([-1.13541424,  2.84603798,  2.85190309, -2.92513476, -1.59776357,
        -2.87165767,  1.1358938 ,  2.76265427, -2.88373604, -0.8898509 ]),
 array([ 1.17887679, -2.95271453, -2.76481393,  2.99874504,  1.49515949,
         2.80739732, -1.15199582, -2.70936204,  2.84731498,  0.81818536])]

In [6]:
mu_learned_list = []
sigma_learned_list = []
z_score_learned_list = []
mu_ng_list = []
sigma_ng_list = []
z_score_ng_list = []

for beta_learned, beta_ng in zip(beta_learned_list, beta_ng_list):
    v_mus, v_sigmas = compute_V_estimate(u_truth, beta_learned)
    z_score_learned = [
        (mu - truth) / sigma for mu, truth, sigma in zip(v_mus, v_truth, v_sigmas)
    ]
    mu_learned_list.append(v_mus)
    sigma_learned_list.append(v_sigmas)
    z_score_learned_list.append(z_score_learned)

    ngv_mus, ngv_sigmas = compute_V_estimate(u_truth, beta_ng)
    z_score_ng = [
        (mu - truth) / sigma
        for mu, truth, sigma in zip(
            ngv_mus * len(v_truth), v_truth, ngv_sigmas * len(v_truth)
        )
    ]
    mu_ng_list.append(ngv_mus)
    sigma_ng_list.append(ngv_sigmas)
    z_score_ng_list.append(z_score_ng)

# Reports that average over two groups

In [7]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD
learned_in_threshold_perc = learned_in_threshold.sum() / learned_in_threshold.size
ng_in_threshold_perc = ng_in_threshold.sum() / ng_in_threshold.size
print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)
if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= 0.96 , ng_in_threshold_perc= 0.0


In [8]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list)
ac_mse = np.mean( (mu_learned_list - v_truth)**2 )

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list)
mv_mse = np.mean( (mu_ng_list - v_truth)**2 )

In [9]:
print("ACPE results: (average over groups)")
print(f"ACL: {ac_acl}") 
print(f"MSE: {ac_mse}")
print(f"ECP: {learned_in_threshold_perc}")


print("MVPE results: (average over groups)")
print(f"ACL: {mv_acl}") 
print(f"MSE: {mv_mse}")
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: (average over groups)
ACL: 0.1630770841088718
MSE: 0.001545880952244807
ECP: 0.96
MVPE results: (average over groups)
ACL: 0.20352345303736818
MSE: 1.5521851104754842
ECP: 0.0


# Reports that seperate two groups

In [10]:
z_score_learned = np.array(z_score_learned_list)
z_score_ng = np.array(z_score_ng_list)
Z_THRESHOLD = 1.96
learned_in_threshold = np.abs(z_score_learned) < Z_THRESHOLD
ng_in_threshold = np.abs(z_score_ng) < Z_THRESHOLD

learned_in_threshold_perc = np.sum(learned_in_threshold, axis=0) / learned_in_threshold.shape[0]
ng_in_threshold_perc = np.sum(ng_in_threshold, axis=0) / ng_in_threshold.shape[0]

print(
    "learned_in_threshold_perc=",
    learned_in_threshold_perc,
    ", ng_in_threshold_perc=",
    ng_in_threshold_perc,
)

if SAVE_RESULT:
    with open(RESULT_FILE, "wb") as f:
        pickle.dump(
            dict(
                mu_learned_list=mu_learned_list,
                sigma_learned_list=sigma_learned_list,
                z_score_learned=z_score_learned,
                mu_ng_list=mu_ng_list,
                sigma_ng_list=sigma_ng_list,
                z_score_ng=z_score_ng,
                beta_learned_list=beta_learned_list,
                beta_ng_list=beta_ng_list,
            ),
            f,
        )

learned_in_threshold_perc= [0.96 0.96] , ng_in_threshold_perc= [0. 0.]


In [11]:
ac_acl = 2*Z_THRESHOLD*np.mean(sigma_learned_list, axis=0)
ac_mse = np.mean((mu_learned_list - v_truth)**2, axis=0)

mv_acl = 2*Z_THRESHOLD*np.mean(sigma_ng_list, axis=0)
mv_mse = np.mean( (mu_ng_list - v_truth)**2, axis=0 )

In [12]:
print("ACPE results: Group1, Group 2")
print(f"MSE: {ac_mse}")
print(f"ACL: {ac_acl}") 
print(f"ECP: {learned_in_threshold_perc}")

print("===")
print("MVPE results: ")
print(f"MSE: {mv_mse}")
print(f"ACL: {mv_acl}") 
print(f"ECP: {ng_in_threshold_perc}")

ACPE results: Group1, Group 2
MSE: [0.00153565 0.00155611]
ACL: [0.16310101 0.16305316]
ECP: [0.96 0.96]
===
MVPE results: 
MSE: [1.5385389  1.56583132]
ACL: [0.20352345]
ECP: [0. 0.]
