Skip to content

Commit

Permalink
Doc2 (#40)
Browse files Browse the repository at this point in the history
* docs

* refactored rl docs

* refactored rl docs

* refactored rl docs

* refactored rl docs

* refactored rl docs

* refactored rl docs

* refactored rl docs

* refactored rl docs
  • Loading branch information
williamFalcon committed Jun 20, 2020
1 parent 3593251 commit ae780bc
Show file tree
Hide file tree
Showing 19 changed files with 123 additions and 69 deletions.
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pandoc
docutils
sphinxcontrib-fulltoc
sphinxcontrib-mockautodoc
gym
git+https://github.com/PytorchLightning/lightning_sphinx_theme.git
# pip_shims
sphinx-autodoc-typehints
Expand Down
5 changes: 0 additions & 5 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,6 @@


# -- Project documents -------------------------------------------------------

# export the documentation
with open('intro.rst', 'w') as fp:
fp.write(pl_bolts.__long_doc__)

# export the READme
with open(os.path.join(PATH_ROOT, 'README.md'), 'r') as fp:
readme = fp.read()
Expand Down
4 changes: 1 addition & 3 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
PyTorchLightning-Bolts documentation
====================================

.. include:: intro.rst

.. toctree::
:maxdepth: 1
:name: start
Expand Down Expand Up @@ -50,6 +47,7 @@ PyTorchLightning-Bolts documentation

autoencoders
gans
rl
self_supervised_models

.. toctree::
Expand Down
17 changes: 0 additions & 17 deletions docs/source/intro.rst

This file was deleted.

69 changes: 69 additions & 0 deletions docs/source/rl.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
Reinforcement Learning
======================
This module is a collection of common RL approaches implemented in Lightning.

---------

Module authors
--------------

Contributions by: `Donal Byrne <https://github.com/djbyrne>`_

- DQN
- Double DQN
- Dueling DQN
- Noisy DQN
- Prioritized Experience Replay DQN
- NStep DQN
- Noisy DQN
- Reinforce
- Policy Gradient

------------

DQN Models
----------
The following models are based on DQN

Deep-Q-Network (DQN)
^^^^^^^^^^^^^^^^^^^^
DQN model introduced in `Playing Atari with Deep Reinforcement Learning <https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf>`_.
Paper authors: Volodymyr Mnih, Koray Kavukcuoglu, David Silver, Alex Graves, Ioannis Antonoglou, Daan Wierstra, Martin Riedmiller.

Original implementation by: `Donal Byrne <https://github.com/djbyrne>`_
Example::

from pl_bolts.models.rl import DQN

model = DQN()
trainer = Trainer()
trainer.fit(model)

.. autoclass:: pl_bolts.models.rl.DQN
:noindex:

Double DQN
^^^^^^^^^^^^^^^^^^^^
Double DQN model introduced in TODO
Paper authors: TODO

Original implementation by: `Donal Byrne <https://github.com/djbyrne>`_
Example::

from pl_bolts.models.rl import TODO

model = TODO
trainer = Trainer()
trainer.fit(model)

.. autoclass:: pl_bolts.models.rl.DoubleDQN
:noindex:

--------------

Policy Gradient Models
----------------------
The following models are based on Policy gradient

Policy Gradient
^^^^^^^^^^^^^^^
8 changes: 8 additions & 0 deletions pl_bolts/models/rl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pl_bolts.models.rl.vanilla_policy_gradient.model import PolicyGradient
from pl_bolts.models.rl.reinforce.model import Reinforce
from pl_bolts.models.rl.per_dqn.model import PERDQN
from pl_bolts.models.rl.noisy_dqn.model import NoisyDQN
from pl_bolts.models.rl.n_step_dqn.model import NStepDQN
from pl_bolts.models.rl.dueling_dqn.model import DuelingDQN
from pl_bolts.models.rl.dqn.model import DQN
from pl_bolts.models.rl.double_dqn.model import DoubleDQN
2 changes: 1 addition & 1 deletion pl_bolts/models/rl/common/experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from typing import List, Tuple

import numpy as np
from gym import Env
from torch.utils.data import IterableDataset
from pl_bolts.models.rl.common.agents import Agent
from pl_bolts.models.rl.common.memory import Experience, Buffer
from gym import Env


class RLDataset(IterableDataset):
Expand Down
4 changes: 2 additions & 2 deletions pl_bolts/models/rl/double_dqn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from typing import Tuple
import torch
import torch.nn as nn
from pl_bolts.models.rl.dqn.model import DQNLightning
from pl_bolts.models.rl.dqn.model import DQN


class DoubleDQNLightning(DQNLightning):
class DoubleDQN(DQN):
""" Double DQN Model """

def loss(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/models/rl/dqn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pl_bolts.models.rl.common.networks import CNN


class DQNLightning(pl.LightningModule):
class DQN(pl.LightningModule):
""" Basic DQN Model """

def __init__(self, hparams: argparse.Namespace) -> None:
Expand Down
4 changes: 2 additions & 2 deletions pl_bolts/models/rl/dueling_dqn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@


from pl_bolts.models.rl.common.networks import DuelingCNN
from pl_bolts.models.rl.dqn.model import DQNLightning
from pl_bolts.models.rl.dqn.model import DQN


class DuelingDQNLightning(DQNLightning):
class DuelingDQN(DQN):
""" Dueling DQN Model """

def build_networks(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions pl_bolts/models/rl/n_step_dqn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from pl_bolts.models.rl.common.agents import ValueAgent
from pl_bolts.models.rl.common.experience import NStepExperienceSource
from pl_bolts.models.rl.common.memory import ReplayBuffer
from pl_bolts.models.rl.dqn.model import DQNLightning
from pl_bolts.models.rl.dqn.model import DQN


class NStepDQNLightning(DQNLightning):
class NStepDQN(DQN):
""" NStep DQN Model """

def __init__(self, hparams: argparse.Namespace) -> None:
Expand Down
4 changes: 2 additions & 2 deletions pl_bolts/models/rl/noisy_dqn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import torch

from pl_bolts.models.rl.common.networks import NoisyCNN
from pl_bolts.models.rl.dqn.model import DQNLightning
from pl_bolts.models.rl.dqn.model import DQN


class NoisyDQNLightning(DQNLightning):
class NoisyDQN(DQN):
""" Noisy DQN Model """

def build_networks(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions pl_bolts/models/rl/per_dqn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from pl_bolts.models.rl.common.agents import ValueAgent
from pl_bolts.models.rl.common.experience import ExperienceSource, PrioRLDataset
from pl_bolts.models.rl.common.memory import PERBuffer
from pl_bolts.models.rl.dqn.model import DQNLightning
from pl_bolts.models.rl.dqn.model import DQN


class PERDQNLightning(DQNLightning):
class PERDQN(DQN):
""" PER DQN Model """

def __init__(self, hparams):
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/models/rl/reinforce/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pl_bolts.models.rl.common.wrappers import ToTensor


class ReinforceLightning(pl.LightningModule):
class Reinforce(pl.LightningModule):
""" Basic DQN Model """

def __init__(self, hparams: argparse.Namespace) -> None:
Expand Down
4 changes: 2 additions & 2 deletions pl_bolts/models/rl/vanilla_policy_gradient/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from pl_bolts.models.rl.common.wrappers import ToTensor


class VPGLightning(pl.LightningModule):
""" VPG Model """
class PolicyGradient(pl.LightningModule):
""" PolicyGradient Model """

def __init__(self, hparams: argparse.Namespace) -> None:
super().__init__()
Expand Down
14 changes: 7 additions & 7 deletions tests/models/test_rl/integration/test_policy_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
import pytorch_lightning as pl

from pl_bolts.models.rl.common import cli
from pl_bolts.models.rl.reinforce.model import ReinforceLightning
from pl_bolts.models.rl.vanilla_policy_gradient.model import VPGLightning
from pl_bolts.models.rl.reinforce.model import Reinforce
from pl_bolts.models.rl.vanilla_policy_gradient.model import PolicyGradient


class TestPolicyModels(TestCase):

def setUp(self) -> None:
parent_parser = argparse.ArgumentParser(add_help=False)
parent_parser = cli.add_base_args(parent=parent_parser)
parent_parser = VPGLightning.add_model_specific_args(parent_parser)
parent_parser = PolicyGradient.add_model_specific_args(parent_parser)
args_list = [
"--algo", "vpg",
"--algo", "PolicyGradient",
"--episode_length", "100",
"--env", "CartPole-v0"
]
Expand All @@ -29,14 +29,14 @@ def setUp(self) -> None:

def test_reinforce(self):
"""Smoke test that the DQN model runs"""
model = ReinforceLightning(self.hparams)
model = Reinforce(self.hparams)
result = self.trainer.fit(model)

self.assertEqual(result, 1)

def test_vpg(self):
def test_PolicyGradient(self):
"""Smoke test that the Double DQN model runs"""
model = VPGLightning(self.hparams)
model = PolicyGradient(self.hparams)
result = self.trainer.fit(model)

self.assertEqual(result, 1)
26 changes: 13 additions & 13 deletions tests/models/test_rl/integration/test_value_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@
import pytorch_lightning as pl

from pl_bolts.models.rl.common import cli
from pl_bolts.models.rl.double_dqn.model import DoubleDQNLightning
from pl_bolts.models.rl.dqn.model import DQNLightning
from pl_bolts.models.rl.dueling_dqn.model import DuelingDQNLightning
from pl_bolts.models.rl.n_step_dqn.model import NStepDQNLightning
from pl_bolts.models.rl.noisy_dqn.model import NoisyDQNLightning
from pl_bolts.models.rl.per_dqn.model import PERDQNLightning
from pl_bolts.models.rl.double_dqn.model import DoubleDQN
from pl_bolts.models.rl.dqn.model import DQN
from pl_bolts.models.rl.dueling_dqn.model import DuelingDQN
from pl_bolts.models.rl.n_step_dqn.model import NStepDQN
from pl_bolts.models.rl.noisy_dqn.model import NoisyDQN
from pl_bolts.models.rl.per_dqn.model import PERDQN


class TestValueModels(TestCase):

def setUp(self) -> None:
parent_parser = argparse.ArgumentParser(add_help=False)
parent_parser = cli.add_base_args(parent=parent_parser)
parent_parser = DQNLightning.add_model_specific_args(parent_parser)
parent_parser = DQN.add_model_specific_args(parent_parser)
args_list = [
"--algo", "dqn",
"--warm_start_steps", "100",
Expand All @@ -34,42 +34,42 @@ def setUp(self) -> None:

def test_dqn(self):
"""Smoke test that the DQN model runs"""
model = DQNLightning(self.hparams)
model = DQN(self.hparams)
result = self.trainer.fit(model)

self.assertEqual(result, 1)

def test_double_dqn(self):
"""Smoke test that the Double DQN model runs"""
model = DoubleDQNLightning(self.hparams)
model = DoubleDQN(self.hparams)
result = self.trainer.fit(model)

self.assertEqual(result, 1)

def test_dueling_dqn(self):
"""Smoke test that the Dueling DQN model runs"""
model = DuelingDQNLightning(self.hparams)
model = DuelingDQN(self.hparams)
result = self.trainer.fit(model)

self.assertEqual(result, 1)

def test_noisy_dqn(self):
"""Smoke test that the Noisy DQN model runs"""
model = NoisyDQNLightning(self.hparams)
model = NoisyDQN(self.hparams)
result = self.trainer.fit(model)

self.assertEqual(result, 1)

def test_per_dqn(self):
"""Smoke test that the PER DQN model runs"""
model = PERDQNLightning(self.hparams)
model = PERDQN(self.hparams)
result = self.trainer.fit(model)

self.assertEqual(result, 1)

def test_n_step_dqn(self):
"""Smoke test that the N Step DQN model runs"""
model = NStepDQNLightning(self.hparams)
model = NStepDQN(self.hparams)
result = self.trainer.fit(model)

self.assertEqual(result, 1)
8 changes: 4 additions & 4 deletions tests/models/test_rl/unit/test_reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from pl_bolts.models.rl.common.experience import EpisodicExperienceStream
from pl_bolts.models.rl.common.networks import MLP
from pl_bolts.models.rl.common.wrappers import ToTensor
from pl_bolts.models.rl.dqn.model import DQNLightning
from pl_bolts.models.rl.reinforce.model import ReinforceLightning
from pl_bolts.models.rl.dqn.model import DQN
from pl_bolts.models.rl.reinforce.model import Reinforce


class TestReinforce(TestCase):
Expand All @@ -28,15 +28,15 @@ def setUp(self) -> None:

parent_parser = argparse.ArgumentParser(add_help=False)
parent_parser = cli.add_base_args(parent=parent_parser)
parent_parser = DQNLightning.add_model_specific_args(parent_parser)
parent_parser = DQN.add_model_specific_args(parent_parser)
args_list = [
"--algo", "dqn",
"--warm_start_steps", "500",
"--episode_length", "100",
]
self.hparams = parent_parser.parse_args(args_list)

self.model = ReinforceLightning(self.hparams)
self.model = Reinforce(self.hparams)

def test_loss(self):
"""Test the reinforce loss function"""
Expand Down

0 comments on commit ae780bc

Please sign in to comment.