-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
158 lines (135 loc) · 4.4 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from pathlib import Path
from absl import app, flags
import gymnasium as gym
import numpy as np
import rl.config as cfg
from envs import make_vec_task_env, make_vec_env
from evaluation.alternative_tasks import rollout_vec_env
from evaluation.plots import plot_tasks_successes
from gmp.config import GmpParams
from gmp.gmp import GMP
from gmp.latent_space import random_ball_numpy
flags.DEFINE_integer("seed", 0, "Seed for reproducibility.")
flags.DEFINE_enum("task", "cartpole", ["cartpole", "ring"], "Task name.")
flags.DEFINE_integer(
"diversity_latent_samples", 8, "Number of samples for the diversity loss term."
)
flags.DEFINE_float("latent_coef", 0.05, "Coefficient in front of the loss term.")
flags.DEFINE_integer("hidden_size", 64, "Size of the Dense layers inside the encoder.")
flags.DEFINE_enum(
"activation_fn", "relu", ["tanh", "relu"], "Activation function inside the encoder."
)
flags.DEFINE_integer(
"m_hidden_size", 16, "Size of the Dense layers inside the mapping network."
)
flags.DEFINE_enum(
"m_activation_fn",
"relu",
["tanh", "relu"],
"Activation function inside the mapping network.",
)
flags.DEFINE_integer("m_n_layers", 4, "Number of layers inside the mapping network.")
flags.DEFINE_enum(
"architecture",
"style",
["style", "multiplicative"],
"Architecture of the generator.",
)
flags.DEFINE_integer("n_blocks", 1, "Number of blocks in the style architecture.")
FLAGS = flags.FLAGS
def make_config(
seed: int,
task: str,
*,
diversity_latent_samples: int,
latent_coef: float,
hidden_size: int,
activation_fn: str,
m_hidden_size: int,
m_activation_fn: str,
m_n_layers: int,
architecture: str,
n_blocks: int,
n_envs: int = 8,
) -> tuple[gym.vector.VectorEnv, cfg.AlgoConfig]:
n_envs = max(2, n_envs)
envs, env_cfg = make_vec_env(task, seed, n_envs)
n_env_steps = 500_000 // n_envs if task == "cartpole" else 100_000 // n_envs
gmp_params = GmpParams(
gamma=0.99,
_lambda=0.95,
clip_eps=0.2,
entropy_coef=0.01,
value_coef=0.5,
normalize=True,
latent_size=2,
diversity_latent_samples=diversity_latent_samples,
latent_coef=latent_coef,
hidden_size=hidden_size,
activation_fn=activation_fn,
m_hidden_size=m_hidden_size,
m_activation_fn=m_activation_fn,
m_n_layers=m_n_layers,
architecture=architecture,
n_blocks=n_blocks,
)
config = cfg.AlgoConfig(
seed,
gmp_params,
cfg.UpdateConfig(
learning_rate=0.0001,
learning_rate_annealing=True,
max_grad_norm=0.5,
max_buffer_size=256,
batch_size=max(32, 256 * n_envs // 8),
n_epochs=5,
shared_encoder=False,
),
cfg.TrainConfig(n_env_steps=n_env_steps, save_frequency=n_env_steps),
env_cfg,
)
return config, envs
def eval_alt_tasks(
seed: int,
gmp: GMP,
envs: gym.Env | gym.vector.VectorEnv,
tasks: list[str],
points: np.ndarray,
):
successes = rollout_vec_env(seed + 1, gmp, envs, points, tasks)
np.savez(Path("./results").joinpath(gmp.run_name, "successes.npz"), **successes)
plot_tasks_successes(
successes,
show_plot=True,
save_path=Path("./results").joinpath(gmp.run_name, "tasks.png"),
color_fn=lambda i: ["r", "b", "g"][i],
)
def train_and_eval_alt_tasks(
seed: int,
envs: gym.Env | gym.vector.VectorEnv,
config: cfg.AlgoConfig,
points: np.ndarray,
):
gmp = GMP(config)
gmp.train(envs, config.train_cfg.n_env_steps, [])
del envs
envs, tasks = make_vec_task_env(config.env_cfg.task_name, seed, 100)
eval_alt_tasks(seed + 2, gmp, envs, tasks, points)
def main(_):
config, envs = make_config(
FLAGS.seed,
FLAGS.task,
diversity_latent_samples=FLAGS.diversity_latent_samples,
latent_coef=FLAGS.latent_coef,
hidden_size=FLAGS.hidden_size,
activation_fn=FLAGS.activation_fn,
m_hidden_size=FLAGS.m_hidden_size,
m_activation_fn=FLAGS.m_activation_fn,
m_n_layers=FLAGS.m_n_layers,
architecture=FLAGS.architecture,
n_blocks=FLAGS.n_blocks,
)
points = random_ball_numpy(np.random.default_rng(FLAGS.seed + 1), 10_000, 2)
train_and_eval_alt_tasks(FLAGS.seed + 2, envs, config, points)
if __name__ == "__main__":
app.run(main)