Skip to content

Commit

Permalink
Implement keyboard interrupt and snapshots
Browse files Browse the repository at this point in the history
I want to keep the state when of the agent when I interrupt training.
Often because it already reached a very high value. Also I wanted a way
of investigating agent behaviour during training to get a better feel
for why performance crashes early on. Additionally, to not loose well
trained agents I added a snapshot system.
  • Loading branch information
SwamyDev committed Mar 11, 2020
1 parent 5d2b702 commit 80fc730
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 33 deletions.
78 changes: 78 additions & 0 deletions tests/test_agent_snapshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pytest
from udacity_rl.agents import AgentSnapshot
from udacity_rl.agents.agent import AgentInterface


class AgentStub(AgentInterface):
@property
def action_size(self):
pass

@property
def observation_space(self):
pass

@property
def action_space(self):
pass

@property
def configuration(self):
pass

def act(self, observation, epsilon=0):
pass

def step(self, obs, action, reward, next_obs, done):
pass

def train(self):
pass

def save(self, save_path):
pass

def load(self, save_path):
pass


class AgentSpy(AgentStub):
def __init__(self):
self.received_save_path = None
self.num_save_calls = 0

def save(self, save_path):
self.received_save_path = save_path
self.num_save_calls += 1
super().save(save_path)


@pytest.fixture
def agent():
return AgentSpy()


def test_agent_snapshot_does_nothing_if_new_score_is_below_target(agent, tmp_path):
snapshot = AgentSnapshot(agent, 10, tmp_path / "snapshot")
snapshot.new_score(10)
assert agent.received_save_path is None


def test_agent_snapshot_saves_agent_if_new_score_is_above_target(agent, tmp_path):
snapshot = AgentSnapshot(agent, 10, tmp_path / "snapshot")
snapshot.new_score(11)
assert agent.received_save_path == tmp_path / "snapshot"


def test_agent_snapshot_updates_new_target(agent, tmp_path):
snapshot = AgentSnapshot(agent, 10, tmp_path / "snapshot")
snapshot.new_score(11)
assert agent.received_save_path == tmp_path / "snapshot"
snapshot.new_score(11)
assert agent.num_save_calls == 1


def test_do_nothing_if_target_is_none(agent, tmp_path):
snapshot = AgentSnapshot(agent, None, tmp_path / "snapshot")
snapshot.new_score(11)
assert agent.num_save_calls == 0
4 changes: 2 additions & 2 deletions tests/test_cli_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

def test_navigation_training_invocation():
_, scores = run_train_session(GymEnvFactory('gym_quickcheck:random-walk-v0'), AgentFactory('DQN'), 1000, dict(),
None)
None, None)
assert np.mean(scores[-100:]) == approx(-1, abs=0.3)


def test_navigation_run_invocation():
agent, _ = run_train_session(GymEnvFactory('gym_quickcheck:random-walk-v0'), AgentFactory('DQN'), 1000, dict(),
None)
None, None)
scores = run_test_session(agent, GymEnvFactory('gym_quickcheck:random-walk-v0'), 100)
assert np.mean(scores) == approx(-1, abs=0.3)
17 changes: 17 additions & 0 deletions udacity_rl/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import json
import logging
import pickle

from udacity_rl.agents.dqn_agent import DQNAgent
from udacity_rl.agents.ddpg_agent import DDPGAgent
from udacity_rl.agents.maddpg_agent import MADDPGAgent

logger = logging.getLogger(__name__)

