Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RL models clean up #112

Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
9c06583
Updated RL docs with latest models
Jun 24, 2020
33be076
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Jun 25, 2020
fdc92f9
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Jun 28, 2020
682bbe6
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Jun 30, 2020
17073bc
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Jun 30, 2020
d05db21
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Jul 8, 2020
8cde396
Updated RL docs with latest models
Jun 24, 2020
96aaa97
Merge branch 'master' of https://github.com/djbyrne/pytorch-lightning…
djbyrne Jul 11, 2020
885be16
Cleaned up avg_reward calculation
djbyrne Jul 11, 2020
00a8547
Refactored DQN to use train_batch structure
djbyrne Jul 12, 2020
0aca98d
Merge branch 'master' into enhancement/rl_models_clean_up
djbyrne Jul 12, 2020
cfd139e
Cleaned up VPG metrics
djbyrne Jul 12, 2020
2741c5b
Refactore double dqn to use train_batch structure
djbyrne Jul 12, 2020
ad54460
Refactored noisy dqn to use train_batch structure
djbyrne Jul 12, 2020
164c7b4
Refactored per dqn to use train_batch structure
djbyrne Jul 12, 2020
407ff94
Updated docstrings
djbyrne Jul 12, 2020
6df878f
format
Borda Jul 12, 2020
44e0006
Apply suggestions from code review
Borda Jul 12, 2020
4f3d164
typo
Borda Jul 13, 2020
0333ebb
Merge branch 'enhancement/rl_models_clean_up' of https://github.com/d…
Borda Jul 13, 2020
2e18e19
Fixed pep8 errors
djbyrne Jul 14, 2020
79e7e5c
Fixed flake8 errors
djbyrne Jul 14, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -48,6 +48,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Moved rl.common.experience to datamodules
- train_batch function to VPG model to generate batch of data at each step (POC)
- Experience source no longer gets initialized with a device, instead the device is passed at each step()
- Refactored RL models to use train_batch structure, with the exception of REINFORCE

### Fixed

Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/datamodules/experience_source.py
Expand Up @@ -94,7 +94,7 @@ def step(self, device: torch.device) -> Tuple[Experience, float, bool]:
Takes an n-step in the environment

Returns:
Experience
Experience, undiscounted reward, done
"""
exp = self.n_step(device)

Expand Down
17 changes: 0 additions & 17 deletions pl_bolts/models/rl/double_dqn_model.py
Expand Up @@ -71,30 +71,13 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedD
Returns:
Training loss and log metrics
"""
self.agent.update_epsilon(self.global_step)

# step through environment with agent and add to buffer
exp, reward, done = self.source.step(self.device)
self.buffer.append(exp)

self.episode_reward += reward
self.episode_steps += 1

# calculates training loss
loss = double_dqn_loss(batch, self.net, self.target_net)

if self.trainer.use_dp or self.trainer.use_ddp2:
loss = loss.unsqueeze(0)

if done:
self.total_reward = self.episode_reward
self.reward_list.append(self.total_reward)
self.avg_reward = sum(self.reward_list[-100:]) / 100
self.episode_count += 1
self.episode_reward = 0
self.total_episode_steps = self.episode_steps
self.episode_steps = 0

# Soft update of target network
if self.global_step % self.sync_rate == 0:
self.target_net.load_state_dict(self.net.state_dict())
Expand Down
89 changes: 62 additions & 27 deletions pl_bolts/models/rl/dqn_model.py
Expand Up @@ -8,15 +8,14 @@

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

from pl_bolts.datamodules.experience_source import Experience, ExperienceSource, ExperienceSourceDataset
from pl_bolts.losses.rl import dqn_loss
from pl_bolts.models.rl.common import wrappers, cli
from pl_bolts.models.rl.common.agents import ValueAgent
from pl_bolts.models.rl.common.experience import ExperienceSource, RLDataset
from pl_bolts.models.rl.common.memory import ReplayBuffer
from pl_bolts.models.rl.common.networks import CNN

