@@ -7,10 +7,10 @@


class ExperienceReplay(object):
def __init__(self, env, history_size=4, batch_size=32, memory_size=1000000):
def __init__(self, env, action_repeat=4, batch_size=32, memory_size=1000000):
dims = list(env.dims)

self.history_size = history_size
self.action_repeat = action_repeat
self.batch_size = batch_size
self.memory_size = memory_size

@@ -20,8 +20,8 @@ def __init__(self, env, history_size=4, batch_size=32, memory_size=1000000):
self.terminals = np.empty(self.memory_size, dtype=np.bool)

# pre-allocate prestates and poststates for minibatch
self.prestates = np.empty([self.batch_size, self.history_size] + dims, dtype=np.float16)
self.poststates = np.empty([self.batch_size, self.history_size] + dims, dtype=np.float16)
self.prestates = np.empty([self.batch_size, self.action_repeat] + dims, dtype=np.float16)
self.poststates = np.empty([self.batch_size, self.action_repeat] + dims, dtype=np.float16)

self.count = 0
self.current = 0
@@ -35,15 +35,15 @@ def add(self, screen, reward, action, terminal):
self.current = (self.current + 1) % self.memory_size

def sample(self):
assert self.count >= self.history_size, 'Add more data'
assert self.count >= self.action_repeat, 'Add more data'

indexes = []
while len(indexes) < self.batch_size:
while True:
index = random.randint(self.history_size, self.count - 1)
if index >= self.current and index - self.history_size < self.current:
index = random.randint(self.action_repeat, self.count - 1)
if index >= self.current and index - self.action_repeat < self.current:
continue
if self.terminals[(index - self.history_size):index].any():
if self.terminals[(index - self.action_repeat):index].any():
continue
break

@@ -59,21 +59,21 @@ def sample(self):

def retrieve(self, index=None):
"""
Retrieve 4 screens (4 is history_size)
Retrieve 4 screens (4 is action_repeat)
"""
if index is None:
index = min(self.count, self.memory_size)

index = index % self.count
if index >= self.history_size - 1:
return self.screens[(index - (self.history_size - 1)):(index + 1), ...]
if index >= self.action_repeat - 1:
return self.screens[(index - (self.action_repeat - 1)):(index + 1), ...]
else:
indexes = [(index - i) % self.count for i in reversed(range(self.history_size))]
indexes = [(index - i) % self.count for i in reversed(range(self.action_repeat))]
return self.screens[indexes, ...]

@property
def available(self):
return self.count >= self.history_size
return self.count >= self.action_repeat


if __name__ == '__main__':

Large diffs are not rendered by default.

@@ -1,42 +1,43 @@
import numpy as np

import threading
from environment import Environment
from replay import ExperienceReplay
import tensorflow as tf


def test_memory_replay():
env = Environment('Breakout-v0')
print 'env.dims:', env.dims
replay = ExperienceReplay(env)

env.reset()
count = 0
while True:
count += 1
action = env.random_action()
screen, reward, done, info = env.step(action)
replay.add(screen, reward, action, done)
if done:
break

prestates, actions, rewards, poststates, terminals = replay.sample()

print 'prestates:', prestates.shape
print 'actions:', actions
print 'rewards:', rewards
print 'poststates:', poststates.shape
print 'terminals:', terminals
print 'count:', count

for i in range(replay.batch_size - 1):
for j in range(replay.history_size):

if (j + 1) % replay.history_size != 0:
print np.array_equal(prestates[i][j + 1], poststates[i][j]), (j + 1) % replay.history_size
else:
print np.array_equal(prestates[i][0], poststates[i][j]), (j + 1) % replay.history_size

#
# def test_memory_replay():
# env = Environment('Breakout-v0')
# print 'env.dims:', env.dims
# replay = ExperienceReplay(env)
#
# env.reset()
# count = 0
# while True:
# count += 1
# action = env.random_action()
# screen, reward, done, info = env.step(action)
# replay.add(screen, reward, action, done)
# if done:
# break
#
# prestates, actions, rewards, poststates, terminals = replay.sample()
#
# print 'prestates:', prestates.shape
# print 'actions:', actions
# print 'rewards:', rewards
# print 'poststates:', poststates.shape
# print 'terminals:', terminals
# print 'count:', count
#
# for i in range(replay.batch_size - 1):
# for j in range(replay.history_size):
#
# if (j + 1) % replay.history_size != 0:
# print np.array_equal(prestates[i][j + 1], poststates[i][j]), (j + 1) % replay.history_size
# else:
# print np.array_equal(prestates[i][0], poststates[i][j]), (j + 1) % replay.history_size
#

def test_argmax():
input = tf.placeholder('float32', [10])
@@ -5,4 +5,5 @@
env = environment.Environment('Breakout-v0')
replay = replay.ExperienceReplay(env)
agent = agent.Agent(env, replay)
agent.restore()
agent.train()