Skip to content

Commit

Permalink
Extract the replay buffer logic
Browse files Browse the repository at this point in the history
This should make it easier to integrate the priority replay logic.
  • Loading branch information
SwamyDev committed Mar 17, 2020
1 parent 704d6b2 commit 4ed221d
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 24 deletions.
16 changes: 8 additions & 8 deletions tests/test_memory.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import numpy as np
import pytest

from udacity_rl.memory import Memory, MemoryRecordError
from udacity_rl.memory import Memory, MemoryRecordError, UniformReplayBuffer


@pytest.fixture
def make_memory():
def factory(batch_size=64, record_size=10000, num_records=0, seed=None):
m = Memory(batch_size, record_size, seed)
m = Memory(batch_size, UniformReplayBuffer(record_size, seed))
for index in range(num_records):
m.record(index=index)
return m
Expand All @@ -25,7 +25,7 @@ def test_memory_is_initially_empty(memory):


def test_sampling_empty_memory_returns_empty_list(memory):
assert memory.sample() == []
assert memory.sample() == ([], None)


def test_recording_experience_increases_length(memory):
Expand All @@ -38,11 +38,11 @@ def test_recording_experience_increases_length(memory):
@pytest.mark.parametrize('batch_size', (1, 3))
def test_sampling_memory_returns_list_of_experiences_if_enough_records_to_fill_a_batch(make_memory, batch_size):
memory = make_memory(batch_size, num_records=batch_size)
assert sorted(memory.sample()) == sorted([np.array([index]) for index in range(batch_size)])
assert sorted(memory.sample()[0]) == sorted([np.array([index]) for index in range(batch_size)])


def test_sampling_memory_returns_empty_list_if_not_enough_records_to_fill_a_batch(make_memory):
assert make_memory(batch_size=2, num_records=1).sample() == []
assert make_memory(batch_size=2, num_records=1).sample() == ([], None)


def test_is_unfilled_indicated_whether_batch_is_incomplete_or_not(make_memory):
Expand All @@ -51,8 +51,8 @@ def test_is_unfilled_indicated_whether_batch_is_incomplete_or_not(make_memory):


def test_sample_randomly_from_record_if_record_exceeds_batch_size(make_memory):
assert (make_memory(batch_size=2, num_records=10, seed=17).sample() !=
make_memory(batch_size=2, num_records=10, seed=42).sample()).any()
assert (make_memory(batch_size=2, num_records=10, seed=17).sample()[0] !=
make_memory(batch_size=2, num_records=10, seed=42).sample()[0]).any()


def test_record_has_fixed_length(make_memory):
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_recording_different_attributes_raises_error(memory):

def test_sampling_memory_returns_numpy_arrays(make_memory):
memory = make_memory(batch_size=2, num_records=2)
sample = memory.sample()
sample, _ = memory.sample()
assert isinstance(sample, np.ndarray)


Expand Down
9 changes: 3 additions & 6 deletions udacity_rl/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from gym import spaces

from udacity_rl.memory import Memory
from udacity_rl.memory import Memory, UniformReplayBuffer


def with_default(cfg, key, default):
Expand Down Expand Up @@ -95,18 +95,15 @@ def configuration(self):
def _with_mem_defaults(cfg):
cfg = with_default(cfg, 'batch_size', 64)
cfg = with_default(cfg, 'record_size', int(1e5))
cfg = with_default(cfg, 'seed', None)
return cfg


def _only_memory_args(mem_cfg):
return {k: mem_cfg[k] for k in mem_cfg if k in Memory.__init__.__code__.co_varnames}


class MemoryAgent(Agent, abc.ABC):
def __init__(self, observation_space, action_space, **kwargs):
super().__init__(observation_space, action_space, **kwargs)
mem_cfg = _with_mem_defaults(kwargs)
self._memory = Memory(**_only_memory_args(mem_cfg))
self._memory = Memory(mem_cfg["batch_size"], UniformReplayBuffer(mem_cfg["record_size"], mem_cfg["seed"]))

def step(self, obs, action, reward, next_obs, done):
if not isinstance(action, (collections.Sequence, np.ndarray)):
Expand Down
53 changes: 43 additions & 10 deletions udacity_rl/memory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import logging
import random
from collections import deque
Expand All @@ -7,19 +8,51 @@
logger = logging.getLogger(__name__)


class Memory:
def __init__(self, batch_size, record_size, seed=None):
class ReplayBuffer(abc.ABC):
@abc.abstractmethod
def append(self, data):
pass

@abc.abstractmethod
def sample(self, size):
pass

@abc.abstractmethod
def __len__(self):
pass


class UniformReplayBuffer(ReplayBuffer):
def __init__(self, record_size, seed=None):
self._record = deque(maxlen=record_size)
self._batch_size = batch_size
self._keys = None
if seed is not None:
random.seed(seed)

self._print_config()

def _print_config(self):
logger.info(f"Uniform Replay Buffer:\n"
f"\tRecord size:\t{self._record.maxlen}\n")

def append(self, data):
self._record.append(data)

def sample(self, size):
return random.sample(self._record, k=size), None

def __len__(self):
return len(self._record)


class Memory:
def __init__(self, batch_size, buffer):
self._buffer = buffer
self._batch_size = batch_size
self._keys = None
self._print_config()

def _print_config(self):
logger.info(f"Memory configuration:\n"
f"\tRecord size:\t{self._record.maxlen}\n"
f"\tBatch size:\t{self._batch_size}\n")

def record(self, **kwargs):
Expand All @@ -29,17 +62,17 @@ def record(self, **kwargs):
if keys != self._keys:
raise MemoryRecordError(f"The recorded value keys are not allowed:\n"
f"expected: {self._keys}\nactual:{keys}")
self._record.append(tuple(kwargs[k] for k in kwargs))
self._buffer.append(tuple(kwargs[k] for k in kwargs))

def sample(self):
if self.is_unfilled():
return []
return [], None

sample = random.sample(self._record, k=self._batch_size)
sample, info = self._buffer.sample(self._batch_size)
if len(sample[0]) > 1:
return self._cast_to_ndarray_tuple(list(zip(*sample)))

return np.array(sample)
return np.array(sample), info

@staticmethod
def _cast_to_ndarray_tuple(attributes):
Expand All @@ -51,7 +84,7 @@ def is_unfilled(self):
return len(self) < self._batch_size

def __len__(self):
return len(self._record)
return len(self._buffer)


class MemoryRecordError(AssertionError):
Expand Down

0 comments on commit 4ed221d

Please sign in to comment.