_CLASS_MAPPING = {
DQNAgent.__name__: DQNAgent,
DDPGAgent.__name__: DDPGAgent,
Expand Down Expand Up @@ -37,3 +40,17 @@ def agent_load(path):
agent = _CLASS_MAPPING[agent_type](obs_space, act_space, **cfg)
agent.load(path)
return agent


class AgentSnapshot:
def __init__(self, agent, target_score, path):
self._agent = agent
self._target = target_score
self._path = path

def new_score(self, score):
if self._target is not None and score > self._target:
logger.info(f"saving agent ({self._path}) snapshot with score: {score}")
agent_save(self._agent, self._path)
self._target = score
logger.info(f"new save threshold: {self._target}")
73 changes: 42 additions & 31 deletions udacity_rl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from unityagents import UnityEnvironment

from udacity_rl.adapter import GymAdapter
from udacity_rl.agents import DQNAgent, agent_load, agent_save
from udacity_rl.agents import DQNAgent, agent_load, agent_save, AgentSnapshot
from udacity_rl.agents.ddpg_agent import DDPGAgent
from udacity_rl.agents.maddpg_agent import MADDPGAgent
from udacity_rl.epsilon import EpsilonExpDecay, NoiseFixed
Expand Down Expand Up @@ -97,16 +97,18 @@ def environment_session(env_factory, *args, **kwargs):
help="path to store the agent at (default: /tmp/agent_ckpt)")
@click.option('--max-t', default=None, type=click.INT,
help="maximum episode steps (default: None)")
@click.option('--save-at', default=None, type=click.FLOAT,
help="save at average score greater than specified. No snapshots when None. (default: None)")
@click.pass_context
def train(ctx, algorithm, episodes, config, output, max_t):
def train(ctx, algorithm, episodes, config, output, max_t, save_at):
"""
train the agent with the specified algorithm on the environment for the given amount of episodes
"""
cfg = dict()
if config is not None:
cfg = json.load(config)

agent, scores = run_train_session(ctx.obj['env_factory'], AgentFactory(algorithm), episodes, cfg, max_t)
agent, scores = run_train_session(ctx.obj['env_factory'], AgentFactory(algorithm), episodes, cfg, max_t, save_at)
agent_save(agent, Path(output))
plot_scores(scores)

Expand All @@ -115,7 +117,7 @@ def _squeeze_box(box):
return Box(box.low[0][0], box.high[0][0], shape=(box.shape[1],))


def run_train_session(env_fac, agent_fac, episodes, config, max_t):
def run_train_session(env_fac, agent_fac, episodes, config, max_t, save_at):
with environment_session(env_fac, train_mode=True) as env:
if 'act_noise_std' in config:
eps_calc = NoiseFixed(config['act_noise_std'])
Expand All @@ -126,42 +128,51 @@ def run_train_session(env_fac, agent_fac, episodes, config, max_t):

logger.info(f"Epsilon configuration:\n"
f"\t{eps_calc}\n")

scores = run_session(agent, env, episodes,
train_frequency=config.get('train_frequency', 4),
eps_calc=eps_calc,
max_t=max_t)
max_t=max_t,
save_at=save_at)
return agent, scores


def run_session(agent, env, episodes, train_frequency=None, eps_calc=None, max_t=None):
def run_session(agent, env, episodes, train_frequency=None, eps_calc=None, max_t=None, save_at=None):
step = 0
scores_last = deque(maxlen=100)
scores_all = list()
for episode in range(episodes):
done = False
score = 0
obs = env.reset()
t = 0
while not done and (max_t is None or t < max_t):
action = agent.act(obs, 0 if eps_calc is None else eps_calc.epsilon)
next_obs, reward, done, _ = env.step(action)
agent.step(obs, action, reward, next_obs, done)
obs = next_obs
step += 1
if train_frequency is not None and step % train_frequency == 0:
agent.train()
score += np.mean(reward)

if eps_calc:
eps_calc.update()
scores_last.append(score)
scores_all.append(score)

score_avg = sum(scores_last) / len(scores_last)
reward_msg = f"\rEpisodes ({episode}/{episodes})\tAverage reward: {score_avg :.4f}"
print(reward_msg, end="")
if episode % 100 == 0:
print(reward_msg)
snapshot = AgentSnapshot(agent, save_at, Path("/tmp/agent_snapshot"))
try:
for episode in range(episodes):
done = False
score = 0
obs = env.reset()
t = 0
while not done and (max_t is None or t < max_t):
action = agent.act(obs, 0 if eps_calc is None else eps_calc.epsilon)
next_obs, reward, done, _ = env.step(action)
agent.step(obs, action, reward, next_obs, done)
obs = next_obs
step += 1
if train_frequency is not None and step % train_frequency == 0:
agent.train()
score += np.mean(reward)

if eps_calc:
eps_calc.update()
scores_last.append(score)
scores_all.append(score)

score_avg = sum(scores_last) / len(scores_last)
reward_msg = f"\rEpisodes ({episode}/{episodes})\tAverage reward: {score_avg :.4f}"
print(reward_msg, end="")

if episode % 100 == 0:
print(reward_msg)
snapshot.new_score(score_avg)
except KeyboardInterrupt:
pass

return scores_all


Expand Down

0 comments on commit 80fc730

Please sign in to comment.