Skip to content

Commit

Permalink
updating dqn
Browse files Browse the repository at this point in the history
Former-commit-id: b6a37f5
  • Loading branch information
ZhitingHu committed Jan 2, 2018
1 parent 5d50ac8 commit ac7f531
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 74 deletions.
48 changes: 23 additions & 25 deletions examples/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
import gym

from texar.agents import NatureDQNAgent
from texar.agents import DQNAgent

# pylint: disable=invalid-name

Expand All @@ -18,34 +18,32 @@
if __name__ == '__main__':
hparams = {}
hparams['qnet'] = {
'kwargs': {
'hparams': {
'network_hparams': {
'layers': [
{
'type': 'Dense',
'kwargs': {
'units': 128,
'activation': 'relu'
}
}, {
'type': 'Dense',
'kwargs': {
'units': 128,
'activation': 'relu'
}
}, {
'type': 'Dense',
'kwargs': {
'units': 2
}
'hparams': {
'network_hparams': {
'layers': [
{
'type': 'Dense',
'kwargs': {
'units': 128,
'activation': 'relu'
}
]
}
}, {
'type': 'Dense',
'kwargs': {
'units': 128,
'activation': 'relu'
}
}, {
'type': 'Dense',
'kwargs': {
'units': 2
}
}
]
}
}
}
agent = NatureDQNAgent(actions=2, state_shape=(4, ), hparams=hparams)
agent = DQNAgent(actions=2, state_shape=(4, ), hparams=hparams)

for i in range(5000):
reward_sum = 0.0
Expand Down
8 changes: 7 additions & 1 deletion texar/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,10 @@
Various RL Agents
"""

from texar.agents.nature_dqn_agent import NatureDQNAgent
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# pylint: disable=wildcard-import

from texar.agents.dqn_agent import *
15 changes: 9 additions & 6 deletions texar/agents/agent_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#
"""
Base class for rl agents.
Base class for RL agents.
"""
from __future__ import absolute_import
from __future__ import division
Expand All @@ -11,7 +11,10 @@

class AgentBase(object):
"""
Base class inherited by rl agents.
Base class inherited by RL agents.
Args:
TODO
"""
def __init__(self, hparams=None):
self._hparams = HParams(hparams, self.default_hparams())
Expand All @@ -28,18 +31,18 @@ def default_hparams():
}

def set_initial_state(self, observation):
"""
reset the current state
"""Resets the current state.
Args:
observation: observation in the beginning
"""
raise NotImplementedError

def perceive(self, action, reward, is_terminal, next_observation):
"""Perceive from environment
"""Perceives from environment.
Args:
action: A OneHot vector indicate the action
action: A one-hot vector indicate the action
reward: A number indicate the reward
is_terminal: True iff it is a terminal state
next_observation: New Observation from environment
Expand Down
33 changes: 15 additions & 18 deletions texar/agents/nature_dqn_agent.py → texar/agents/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@

# pylint: disable=too-many-instance-attributes, too-many-arguments, invalid-name

class NatureDQNAgent(AgentBase):
__all__ = [
"DQNAgent"
]

class DQNAgent(AgentBase):
"""TODO: docs
"""
def __init__(self, actions, state_shape,
Expand All @@ -35,14 +39,14 @@ def __init__(self, actions, state_shape,
if self.qnet is None:
self.qnet = get_instance(
self._hparams.qnet.type,
self._hparams.qnet.kwargs.todict(),
{'hparams': self._hparams.qnet.hparams},
module_paths=['texar.modules', 'texar.custom'])

self.replay_memory = replay_memory
if self.replay_memory is None:
self.replay_memory = get_instance(
self._hparams.replay_memory.type,
self._hparams.replay_memory.kwargs.todict(),
{'hparams': self._hparams.replay_memory.hparams},
module_paths=['texar.core', 'texar.custom'])

# loss & trainer
Expand All @@ -67,7 +71,7 @@ def __init__(self, actions, state_shape,

self.exploration = get_instance(
self._hparams.exploration.type,
self._hparams.exploration.kwargs.todict(),
{'hparams': self._hparams.exploration.hparams},
module_paths=['texar.core', 'texar.custom'])

# TODO
Expand All @@ -78,30 +82,24 @@ def __init__(self, actions, state_shape,
@staticmethod
def default_hparams():
return {
'name': 'nature_dqn_agent',
'name': 'dqn_agent',
'batch_size': 32,
'discount_factor': 0.99,
'observation_steps': 100,
'update_period': 100,
'qnet': {
'type': 'NatureQNet',
'kwargs': {
'hparams': None
}
'hparams': None
},
'replay_memory': {
'type': 'DequeReplayMemory',
'kwargs': {
'hparams': None
}
'hparams': None
},
'loss': "l2_loss",
'optimization': opt.default_optimization_hparams(),
'exploration': {
'type': 'EpsilonDecayExploration',
'kwargs': {
'hparams': None
}
'type': 'EpsilonLinearDecayExploration',
'hparams': None
}
}

Expand Down Expand Up @@ -153,7 +151,7 @@ def train_qnet(self):
def update_target(self):
""" Copy the parameters from qnet to target
"""
self.sess.run(self.qnet.update_target())
self.sess.run(self.qnet.copy_qnet_to_target())

def get_action(self, state=None, action_mask=None):
if state is None:
Expand All @@ -164,7 +162,7 @@ def get_action(self, state=None, action_mask=None):
qvalue = self.sess.run(self.qnet_qvalue,
feed_dict={self.state_input: np.array([state])})
action = np.zeros(shape=(self.actions,))
if random.random() < self.exploration.epsilon():
if random.random() < self.exploration.get_epsilon(self.timestep):
while True:
action_id = random.randrange(self.actions)
if action_mask[action_id]:
Expand All @@ -176,5 +174,4 @@ def get_action(self, state=None, action_mask=None):
action_id = np.argmax(qvalue)
action[action_id] = 1.0

self.exploration.add_timestep()
return action
79 changes: 58 additions & 21 deletions texar/core/explorations.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,86 @@
#
"""
TODO: docs
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from texar.hyperparams import HParams

# pylint: disable=invalid-name

class ExplorationBase(object):
"""Base class inherited by all exploration classes.
Args:
class ExplorationBase:
"""
def __init__(self, hparams=None):
self._hparams = HParams(hparams, self.default_hparams())

@staticmethod
def default_hparams():
"""Returns a dictionary of default hyperparameters.
TODO: docs
"""
return {
'name': 'exploration_base'
}

@property
def epsilon(self):
raise NotImplementedError
def get_epsilon(self, timestep):
"""Returns the epsilon value.
def add_timestep(self):
Args:
timestep (int): The time step.
Returns:
float: the epsilon value.
"""
raise NotImplementedError

@property
def hparams(self):
"""The hyperparameter.
"""
return self._hparams


class EpsilonDecayExploration(ExplorationBase):
class EpsilonLinearDecayExploration(ExplorationBase):
"""TODO: docs
Args:
"""
def __init__(self, hparams=None):
ExplorationBase.__init__(self, hparams=hparams)
self._epsilon = self._hparams.initial_epsilon
self.timestep = 0
self.initial_epsilon = self._hparams.initial_epsilon
self.final_epsilon = self._hparams.final_epsilon
self.decay_steps = self._hparams.decay_steps

@staticmethod
def default_hparams():
"""TODO
"""
return {
'name': 'epsilon_decay_exploration',
'name': 'epsilon_linear_decay_exploration',
'initial_epsilon': 0.1,
'final_epsilon': 0.0,
'decay_steps': 20000
'decay_timesteps': 20000,
'start_timestep': 0
}

def epsilon(self):
return self._epsilon
def get_epsilon(self, timestep):
nsteps = self._hparams.decay_timesteps
st = self._hparams.start_timestep
et = st + nsteps

if timestep <= st:
return self._hparams.initial_epsilon
if timestep > et:
return self._hparams.final_epsilon

r = (timestep - st) * 1.0 / nsteps
epsilon = (1 - r) * self._hparams.initial_epsilon + \
r * self._hparams.final_epsilon

def add_timestep(self):
self.timestep += 1
self.epsilon_decay()
return epsilon

def epsilon_decay(self):
if self._epsilon > 0.:
self._epsilon -= (self.initial_epsilon - self.final_epsilon) / self.decay_steps
6 changes: 3 additions & 3 deletions texar/modules/q_nets/q_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def default_hparams():
'network_hparams': FeedForwardNetwork.default_hparams()
}

def _build(self, inputs):
def _build(self, inputs): # pylint: disable=arguments-differ
qnet_result, target_result = self.qnet(inputs), self.target(inputs)

if not self._built:
Expand All @@ -46,8 +46,8 @@ def _build(self, inputs):

return qnet_result, target_result

def update_target(self):
""" Copy the parameters from qnet to target
def copy_qnet_to_target(self):
"""Copy the parameters from qnet to target.
Returns:
A list of assign tensors
Expand Down

0 comments on commit ac7f531

Please sign in to comment.