Expand All @@ -38,6 +37,9 @@ def __init__(
replay_size: int = 100000,
warm_start_size: int = 10000,
num_samples: int = 500,
avg_reward_len: int = 100,
min_episode_reward: int = -21,

**kwargs,
):
"""
Expand Down Expand Up @@ -76,6 +78,10 @@ def __init__(
warm_start_size: how many random steps through the environment to be carried out at the start of
training to fill the buffer with a starting point
num_samples: the number of samples to pull from the dataset iterator and feed to the DataLoader
avg_reward_len: how many episodes to take into account when calculating the avg reward
min_episode_reward: the minimum score that can be achieved in an episode. Used for filling the avg buffer
before training begins

Borda marked this conversation as resolved.
Show resolved Hide resolved

.. note::
This example is based on:
Expand Down Expand Up @@ -110,6 +116,7 @@ def __init__(
eps_end=eps_end,
eps_frames=eps_last_frame,
)
self.source = ExperienceSource(self.env, self.agent)

# Hyperparameters
self.sync_rate = sync_rate
Expand All @@ -128,10 +135,20 @@ def __init__(
self.episode_count = 0
self.episode_steps = 0
self.total_episode_steps = 0

self.total_steps = 0
self.reward_sum = 0

self.avg_reward_len = avg_reward_len

self.reward_list = []
for _ in range(100):
self.reward_list.append(-21)
self.avg_reward = -21
for _ in range(avg_reward_len):
self.reward_list.append(torch.tensor(min_episode_reward, device=self.device))
self.avg_reward = 0


self.buffer = ReplayBuffer(self.replay_size)
self.populate(self.warm_start_size)

def populate(self, warm_start: int) -> None:
"""Populates the buffer with initial experience"""
Expand Down Expand Up @@ -159,6 +176,45 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self.net(x)
return output

def train_batch(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Contains the logic for generating a new batch of data to be passed to the DataLoader

Returns:
yields a Experience tuple containing the state, action, reward, done and next_state.
"""
self.agent.update_epsilon(self.total_steps)

# take a step in the env
exp, reward, done = self.source.step(self.device)
self.episode_steps += 1
self.total_steps += 1

self.reward_sum += exp.reward
self.episode_reward += reward

# gather the experience data
self.buffer.append(exp)

if done:
# tracking metrics
self.episode_count += 1
self.reward_list.append(self.total_reward)
self.avg_reward = sum(self.reward_list[-self.avg_reward_len:]) / self.avg_reward_len
self.total_reward = self.episode_reward
self.total_episode_steps = self.episode_steps

self.logger.experiment.add_scalar("reward", self.total_reward, self.total_steps)

# reset metrics
self.episode_reward = 0
self.episode_steps = 0

states, actions, rewards, dones, new_states = self.buffer.sample(self.batch_size)

for idx, _ in enumerate(dones):
yield states[idx], actions[idx], rewards[idx], dones[idx], new_states[idx]

def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedDict:
"""
Carries out a single step through the environment to update the replay buffer.
Expand All @@ -171,29 +227,12 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedD
Returns:
Training loss and log metrics
"""
self.agent.update_epsilon(self.global_step)

# step through environment with agent and add to buffer
exp, reward, done = self.source.step(self.device)
self.buffer.append(exp)

self.episode_reward += reward
self.episode_steps += 1

# calculates training loss
loss = dqn_loss(batch, self.net, self.target_net)

if self.trainer.use_dp or self.trainer.use_ddp2:
loss = loss.unsqueeze(0)

if done:
self.total_reward = self.episode_reward
self.reward_list.append(self.total_reward)
self.avg_reward = sum(self.reward_list[-100:]) / 100
self.episode_count += 1
self.episode_reward = 0
self.total_episode_steps = self.episode_steps
self.episode_steps = 0

# Soft update of target network
if self.global_step % self.sync_rate == 0:
Expand Down Expand Up @@ -244,11 +283,7 @@ def configure_optimizers(self) -> List[Optimizer]:

