Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancement/train batch function #107

Merged
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
9c06583
Updated RL docs with latest models
Jun 24, 2020
33be076
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Jun 25, 2020
fdc92f9
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Jun 28, 2020
682bbe6
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Jun 30, 2020
17073bc
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Jun 30, 2020
47e5fa0
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Jul 3, 2020
86b0dee
Added POC for train_batch interface when populating RL datasets
Jul 3, 2020
885f198
Updated other models to use train_batch interface
Jul 3, 2020
896b032
Update tests/datamodules/test_experience_sources.py
djbyrne Jul 6, 2020
2baa02c
Fixing lint errors
Jul 6, 2020
08e71f7
Merge branch 'enhancement/train_batch_function' of https://github.com…
Jul 6, 2020
db18cd8
Fixed linting errors
Jul 6, 2020
0f5ca79
Update pl_bolts/datamodules/experience_source.py
djbyrne Jul 6, 2020
5d9dfa6
Resolved comments
Jul 6, 2020
c3f62ac
req
Borda Jul 8, 2020
577569c
Removed cyclic import of Agents from experience source
Jul 9, 2020
fa658a4
Merge branch 'enhancement/train_batch_function' of https://github.com…
Jul 9, 2020
2292528
Updated reference of Experience to datamodules instead of the rl.common
Jul 9, 2020
13cc727
timeout
Borda Jul 9, 2020
d4c1cc7
Commented out test_dev_dataset to test run times
djbyrne Jul 11, 2020
04f02cd
undo commenting out of test_dev_datasets
djbyrne Jul 11, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -45,6 +45,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Device is no longer set in the DQN model init
- Moved RL loss function to the losses module
- Moved rl.common.experience to datamodules
- train_batch function to VPG model to generate batch of data at each step (POC)
- Experience source no longer gets initialized with a device, instead the device is passed at each step()

### Fixed

Expand Down
2 changes: 2 additions & 0 deletions pl_bolts/datamodules/__init__.py
Expand Up @@ -6,3 +6,5 @@
from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule
from pl_bolts.datamodules.stl10_datamodule import STL10DataModule
from pl_bolts.datamodules.fashion_mnist_datamodule import FashionMNISTDataModule
from pl_bolts.datamodules.experience_source import ExperienceSourceDataset, ExperienceSource, \
NStepExperienceSource, EpisodicExperienceStream
200 changes: 200 additions & 0 deletions pl_bolts/datamodules/experience_source.py
@@ -0,0 +1,200 @@
"""
Datamodules for RL models that rely on experiences generated during training

Based on implementations found here: https://github.com/Shmuma/ptan/blob/master/ptan/experience.py
"""
from collections import deque
from typing import Iterable, Callable, Tuple, List
import numpy as np
import torch
from gym import Env
djbyrne marked this conversation as resolved.
Show resolved Hide resolved
from torch.utils.data import IterableDataset

# Datasets
from pl_bolts.models.rl.common.agents import Agent
from pl_bolts.models.rl.common.memory import Experience


class ExperienceSourceDataset(IterableDataset):
"""
Basic experience source dataset. Takes a generate_batch function that returns an iterator.
The logic for the experience source and how the batch is generated is defined the Lightning model itself
"""

def __init__(self, generate_batch: Callable):
self.generate_batch = generate_batch

def __iter__(self) -> Iterable:
iterator = self.generate_batch()
return iterator

# Experience Sources

class ExperienceSource:
djbyrne marked this conversation as resolved.
Show resolved Hide resolved
"""
Basic single step experience source

Args:
env: Environment that is being used
agent: Agent being used to make decisions
"""

def __init__(self, env: Env, agent: Agent):
self.env = env
self.agent = agent
self.state = self.env.reset()

def _reset(self) -> None:
"""resets the env and state"""
self.state = self.env.reset()

def step(self, device: torch.device) -> Tuple[Experience, float, bool]:
"""Takes a single step through the environment"""
action = self.agent(self.state, device)
new_state, reward, done, _ = self.env.step(action)
experience = Experience(
state=self.state,
action=action,
reward=reward,
new_state=new_state,
done=done,
)
self.state = new_state

if done:
self.state = self.env.reset()

return experience, reward, done

def run_episode(self, device: torch.device) -> float:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is episode a common RL term for this? Intuitively I would have called this sequence...

Copy link
Contributor Author

@djbyrne djbyrne Jul 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It depends on the task. Most tasks are Episodic in some form and will have a termination state denoting the end of the episode. This function was originally used for carrying out a validation episode and is useful

"""Carries out a single episode and returns the total reward. This is used for testing"""
done = False
total_reward = 0

