In [None]:
from environments.toy_env import ToyEnv
from utils.policy_evaluation import evaluate_policy
from policies.generic_policies import EpsilonSmoothPolicy
from policies.toy_env_policies import ThresholdPolicy
from utils.offline_dataset import OfflineRLDataset
from models.fnn_nuisance_model import FeedForwardNuisanceModel
from models.fnn_critic import FeedForwardCritic
from learners.iterative_sieve_critic import IterativeSieveLearner

s_threshold = 2.0
gamma = 0.9
adversarial_lambda = 4.0
batch_size = 1024
num_sample = 10000

device = None

env = ToyEnv(s_init=s_threshold, adversarial=False)
pi_e = ThresholdPolicy(env, s_threshold=s_threshold)
pi_e_name = "pi_e"

dataset_path_train = "tmp_dataset/train_data"
dataset_path_test = "tmp_dataset/test_data"

In [None]:
## build datasets and save them

pi_base = ThresholdPolicy(env, s_threshold=1.5)
pi_b = EpsilonSmoothPolicy(env, pi_base=pi_base, epsilon=0.1)

dataset = OfflineRLDataset()
dataset.sample_new_trajectory(env=env, pi=pi_b, burn_in=1000,
                              num_sample=num_sample, thin=10)

test_dataset = OfflineRLDataset()
test_dataset.sample_new_trajectory(env=env, pi=pi_b, burn_in=1000,
                                   num_sample=num_sample, thin=10)

dataset.apply_eval_policy(pi_e_name, pi_e)
test_dataset.apply_eval_policy(pi_e_name, pi_e)

dataset.save_dataset(dataset_path_train)
test_dataset.save_dataset(dataset_path_test)

dataset.to(device)
test_dataset.to(device)

In [None]:
## double check that dataset is loadable and looks correct

dataset_tmp = OfflineRLDataset.load_dataset(dataset_path_train)
dataset_tmp.to(device)
dl = dataset_tmp.get_batch_loader(batch_size=10)
# for i, batch in enumerate(dl):
#     for k, v in batch.items():
#         print(k, v.shape)
#         print(v)
#     print("")
#     if i > 10:
#         break
batch = next(iter(dl))
f = env.init_basis_func(batch["s"], batch["a"])
print(f.shape)
print(f.sum(1))
for i in range(len(f)):
    print(f[i])

In [None]:
## train model on q / xi / eta moments

s_dim = env.get_s_dim()
num_a = env.get_num_a()

model_do = 0.05
model_config = {
    "s_embed_dim": 32,
    "s_embed_layers": [32],
    "s_embed_do": model_do,
    "a_embed_dim": 32,
    "sa_feature_dim": 64,
    "sa_feature_layers": [64],
    "sa_feature_do": model_do,
    "q_layers": [64, 64],
    "q_do": model_do,
    "beta_layers": [64, 64],
    "beta_do": model_do,
    "w_layers": [64, 64],
    "w_do": model_do,
    "eta_layers": [64, 64],
    "eta_do": model_do,
}
# model_config = {
#     "s_embed_dim": 16,
#     "s_embed_layers": [16],
#     "s_embed_do": model_do,
#     "a_embed_dim": 16,
#     "sa_feature_dim": 32,
#     "sa_feature_layers": [32],
#     "sa_feature_do": model_do,
#     "q_layers": [32, 16],
#     "q_do": model_do,
#     "beta_layers": [32, 16],
#     "beta_do": model_do,
#     "w_layers": [32, 16],
#     "w_do": model_do,
#     "eta_layers": [32, 16],
#     "eta_do": model_do,
# }
model = FeedForwardNuisanceModel(s_dim=s_dim, num_a=num_a, gamma=gamma,
                                 config=model_config, device=device)
critic_class = FeedForwardCritic
critic_do = 0.05
critic_config = {
    "s_embed_dim": 8,
    "s_embed_layers": [8],
    "s_embed_do": critic_do,
    "a_embed_dim": 8,
    "critic_layers": [16],
    "critic_do": critic_do,
}
critic_kwargs = {
    "s_dim": s_dim,
    "num_a": num_a,
    "config": critic_config
}

learner_1 = IterativeSieveLearner(
    nuisance_model=model, gamma=gamma, adversarial_lambda=adversarial_lambda,
    train_q_xi=True, train_eta=True, train_w=True,
)

s_init, a_init = env.get_s_a_init(pi_e)
dl_test = test_dataset.get_batch_loader(batch_size=batch_size)
evaluate_pv_kwargs = {
    "s_init": s_init, "a_init": a_init,
    "dl_test": dl_test, "pi_e_name": pi_e_name,
}
learner_1.train(
    dataset, pi_e_name=pi_e_name, verbose=True, device=device,
    init_basis_func=env.bias_basis_func, num_init_basis=1,
    # init_basis_func=env.flexible_basis_func,
    # num_init_basis=env.get_num_init_basis_func(),
    # model_lr=1e-4,
    # num_init_basis=env.get_num_init_basis_func(),
    evaluate_pv_kwargs=evaluate_pv_kwargs, critic_class=critic_class,
    s_init=s_init, critic_kwargs=critic_kwargs,
)
model.save_model("tmp_model")

