diff --git a/tests/test_agent_snapshot.py b/tests/test_agent_snapshot.py new file mode 100644 index 0000000..cff2ed0 --- /dev/null +++ b/tests/test_agent_snapshot.py @@ -0,0 +1,78 @@ +import pytest +from udacity_rl.agents import AgentSnapshot +from udacity_rl.agents.agent import AgentInterface + + +class AgentStub(AgentInterface): + @property + def action_size(self): + pass + + @property + def observation_space(self): + pass + + @property + def action_space(self): + pass + + @property + def configuration(self): + pass + + def act(self, observation, epsilon=0): + pass + + def step(self, obs, action, reward, next_obs, done): + pass + + def train(self): + pass + + def save(self, save_path): + pass + + def load(self, save_path): + pass + + +class AgentSpy(AgentStub): + def __init__(self): + self.received_save_path = None + self.num_save_calls = 0 + + def save(self, save_path): + self.received_save_path = save_path + self.num_save_calls += 1 + super().save(save_path) + + +@pytest.fixture +def agent(): + return AgentSpy() + + +def test_agent_snapshot_does_nothing_if_new_score_is_below_target(agent, tmp_path): + snapshot = AgentSnapshot(agent, 10, tmp_path / "snapshot") + snapshot.new_score(10) + assert agent.received_save_path is None + + +def test_agent_snapshot_saves_agent_if_new_score_is_above_target(agent, tmp_path): + snapshot = AgentSnapshot(agent, 10, tmp_path / "snapshot") + snapshot.new_score(11) + assert agent.received_save_path == tmp_path / "snapshot" + + +def test_agent_snapshot_updates_new_target(agent, tmp_path): + snapshot = AgentSnapshot(agent, 10, tmp_path / "snapshot") + snapshot.new_score(11) + assert agent.received_save_path == tmp_path / "snapshot" + snapshot.new_score(11) + assert agent.num_save_calls == 1 + + +def test_do_nothing_if_target_is_none(agent, tmp_path): + snapshot = AgentSnapshot(agent, None, tmp_path / "snapshot") + snapshot.new_score(11) + assert agent.num_save_calls == 0 diff --git a/tests/test_cli_functions.py b/tests/test_cli_functions.py index b45abd7..5a2ab65 100644 --- a/tests/test_cli_functions.py +++ b/tests/test_cli_functions.py @@ -6,12 +6,12 @@ def test_navigation_training_invocation(): _, scores = run_train_session(GymEnvFactory('gym_quickcheck:random-walk-v0'), AgentFactory('DQN'), 1000, dict(), - None) + None, None) assert np.mean(scores[-100:]) == approx(-1, abs=0.3) def test_navigation_run_invocation(): agent, _ = run_train_session(GymEnvFactory('gym_quickcheck:random-walk-v0'), AgentFactory('DQN'), 1000, dict(), - None) + None, None) scores = run_test_session(agent, GymEnvFactory('gym_quickcheck:random-walk-v0'), 100) assert np.mean(scores) == approx(-1, abs=0.3) diff --git a/udacity_rl/agents/__init__.py b/udacity_rl/agents/__init__.py index 390a77f..58f6391 100644 --- a/udacity_rl/agents/__init__.py +++ b/udacity_rl/agents/__init__.py @@ -1,10 +1,13 @@ import json +import logging import pickle from udacity_rl.agents.dqn_agent import DQNAgent from udacity_rl.agents.ddpg_agent import DDPGAgent from udacity_rl.agents.maddpg_agent import MADDPGAgent +logger = logging.getLogger(__name__) + _CLASS_MAPPING = { DQNAgent.__name__: DQNAgent, DDPGAgent.__name__: DDPGAgent, @@ -37,3 +40,17 @@ def agent_load(path): agent = _CLASS_MAPPING[agent_type](obs_space, act_space, **cfg) agent.load(path) return agent + + +class AgentSnapshot: + def __init__(self, agent, target_score, path): + self._agent = agent + self._target = target_score + self._path = path + + def new_score(self, score): + if self._target is not None and score > self._target: + logger.info(f"saving agent ({self._path}) snapshot with score: {score}") + agent_save(self._agent, self._path) + self._target = score + logger.info(f"new save threshold: {self._target}") diff --git a/udacity_rl/main.py b/udacity_rl/main.py index a5c0cae..bb713d9 100644 --- a/udacity_rl/main.py +++ b/udacity_rl/main.py @@ -12,7 +12,7 @@ from unityagents import UnityEnvironment from udacity_rl.adapter import GymAdapter -from udacity_rl.agents import DQNAgent, agent_load, agent_save +from udacity_rl.agents import DQNAgent, agent_load, agent_save, AgentSnapshot from udacity_rl.agents.ddpg_agent import DDPGAgent from udacity_rl.agents.maddpg_agent import MADDPGAgent from udacity_rl.epsilon import EpsilonExpDecay, NoiseFixed @@ -97,8 +97,10 @@ def environment_session(env_factory, *args, **kwargs): help="path to store the agent at (default: /tmp/agent_ckpt)") @click.option('--max-t', default=None, type=click.INT, help="maximum episode steps (default: None)") +@click.option('--save-at', default=None, type=click.FLOAT, + help="save at average score greater than specified. No snapshots when None. (default: None)") @click.pass_context -def train(ctx, algorithm, episodes, config, output, max_t): +def train(ctx, algorithm, episodes, config, output, max_t, save_at): """ train the agent with the specified algorithm on the environment for the given amount of episodes """ @@ -106,7 +108,7 @@ def train(ctx, algorithm, episodes, config, output, max_t): if config is not None: cfg = json.load(config) - agent, scores = run_train_session(ctx.obj['env_factory'], AgentFactory(algorithm), episodes, cfg, max_t) + agent, scores = run_train_session(ctx.obj['env_factory'], AgentFactory(algorithm), episodes, cfg, max_t, save_at) agent_save(agent, Path(output)) plot_scores(scores) @@ -115,7 +117,7 @@ def _squeeze_box(box): return Box(box.low[0][0], box.high[0][0], shape=(box.shape[1],)) -def run_train_session(env_fac, agent_fac, episodes, config, max_t): +def run_train_session(env_fac, agent_fac, episodes, config, max_t, save_at): with environment_session(env_fac, train_mode=True) as env: if 'act_noise_std' in config: eps_calc = NoiseFixed(config['act_noise_std']) @@ -126,42 +128,51 @@ def run_train_session(env_fac, agent_fac, episodes, config, max_t): logger.info(f"Epsilon configuration:\n" f"\t{eps_calc}\n") + scores = run_session(agent, env, episodes, train_frequency=config.get('train_frequency', 4), eps_calc=eps_calc, - max_t=max_t) + max_t=max_t, + save_at=save_at) return agent, scores -def run_session(agent, env, episodes, train_frequency=None, eps_calc=None, max_t=None): +def run_session(agent, env, episodes, train_frequency=None, eps_calc=None, max_t=None, save_at=None): step = 0 scores_last = deque(maxlen=100) scores_all = list() - for episode in range(episodes): - done = False - score = 0 - obs = env.reset() - t = 0 - while not done and (max_t is None or t < max_t): - action = agent.act(obs, 0 if eps_calc is None else eps_calc.epsilon) - next_obs, reward, done, _ = env.step(action) - agent.step(obs, action, reward, next_obs, done) - obs = next_obs - step += 1 - if train_frequency is not None and step % train_frequency == 0: - agent.train() - score += np.mean(reward) - - if eps_calc: - eps_calc.update() - scores_last.append(score) - scores_all.append(score) - - score_avg = sum(scores_last) / len(scores_last) - reward_msg = f"\rEpisodes ({episode}/{episodes})\tAverage reward: {score_avg :.4f}" - print(reward_msg, end="") - if episode % 100 == 0: - print(reward_msg) + snapshot = AgentSnapshot(agent, save_at, Path("/tmp/agent_snapshot")) + try: + for episode in range(episodes): + done = False + score = 0 + obs = env.reset() + t = 0 + while not done and (max_t is None or t < max_t): + action = agent.act(obs, 0 if eps_calc is None else eps_calc.epsilon) + next_obs, reward, done, _ = env.step(action) + agent.step(obs, action, reward, next_obs, done) + obs = next_obs + step += 1 + if train_frequency is not None and step % train_frequency == 0: + agent.train() + score += np.mean(reward) + + if eps_calc: + eps_calc.update() + scores_last.append(score) + scores_all.append(score) + + score_avg = sum(scores_last) / len(scores_last) + reward_msg = f"\rEpisodes ({episode}/{episodes})\tAverage reward: {score_avg :.4f}" + print(reward_msg, end="") + + if episode % 100 == 0: + print(reward_msg) + snapshot.new_score(score_avg) + except KeyboardInterrupt: + pass + return scores_all