-
Notifications
You must be signed in to change notification settings - Fork 0
/
build_datasets.py
54 lines (40 loc) · 1.93 KB
/
build_datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import json
from environments.toy_env import ToyEnv
# from experiment_config import config
from policies.generic_policies import EpsilonSmoothPolicy
from policies.toy_env_policies import ThresholdPolicy
from utils.offline_dataset import OfflineRLDataset
def main(config_path="experiment_config.json"):
with open(config_path) as f:
config = json.load(f)
num_rep_range = config["num_rep_range"]
s_threshold = config["s_threshold"]
num_sample = config["num_sample"]
pi_b_threshold = config["pi_b_threshold"]
pi_b_epsilon = config["pi_b_epsilon"]
for i in num_rep_range:
print(f"building datasets for rep {i}")
env = ToyEnv(s_init=s_threshold, adversarial=False)
pi_e = ThresholdPolicy(env, s_threshold=s_threshold)
pi_e_name = config["pi_e_name"]
base_dataset_path_train = config["base_dataset_path_train"]
base_dataset_path_test = config["base_dataset_path_test"]
## build datasets and save them
pi_base = ThresholdPolicy(env, s_threshold=pi_b_threshold)
pi_b = EpsilonSmoothPolicy(env, pi_base=pi_base, epsilon=pi_b_epsilon)
dataset = OfflineRLDataset()
burn_in = config["burn_in"]
thin = config["thin"]
dataset.sample_new_trajectory(env=env, pi=pi_b, burn_in=burn_in,
num_sample=num_sample, thin=thin)
test_dataset = OfflineRLDataset()
test_dataset.sample_new_trajectory(env=env, pi=pi_b, burn_in=burn_in,
num_sample=num_sample, thin=thin)
dataset.apply_eval_policy(pi_e_name, pi_e)
test_dataset.apply_eval_policy(pi_e_name, pi_e)
dataset_path_train = "_".join([base_dataset_path_train, str(i)])
dataset_path_test = "_".join([base_dataset_path_test, str(i)])
dataset.save_dataset(dataset_path_train)
test_dataset.save_dataset(dataset_path_test)
if __name__ == "__main__":
main()