Skip to content

Commit

Permalink
Integrated core_rl modules into bolts (#39)
Browse files Browse the repository at this point in the history
* Integrated core_rl modules into bolts

What Changed:
- Added commons dir with agents, experience, memory, cli, networks and wrapper modules
- Added several key models that utilise the common library
- Added unit tests for models and common
- Added integration smoke tests for models

* Fixed linting issues

* Fixed linting error

Co-authored-by: Donal <donal.byrne@xperi.com>
  • Loading branch information
djbyrne and Donal committed Jun 20, 2020
1 parent 2a8218b commit 9254626
Show file tree
Hide file tree
Showing 43 changed files with 4,010 additions and 0 deletions.
Empty file added pl_bolts/models/rl/__init__.py
Empty file.
Empty file.
129 changes: 129 additions & 0 deletions pl_bolts/models/rl/common/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""
Agent module containing classes for Agent logic
Based on the implementations found here: https://github.com/Shmuma/ptan/blob/master/ptan/agent.py
"""
from random import randint

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F


class Agent:
"""Basic agent that always returns 0"""

def __init__(self, net: nn.Module):
self.net = net

def __call__(self, state: torch.Tensor, device: str) -> int:
"""
Using the given network, decide what action to carry
Args:
state: current state of the environment
device: device used for current batch
Returns:
action
"""
return 0


class ValueAgent(Agent):
"""Value based agent that returns an action based on the Q values from the network"""

def __init__(
self,
net: nn.Module,
action_space: int,
eps_start: float = 1.0,
eps_end: float = 0.2,
eps_frames: float = 1000,
):
super().__init__(net)
self.action_space = action_space
self.eps_start = eps_start
self.epsilon = eps_start
self.eps_end = eps_end
self.eps_frames = eps_frames

def __call__(self, state: torch.Tensor, device: str) -> int:
"""
Takes in the current state and returns the action based on the agents policy
Args:
state: current state of the environment
device: the device used for the current batch
Returns:
action defined by policy
"""

if np.random.random() < self.epsilon:
action = self.get_random_action()
else:
action = self.get_action(state, device)

return action

def get_random_action(self) -> int:
"""returns a random action"""
action = randint(0, self.action_space - 1)

return action

def get_action(self, state: torch.Tensor, device: torch.device):
"""
Returns the best action based on the Q values of the network
Args:
state: current state of the environment
device: the device used for the current batch
Returns:
action defined by Q values
"""
if not isinstance(state, torch.Tensor):
state = torch.tensor([state])

if device.type != "cpu":
state = state.cuda(device)

q_values = self.net(state)
_, action = torch.max(q_values, dim=1)
return int(action.item())

def update_epsilon(self, step: int) -> None:
"""
Updates the epsilon value based on the current step
Args:
step: current global step
"""
self.epsilon = max(self.eps_end, self.eps_start - (step + 1) / self.eps_frames)


class PolicyAgent(Agent):
"""Policy based agent that returns an action based on the networks policy"""

def __call__(self, state: torch.Tensor, device: str) -> int:
"""
Takes in the current state and returns the action based on the agents policy
Args:
state: current state of the environment
device: the device used for the current batch
Returns:
action defined by policy
"""
if device.type != "cpu":
state = state.cuda(device)

# get the logits and pass through softmax for probability distribution
probabilities = F.softmax(self.net(state))
prob_np = probabilities.data.cpu().numpy()

# take the numpy values and randomly select action based on prob distribution
action = np.random.choice(len(prob_np), p=prob_np)

return action
58 changes: 58 additions & 0 deletions pl_bolts/models/rl/common/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Contains generic arguments used for all models"""

import argparse


def add_base_args(parent) -> argparse.ArgumentParser:
"""
Adds arguments for DQN model
Note: these params are fine tuned for Pong env
Args:
parent
"""
arg_parser = argparse.ArgumentParser(parents=[parent])

arg_parser.add_argument(
"--algo", type=str, default="dqn", help="algorithm to use for training"
)
arg_parser.add_argument(
"--batch_size", type=int, default=32, help="size of the batches"
)
arg_parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
arg_parser.add_argument(
"--env", type=str, default="PongNoFrameskip-v4", help="gym environment tag"
)
arg_parser.add_argument("--gamma", type=float, default=0.99, help="discount factor")
arg_parser.add_argument(
"--episode_length", type=int, default=500, help="max length of an episode"
)
arg_parser.add_argument(
"--max_episode_reward",
type=int,
default=18,
help="max episode reward in the environment",
)
arg_parser.add_argument(
"--max_steps", type=int, default=500000, help="max steps to train the agent"
)
arg_parser.add_argument(
"--n_steps",
type=int,
default=4,
help="how many steps to unroll for each update",
)
arg_parser.add_argument(
"--gpus", type=int, default=1, help="number of gpus to use for training"
)
arg_parser.add_argument(
"--seed", type=int, default=123, help="seed for training run"
)
arg_parser.add_argument(
"--backend",
type=str,
default="dp",
help="distributed backend to be used by lightning",
)
return arg_parser

0 comments on commit 9254626

Please sign in to comment.