In [None]:
## load pre-sampled datasets 

dataset = OfflineRLDataset.load_dataset(dataset_path_train)
dataset.to(device)
test_dataset = OfflineRLDataset.load_dataset(dataset_path_test)
test_dataset.to(device)
model = FeedForwardNuisanceModel.load_model("tmp_model")
model.to(device)


In [None]:

s_dim = env.get_s_dim()
num_a = env.get_num_a()

model_do = 0.05
model_config = {
    "s_embed_dim": 16,
    "s_embed_layers": [16],
    "s_embed_do": model_do,
    "a_embed_dim": 16,
    "sa_feature_dim": 32,
    "sa_feature_layers": [32],
    "sa_feature_do": model_do,
    "q_layers": [32, 16],
    "q_do": model_do,
    "beta_layers": [32, 16],
    "beta_do": model_do,
    "w_layers": [32, 16],
    "w_do": model_do,
    "eta_layers": [32, 16],
    "eta_do": model_do,
}
model = FeedForwardNuisanceModel(s_dim=s_dim, num_a=num_a, gamma=gamma,
                                 config=model_config, device=device)


In [None]:
## train the w model

s_dim = env.get_s_dim()
num_a = env.get_num_a()

critic_class = FeedForwardCritic
critic_do = 0.05
critic_config = {
    "s_embed_dim": 8,
    "s_embed_layers": [8],
    "s_embed_do": critic_do,
    "a_embed_dim": 8,
    "critic_layers": [16],
    "critic_do": critic_do,
}
critic_kwargs = {
    "s_dim": s_dim,
    "num_a": num_a,
    "config": critic_config
}

learner_2 = IterativeSieveLearner(
    nuisance_model=model,
    gamma=gamma, adversarial_lambda=adversarial_lambda,
    train_q_xi=True, train_eta=False, train_w=False,
)
s_init, a_init = env.get_s_a_init(pi_e)
dl_test = test_dataset.get_batch_loader(batch_size=batch_size)
evaluate_pv_kwargs = {
    "s_init": s_init, "a_init": a_init,
    "dl_test": dl_test, "pi_e_name": pi_e_name,
}
learner_2.train(
    dataset, pi_e_name=pi_e_name, verbose=True, device=device,
    init_basis_func=env.init_basis_func, num_init_basis=env.get_num_init_basis_func(),
    evaluate_pv_kwargs=evaluate_pv_kwargs, critic_class=critic_class,
    s_init=s_init, critic_kwargs=critic_kwargs, model_lr=1e-3, model_eval_freq=5,
    model_max_epoch=500, model_grad_clip=None, model_grad_clip_final=None,
    model_min_epoch=50, model_max_no_improve=3, total_num_iterations=20,
    model_reg_alpha=1e-3, model_reg_alpha_final=1e-3, critic_eval_freq=5,
    gamma_tik=1e-3, gamma_0=1e-3,
)
model.save_model("tmp_model_w")

In [None]:
dl = iter(dataset.get_batch_loader(batch_size=10))
batch = next(dl)
eta = model.get_eta(s=batch["s"], a=batch["a"])
w = model.get_w(s=batch["s"])
print(f"eta: {eta}")
print(f"w: {w}")

for batch in dataset.get_batch_loader(batch_size=1000):
    eta = model.get_eta(s=batch["s"], a=batch["a"])
    w = model.get_w(s=batch["s"])
    print(f"eta mean: {eta.mean()}")
    print(f"w mean: {w.mean()}")
    print("")


In [None]:
## evaluate model using 3 policy value estimators

s_init, a_init = env.get_s_a_init(pi_e)
dl_test = test_dataset.get_batch_loader(batch_size=batch_size)

q_pv = model.estimate_policy_val_q(
    s_init=s_init, a_init=a_init, gamma=gamma
)
w_pv = model.estimate_policy_val_w(dl=dl_test)
w_pv_norm = model.estimate_policy_val_w(dl=dl_test, normalize=True)
dr_pv = model.estimate_policy_val_dr(
    s_init=s_init, a_init=a_init, pi_e_name=pi_e_name, dl=dl_test,
    adversarial_lambda=adversarial_lambda, gamma=gamma
)
dr_pv_norm = model.estimate_policy_val_dr(
    s_init=s_init, a_init=a_init, pi_e_name=pi_e_name, dl=dl_test,
    adversarial_lambda=adversarial_lambda, gamma=gamma, normalize=True,
)
print(f"EVALUATING FINAL BEST MODEL:")
print(f"Q-estimated v(pi_e): {q_pv}")
print(f"W-estimated v(pi_e): {w_pv}")
print(f"W-estimated v(pi_e) (normalized): {w_pv_norm}")
print(f"DS/DV-estimated v(pi_e): {dr_pv}")
print(f"DS/DV-estimated v(pi_e) (normalized): {dr_pv_norm}")
print("")

env_eval = ToyEnv(s_init=s_threshold, adversarial=True,
                    adversarial_lambda=adversarial_lambda)
pi_e_val = evaluate_policy(env_eval, pi_e, gamma, min_prec=1e-4)
print(f"true v(pi_e): {pi_e_val}")
print("")

