Skip to content

Commit

Permalink
reviewed replay_memory code
Browse files Browse the repository at this point in the history
Former-commit-id: 95a432d
  • Loading branch information
ZhitingHu committed Jan 5, 2018
1 parent 3a51ab0 commit d8603e6
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 8 deletions.
4 changes: 2 additions & 2 deletions texar/agents/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def set_initial_state(self, observation):
self.current_state = np.array(observation)

def perceive(self, action, reward, is_terminal, next_observation):
self.replay_memory.push({
self.replay_memory.add({
'state': self.current_state,
'action': action,
'reward': reward,
Expand All @@ -121,7 +121,7 @@ def train_qnet(self):
""" Train the Q-Network
:return:
"""
minibatch = self.replay_memory.sample(self.batch_size)
minibatch = self.replay_memory.get(self.batch_size)
state_batch = np.array([data['state'] for data in minibatch])
action_batch = np.array([data['action'] for data in minibatch])
reward_batch = np.array([data['reward'] for data in minibatch])
Expand Down
1 change: 0 additions & 1 deletion texar/core/explorations.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def get_epsilon(self, timestep):
return self._hparams.initial_epsilon
if timestep > et:
return self._hparams.final_epsilon

r = (timestep - st) * 1.0 / nsteps
epsilon = (1 - r) * self._hparams.initial_epsilon + \
r * self._hparams.final_epsilon
Expand Down
30 changes: 25 additions & 5 deletions texar/core/replay_memories.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,46 @@
from texar.hyperparams import HParams
#
"""
TODO: docs
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from collections import deque
import random

from texar.hyperparams import HParams


class ReplayMemoryBase(object):
"""TODO: docs
"""
def __init__(self, hparams=None):
self._hparams = HParams(hparams, self.default_hparams())

def push(self, element):
def add(self, element):
"""TODO: docs
"""
raise NotImplementedError

def sample(self, size):
def get(self, size):
"""TODO: docs
"""
raise NotImplementedError

@staticmethod
def default_hparams():
"""Returns a dictionary of default hyperparameters.
"""
return {
'name': 'replay_memory'
}


class DequeReplayMemory(ReplayMemoryBase):
"""TODO: docs
"""
def __init__(self, hparams=None):
ReplayMemoryBase.__init__(self, hparams)
self.deque = deque()
Expand All @@ -34,10 +53,11 @@ def default_hparams():
'capacity': 80000
}

def push(self, element):
def add(self, element):
self.deque.append(element)
if len(self.deque) > self.capacity:
self.deque.popleft()

def sample(self, size):
#TODO(zhiting): is it okay to have stand alone random generator ?
def get(self, size):
return random.sample(self.deque, size)

0 comments on commit d8603e6

Please sign in to comment.