while not done:
_, reward, done = self.step(device)
total_reward += reward

return total_reward


class NStepExperienceSource(ExperienceSource):
"""Expands upon the basic ExperienceSource by collecting experience across N steps"""

def __init__(self, env: Env, agent: Agent, n_steps: int = 1, gamma: float = 0.99):
super().__init__(env, agent)
self.gamma = gamma
self.n_steps = n_steps
self.n_step_buffer = deque(maxlen=n_steps)

def step(self, device: torch.device) -> Tuple[Experience, float, bool]:
djbyrne marked this conversation as resolved.
Show resolved Hide resolved
"""
Takes an n-step in the environment

Returns:
Experience
"""
exp = self.single_step(device)

while len(self.n_step_buffer) < self.n_steps:
self.single_step(device)

reward, next_state, done = self.get_transition_info()
first_experience = self.n_step_buffer[0]
multi_step_experience = Experience(
first_experience.state, first_experience.action, reward, done, next_state
)

return multi_step_experience, exp.reward, exp.done

def single_step(self, device: torch.device) -> Experience:
"""
Takes a single step in the environment and appends it to the n-step buffer

Returns:
Experience
"""
exp, _, _ = super().step(device)
self.n_step_buffer.append(exp)
return exp

def get_transition_info(self) -> Tuple[np.float, np.array, np.int]:
"""
get the accumulated transition info for the n_step_buffer
Args:
gamma: discount factor

Returns:
multi step reward, final observation and done
"""
last_experience = self.n_step_buffer[-1]
final_state = last_experience.new_state
done = last_experience.done
reward = last_experience.reward

# calculate reward
# in reverse order, go through all the experiences up till the first experience
for experience in reversed(list(self.n_step_buffer)[:-1]):
reward_t = experience.reward
new_state_t = experience.new_state
done_t = experience.done

reward = reward_t + self.gamma * reward * (1 - done_t)
final_state, done = (new_state_t, done_t) if done_t else (final_state, done)

return reward, final_state, done


class EpisodicExperienceStream(ExperienceSource, IterableDataset):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question about wording with episodic

"""
Basic experience stream that iteratively yield the current experience of the agent in the env

Args:
env: Environmen that is being used
agent: Agent being used to make decisions
"""

def __init__(self, env: Env, agent: Agent, device: torch.device, episodes: int = 1):
super().__init__(env, agent)
self.episodes = episodes
self.device = device

def __getitem__(self, item):
return item

def __iter__(self) -> List[Experience]:
"""
Plays a step through the environment until the episode is complete

Returns:
Batch of all transitions for the entire episode
"""
episode_steps, batch = [], []

while len(batch) < self.episodes:
exp = self.step(self.device)
episode_steps.append(exp)

if exp.done:
batch.append(episode_steps)
episode_steps = []

yield batch

def step(self, device: torch.device) -> Experience:
"""Carries out a single step in the environment"""
action = self.agent(self.state, device)
new_state, reward, done, _ = self.env.step(action)
experience = Experience(
state=self.state,
action=action,
reward=reward,
new_state=new_state,
done=done,
)
self.state = new_state

if done:
self.state = self.env.reset()

return experience
2 changes: 1 addition & 1 deletion pl_bolts/models/rl/__init__.py
Expand Up @@ -5,4 +5,4 @@
from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN
from pl_bolts.models.rl.per_dqn_model import PERDQN
from pl_bolts.models.rl.reinforce_model import Reinforce
from pl_bolts.models.rl.vanilla_policy_gradient_model import PolicyGradient
# from pl_bolts.models.rl.vanilla_policy_gradient_model import PolicyGradient
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did you change this one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant to raise an issue with this, some of these imports in the inits are raising errors in my runs. I meant to look into specifically why it was happening. Will update this

45 changes: 25 additions & 20 deletions pl_bolts/models/rl/common/experience.py
@@ -1,11 +1,15 @@
"""Experience sources to be used as datasets for Ligthning DataLoaders

Based on implementations found here: https://github.com/Shmuma/ptan/blob/master/ptan/experience.py

..note:: Deprecated, these functions have been moved to pl_bolts.datamodules.experience_source.py
djbyrne marked this conversation as resolved.
Show resolved Hide resolved

"""
from collections import deque
from typing import List, Tuple

import numpy as np
import torch
from gym import Env
from torch.utils.data import IterableDataset

