-
Notifications
You must be signed in to change notification settings - Fork 1
/
eval.py
114 lines (91 loc) · 3.87 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import torch
import argparse
import numpy as np
from tqdm.auto import trange
from omegaconf import OmegaConf
from stable_baselines3.common.vec_env import DummyVecEnv
from trajectory.models.trajectory import TrajectoryModel
from trajectory.utils.common import set_seed
from trajectory.utils.env import create_env, rollout, vec_rollout
from optimizations import quantizer
def create_argparser():
parser = argparse.ArgumentParser(description="Trajectory Transformer evaluation hyperparameters. All can be set from command line.")
parser.add_argument("--config", default="configs/medium/halfcheetah_medium.yaml")
parser.add_argument("--seed", default=42, type=int)
parser.add_argument("--device", default="cpu", type=str)
return parser
def run_experiment(config, seed, device):
set_seed(seed=seed)
run_config = OmegaConf.load(os.path.join(config.checkpoints_path, "config.yaml"))
discretizer = torch.load(os.path.join(config.checkpoints_path, "discretizer.pt"), map_location=device)
model = TrajectoryModel(**run_config.model)
model.to(device)
model.eval()
model.load_state_dict(torch.load(os.path.join(config.checkpoints_path, config.model_name), map_location=device))
if config.quantize:
print(f"Using [{config.q_type}] quantized model.")
example_context = torch.ones((1600,1)).int().to(device)
model = quantizer(model, (example_context), q_type=config.q_type)
if config.vectorized:
env = DummyVecEnv([lambda: create_env(run_config.dataset.env_name) for _ in range(config.num_episodes)])
rewards = vec_rollout(
vec_env=env,
model=model,
discretizer=discretizer,
beam_context_size=config.beam_context,
beam_width=config.beam_width,
beam_steps=config.beam_steps,
plan_every=config.plan_every,
sample_expand=config.sample_expand,
k_act=config.k_act,
k_obs=config.k_obs,
k_reward=config.k_reward,
temperature=config.temperature,
discount=config.discount,
max_steps=env.envs[0].max_episode_steps,
device=device
)
scores = [env.envs[0].get_normalized_score(r) for r in rewards]
else:
rewards, scores = [], []
env = create_env(run_config.dataset.env_name)
for i in trange(config.num_episodes, desc="Evaluation (not vectorized)"):
reward = rollout(
env=env,
model=model,
discretizer=discretizer,
beam_context_size=config.beam_context,
beam_width=config.beam_width,
beam_steps=config.beam_steps,
plan_every=config.plan_every,
sample_expand=config.sample_expand,
k_act=config.k_act,
k_obs=config.k_obs,
k_reward=config.k_reward,
temperature=config.temperature,
discount=config.discount,
max_steps=config.get("max_steps", None) or env.max_episode_steps,
render_path=os.path.join(config.render_path, str(i)),
device=device
)
rewards.append(reward)
scores.append(env.get_normalized_score(reward))
reward_mean, reward_std = np.mean(rewards), np.std(rewards)
score_mean, score_std = np.mean(scores), np.std(scores)
print(f"Evaluation on {run_config.dataset.env_name}")
print(f"Mean reward: {reward_mean} ± {reward_std}")
print(f"Mean score: {score_mean} ± {score_std}")
def main():
args, override = create_argparser().parse_known_args()
config = OmegaConf.merge(
OmegaConf.load(args.config),
OmegaConf.from_cli(override)
)
run_experiment(
config=config,
seed=args.seed,
device=args.device
)
if __name__ == "__main__":
main()