def prepare_data(self) -> None:
"""Initialize the Replay Buffer dataset used for retrieving experiences"""
self.source = ExperienceSource(self.env, self.agent)
self.buffer = ReplayBuffer(self.replay_size)
self.populate(self.warm_start_size)

self.dataset = RLDataset(self.buffer, self.sample_len)
self.dataset = ExperienceSourceDataset(self.train_batch)

def train_dataloader(self) -> DataLoader:
"""Get train loader"""
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/models/rl/n_step_dqn_model.py
Expand Up @@ -72,7 +72,7 @@ def __init__(

"""
super().__init__(env, gpus, eps_start, eps_end, eps_last_frame, sync_rate, gamma, learning_rate,
batch_size, replay_size, warm_start_size, num_samples)
batch_size, replay_size, warm_start_size, num_samples, **kwargs)

self.source = NStepExperienceSource(
self.env, self.agent, n_steps=n_steps
Expand Down
30 changes: 7 additions & 23 deletions pl_bolts/models/rl/noisy_dqn_model.py
Expand Up @@ -76,42 +76,26 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedD
Returns:
Training loss and log metrics
"""
# step through environment with agent and add to buffer
exp, reward, done = self.source.step(self.device)
self.buffer.append(exp)

self.episode_reward += reward
self.episode_steps += 1

# calculates training loss
loss = dqn_loss(batch, self.net, self.target_net)

if self.trainer.use_dp or self.trainer.use_ddp2:
loss = loss.unsqueeze(0)

if done:
self.total_reward = self.episode_reward
self.reward_list.append(self.total_reward)
self.avg_reward = sum(self.reward_list[-100:]) / 100
self.episode_count += 1
self.episode_reward = 0
self.total_episode_steps = self.episode_steps
self.episode_steps = 0

# Soft update of target network
if self.global_step % self.sync_rate == 0:
self.target_net.load_state_dict(self.net.state_dict())

log = {
"total_reward": torch.tensor(self.total_reward).to(self.device),
"avg_reward": torch.tensor(self.avg_reward),
"total_reward": self.total_reward,
"avg_reward": self.avg_reward,
"train_loss": loss,
"episode_steps": torch.tensor(self.total_episode_steps),
"episode_steps": self.total_episode_steps,
}
status = {
"steps": torch.tensor(self.global_step).to(self.device),
"avg_reward": torch.tensor(self.avg_reward),
"total_reward": torch.tensor(self.total_reward).to(self.device),
"steps": self.global_step,
"avg_reward": self.avg_reward,
"total_reward": self.total_reward,
"episodes": self.episode_count,
"episode_steps": self.episode_steps,
"epsilon": self.agent.epsilon,
Expand All @@ -120,7 +104,7 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedD
return OrderedDict(
{
"loss": loss,
"avg_reward": torch.tensor(self.avg_reward),
"avg_reward": self.avg_reward,
"log": log,
"progress_bar": status,
}
Expand Down
64 changes: 46 additions & 18 deletions pl_bolts/models/rl/per_dqn_model.py
Expand Up @@ -59,6 +59,52 @@ class PERDQN(DQN):
.. note:: Currently only supports CPU and single GPU training with `distributed_backend=dp`

"""
def train_batch(self) -> Tuple[Tuple, torch.Tensor, torch.Tensor]:
"""
Contains the logic for generating a new batch of data to be passed to the DataLoader

Returns:
yields a Experience tuple.
"""
self.agent.update_epsilon(self.total_steps)

# take a step in the env
exp, reward, done = self.source.step(self.device)
self.episode_steps += 1
self.total_steps += 1

self.reward_sum += exp.reward
self.episode_reward += reward

# gather the experience data
self.buffer.append(exp)

if done:
# tracking metrics
self.episode_count += 1
self.reward_list.append(self.total_reward)
self.avg_reward = sum(self.reward_list[-self.avg_reward_len:]) / self.avg_reward_len
self.total_reward = self.episode_reward
self.total_episode_steps = self.episode_steps

self.logger.experiment.add_scalar("reward", self.total_reward, self.total_steps)

# reset metrics
self.episode_reward = 0
self.episode_steps = 0
Comment on lines +93 to +94
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall this be rater at the beginning rather than the end?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im not sure I understand

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that you reset the episode_steps and the other when it is done... so shall it be more logical to reset it before you start rather?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. The done is a local variable that is retrieved after taking a step on line 72, so the done check must come after that, so I think it makes more sense to do it at the end


samples, indices, weights = self.buffer.sample(self.sample_size)

states, actions, rewards, dones, new_states = samples

for idx, _ in enumerate(dones):
yield (
states[idx],
actions[idx],
rewards[idx],
dones[idx],
new_states[idx],
), indices[idx], weights[idx]

def training_step(self, batch, _) -> OrderedDict:
"""
Expand All @@ -76,15 +122,6 @@ def training_step(self, batch, _) -> OrderedDict:

indices = indices.cpu().numpy()

self.agent.update_epsilon(self.global_step)

# step through environment with agent and add to buffer
exp, reward, done = self.source.step(self.device)
self.buffer.append(exp)

self.episode_reward += reward
self.episode_steps += 1

# calculates training loss
loss, batch_weights = per_dqn_loss(samples, weights, self.net, self.target_net)

Expand All @@ -94,15 +131,6 @@ def training_step(self, batch, _) -> OrderedDict:
if self.trainer.use_dp or self.trainer.use_ddp2:
loss = loss.unsqueeze(0)

if done:
self.total_reward = self.episode_reward
self.reward_list.append(self.total_reward)
self.avg_reward = sum(self.reward_list[-100:]) / 100
self.episode_count += 1
self.episode_reward = 0
self.total_episode_steps = self.episode_steps
self.episode_steps = 0

# Soft update of target network
if self.global_step % self.sync_rate == 0:
self.target_net.load_state_dict(self.net.state_dict())
Expand Down
10 changes: 6 additions & 4 deletions pl_bolts/models/rl/reinforce_model.py
Expand Up @@ -30,7 +30,7 @@ class Reinforce(pl.LightningModule):
""" Basic REINFORCE Policy Model """

def __init__(self, env: str, gamma: float = 0.99, lr: float = 1e-4, batch_size: int = 32,
batch_episodes: int = 4, **kwargs) -> None:
batch_episodes: int = 4, avg_reward_len=100, **kwargs) -> None:
"""
PyTorch Lightning implementation of `REINFORCE
<https://papers.nips.cc/paper/
Expand Down Expand Up @@ -59,6 +59,7 @@ def __init__(self, env: str, gamma: float = 0.99, lr: float = 1e-4, batch_size:
lr: learning rate
batch_size: size of minibatch pulled from the DataLoader
batch_episodes: how many episodes to rollout for each batch of training
avg_reward_len: how many episodes to take into account when calculating the avg reward

.. note::
This example is based on:
Expand Down Expand Up @@ -92,10 +93,11 @@ def __init__(self, env: str, gamma: float = 0.99, lr: float = 1e-4, batch_size:
self.episode_count = 0
self.episode_steps = 0
self.total_episode_steps = 0
self.avg_reward_len = avg_reward_len

self.reward_list = []
for _ in range(100):
self.reward_list.append(0)
for _ in range(avg_reward_len):
self.reward_list.append(torch.tensor(0, device=self.device))
self.avg_reward = 0

def build_networks(self) -> None:
Expand Down Expand Up @@ -254,7 +256,7 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedD
# get avg reward over the batched episodes
self.episode_reward = sum(batch_rewards) / len(batch)
self.reward_list.append(self.episode_reward)
self.avg_reward = sum(self.reward_list) / len(self.reward_list)
self.avg_reward = sum(self.reward_list[-self.avg_reward_len:]) / self.avg_reward_len

# calculates training loss
loss = self.loss(batch_qvals, batch_states, batch_actions)
Expand Down