-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement keyboard interrupt and snapshots
I want to keep the state when of the agent when I interrupt training. Often because it already reached a very high value. Also I wanted a way of investigating agent behaviour during training to get a better feel for why performance crashes early on. Additionally, to not loose well trained agents I added a snapshot system.
- Loading branch information
Showing
4 changed files
with
139 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import pytest | ||
from udacity_rl.agents import AgentSnapshot | ||
from udacity_rl.agents.agent import AgentInterface | ||
|
||
|
||
class AgentStub(AgentInterface): | ||
@property | ||
def action_size(self): | ||
pass | ||
|
||
@property | ||
def observation_space(self): | ||
pass | ||
|
||
@property | ||
def action_space(self): | ||
pass | ||
|
||
@property | ||
def configuration(self): | ||
pass | ||
|
||
def act(self, observation, epsilon=0): | ||
pass | ||
|
||
def step(self, obs, action, reward, next_obs, done): | ||
pass | ||
|
||
def train(self): | ||
pass | ||
|
||
def save(self, save_path): | ||
pass | ||
|
||
def load(self, save_path): | ||
pass | ||
|
||
|
||
class AgentSpy(AgentStub): | ||
def __init__(self): | ||
self.received_save_path = None | ||
self.num_save_calls = 0 | ||
|
||
def save(self, save_path): | ||
self.received_save_path = save_path | ||
self.num_save_calls += 1 | ||
super().save(save_path) | ||
|
||
|
||
@pytest.fixture | ||
def agent(): | ||
return AgentSpy() | ||
|
||
|
||
def test_agent_snapshot_does_nothing_if_new_score_is_below_target(agent, tmp_path): | ||
snapshot = AgentSnapshot(agent, 10, tmp_path / "snapshot") | ||
snapshot.new_score(10) | ||
assert agent.received_save_path is None | ||
|
||
|
||
def test_agent_snapshot_saves_agent_if_new_score_is_above_target(agent, tmp_path): | ||
snapshot = AgentSnapshot(agent, 10, tmp_path / "snapshot") | ||
snapshot.new_score(11) | ||
assert agent.received_save_path == tmp_path / "snapshot" | ||
|
||
|
||
def test_agent_snapshot_updates_new_target(agent, tmp_path): | ||
snapshot = AgentSnapshot(agent, 10, tmp_path / "snapshot") | ||
snapshot.new_score(11) | ||
assert agent.received_save_path == tmp_path / "snapshot" | ||
snapshot.new_score(11) | ||
assert agent.num_save_calls == 1 | ||
|
||
|
||
def test_do_nothing_if_target_is_none(agent, tmp_path): | ||
snapshot = AgentSnapshot(agent, None, tmp_path / "snapshot") | ||
snapshot.new_score(11) | ||
assert agent.num_save_calls == 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters