-
Notifications
You must be signed in to change notification settings - Fork 322
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Integrated core_rl modules into bolts (#39)
* Integrated core_rl modules into bolts What Changed: - Added commons dir with agents, experience, memory, cli, networks and wrapper modules - Added several key models that utilise the common library - Added unit tests for models and common - Added integration smoke tests for models * Fixed linting issues * Fixed linting error Co-authored-by: Donal <donal.byrne@xperi.com>
- Loading branch information
Showing
43 changed files
with
4,010 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
""" | ||
Agent module containing classes for Agent logic | ||
Based on the implementations found here: https://github.com/Shmuma/ptan/blob/master/ptan/agent.py | ||
""" | ||
from random import randint | ||
|
||
import numpy as np | ||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class Agent: | ||
"""Basic agent that always returns 0""" | ||
|
||
def __init__(self, net: nn.Module): | ||
self.net = net | ||
|
||
def __call__(self, state: torch.Tensor, device: str) -> int: | ||
""" | ||
Using the given network, decide what action to carry | ||
Args: | ||
state: current state of the environment | ||
device: device used for current batch | ||
Returns: | ||
action | ||
""" | ||
return 0 | ||
|
||
|
||
class ValueAgent(Agent): | ||
"""Value based agent that returns an action based on the Q values from the network""" | ||
|
||
def __init__( | ||
self, | ||
net: nn.Module, | ||
action_space: int, | ||
eps_start: float = 1.0, | ||
eps_end: float = 0.2, | ||
eps_frames: float = 1000, | ||
): | ||
super().__init__(net) | ||
self.action_space = action_space | ||
self.eps_start = eps_start | ||
self.epsilon = eps_start | ||
self.eps_end = eps_end | ||
self.eps_frames = eps_frames | ||
|
||
def __call__(self, state: torch.Tensor, device: str) -> int: | ||
""" | ||
Takes in the current state and returns the action based on the agents policy | ||
Args: | ||
state: current state of the environment | ||
device: the device used for the current batch | ||
Returns: | ||
action defined by policy | ||
""" | ||
|
||
if np.random.random() < self.epsilon: | ||
action = self.get_random_action() | ||
else: | ||
action = self.get_action(state, device) | ||
|
||
return action | ||
|
||
def get_random_action(self) -> int: | ||
"""returns a random action""" | ||
action = randint(0, self.action_space - 1) | ||
|
||
return action | ||
|
||
def get_action(self, state: torch.Tensor, device: torch.device): | ||
""" | ||
Returns the best action based on the Q values of the network | ||
Args: | ||
state: current state of the environment | ||
device: the device used for the current batch | ||
Returns: | ||
action defined by Q values | ||
""" | ||
if not isinstance(state, torch.Tensor): | ||
state = torch.tensor([state]) | ||
|
||
if device.type != "cpu": | ||
state = state.cuda(device) | ||
|
||
q_values = self.net(state) | ||
_, action = torch.max(q_values, dim=1) | ||
return int(action.item()) | ||
|
||
def update_epsilon(self, step: int) -> None: | ||
""" | ||
Updates the epsilon value based on the current step | ||
Args: | ||
step: current global step | ||
""" | ||
self.epsilon = max(self.eps_end, self.eps_start - (step + 1) / self.eps_frames) | ||
|
||
|
||
class PolicyAgent(Agent): | ||
"""Policy based agent that returns an action based on the networks policy""" | ||
|
||
def __call__(self, state: torch.Tensor, device: str) -> int: | ||
""" | ||
Takes in the current state and returns the action based on the agents policy | ||
Args: | ||
state: current state of the environment | ||
device: the device used for the current batch | ||
Returns: | ||
action defined by policy | ||
""" | ||
if device.type != "cpu": | ||
state = state.cuda(device) | ||
|
||
# get the logits and pass through softmax for probability distribution | ||
probabilities = F.softmax(self.net(state)) | ||
prob_np = probabilities.data.cpu().numpy() | ||
|
||
# take the numpy values and randomly select action based on prob distribution | ||
action = np.random.choice(len(prob_np), p=prob_np) | ||
|
||
return action |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
"""Contains generic arguments used for all models""" | ||
|
||
import argparse | ||
|
||
|
||
def add_base_args(parent) -> argparse.ArgumentParser: | ||
""" | ||
Adds arguments for DQN model | ||
Note: these params are fine tuned for Pong env | ||
Args: | ||
parent | ||
""" | ||
arg_parser = argparse.ArgumentParser(parents=[parent]) | ||
|
||
arg_parser.add_argument( | ||
"--algo", type=str, default="dqn", help="algorithm to use for training" | ||
) | ||
arg_parser.add_argument( | ||
"--batch_size", type=int, default=32, help="size of the batches" | ||
) | ||
arg_parser.add_argument("--lr", type=float, default=1e-4, help="learning rate") | ||
arg_parser.add_argument( | ||
"--env", type=str, default="PongNoFrameskip-v4", help="gym environment tag" | ||
) | ||
arg_parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") | ||
arg_parser.add_argument( | ||
"--episode_length", type=int, default=500, help="max length of an episode" | ||
) | ||
arg_parser.add_argument( | ||
"--max_episode_reward", | ||
type=int, | ||
default=18, | ||
help="max episode reward in the environment", | ||
) | ||
arg_parser.add_argument( | ||
"--max_steps", type=int, default=500000, help="max steps to train the agent" | ||
) | ||
arg_parser.add_argument( | ||
"--n_steps", | ||
type=int, | ||
default=4, | ||
help="how many steps to unroll for each update", | ||
) | ||
arg_parser.add_argument( | ||
"--gpus", type=int, default=1, help="number of gpus to use for training" | ||
) | ||
arg_parser.add_argument( | ||
"--seed", type=int, default=123, help="seed for training run" | ||
) | ||
arg_parser.add_argument( | ||
"--backend", | ||
type=str, | ||
default="dp", | ||
help="distributed backend to be used by lightning", | ||
) | ||
return arg_parser |
Oops, something went wrong.