Expand Down Expand Up @@ -74,19 +78,18 @@ class ExperienceSource:
agent: Agent being used to make decisions
"""

def __init__(self, env: Env, agent: Agent, device):
def __init__(self, env: Env, agent: Agent):
self.env = env
self.agent = agent
self.state = self.env.reset()
self.device = device

def _reset(self) -> None:
"""resets the env and state"""
self.state = self.env.reset()

def step(self) -> Tuple[Experience, float, bool]:
def step(self, device: torch.device) -> Tuple[Experience, float, bool]:
"""Takes a single step through the environment"""
action = self.agent(self.state, self.device)
action = self.agent(self.state, device)
new_state, reward, done, _ = self.env.step(action)
experience = Experience(
state=self.state,
Expand All @@ -102,13 +105,13 @@ def step(self) -> Tuple[Experience, float, bool]:

return experience, reward, done

def run_episode(self) -> float:
def run_episode(self, device: torch.device) -> float:
"""Carries out a single episode and returns the total reward. This is used for testing"""
done = False
total_reward = 0

while not done:
_, reward, done = self.step()
_, reward, done = self.step(device)
total_reward += reward

return total_reward
Expand All @@ -117,22 +120,23 @@ def run_episode(self) -> float:
class NStepExperienceSource(ExperienceSource):
"""Expands upon the basic ExperienceSource by collecting experience across N steps"""

def __init__(self, env: Env, agent: Agent, device, n_steps: int = 1):
super().__init__(env, agent, device)
def __init__(self, env: Env, agent: Agent, n_steps: int = 1, gamma: float = 0.99):
super().__init__(env, agent)
self.gamma = gamma
self.n_steps = n_steps
self.n_step_buffer = deque(maxlen=n_steps)

def step(self) -> Tuple[Experience, float, bool]:
def step(self, device: torch.device) -> Tuple[Experience, float, bool]:
"""
Takes an n-step in the environment

Returns:
Experience
"""
exp = self.single_step()
exp = self.single_step(device)

while len(self.n_step_buffer) < self.n_steps:
self.single_step()
self.single_step(device)

reward, next_state, done = self.get_transition_info()
first_experience = self.n_step_buffer[0]
Expand All @@ -142,18 +146,18 @@ def step(self) -> Tuple[Experience, float, bool]:

return multi_step_experience, exp.reward, exp.done

def single_step(self) -> Experience:
def single_step(self, device: torch.device) -> Experience:
"""
Takes a single step in the environment and appends it to the n-step buffer

Returns:
Experience
"""
exp, _, _ = super().step()
exp, _, _ = super().step(device)
self.n_step_buffer.append(exp)
return exp

def get_transition_info(self, gamma=0.9) -> Tuple[np.float, np.array, np.int]:
def get_transition_info(self) -> Tuple[np.float, np.array, np.int]:
"""
get the accumulated transition info for the n_step_buffer
Args:
Expand All @@ -174,7 +178,7 @@ def get_transition_info(self, gamma=0.9) -> Tuple[np.float, np.array, np.int]:
new_state_t = experience.new_state
done_t = experience.done

reward = reward_t + gamma * reward * (1 - done_t)
reward = reward_t + self.gamma * reward * (1 - done_t)
final_state, done = (new_state_t, done_t) if done_t else (final_state, done)

return reward, final_state, done
Expand All @@ -189,9 +193,10 @@ class EpisodicExperienceStream(ExperienceSource, IterableDataset):
agent: Agent being used to make decisions
"""

def __init__(self, env: Env, agent: Agent, device, episodes: int = 1):
super().__init__(env, agent, device)
def __init__(self, env: Env, agent: Agent, device: torch.device, episodes: int = 1):
super().__init__(env, agent)
self.episodes = episodes
self.device = device

def __getitem__(self, item):
return item
Expand All @@ -206,7 +211,7 @@ def __iter__(self) -> List[Experience]:
episode_steps, batch = [], []

while len(batch) < self.episodes:
exp = self.step()
exp = self.step(self.device)
episode_steps.append(exp)

if exp.done:
Expand All @@ -215,9 +220,9 @@ def __iter__(self) -> List[Experience]:

yield batch

def step(self) -> Experience:
def step(self, device: torch.device) -> Experience:
"""Carries out a single step in the environment"""
action = self.agent(self.state, self.device)
action = self.agent(self.state, device)
new_state, reward, done, _ = self.env.step(action)
experience = Experience(
state=self.state,
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/models/rl/double_dqn_model.py
Expand Up @@ -74,7 +74,7 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedD
self.agent.update_epsilon(self.global_step)

# step through environment with agent and add to buffer
exp, reward, done = self.source.step()
exp, reward, done = self.source.step(self.device)
self.buffer.append(exp)

self.episode_reward += reward
Expand Down