In [109]:
from collections import defaultdict, deque, namedtuple

from functools import partial

import numpy as np

from energy_py.agents.memory import Memory

Experience = namedtuple('Experience', ['observation',
                                       'action',
                                       'reward',
                                       'next_observation',
                                       'done'])

class HistoryMemory(Memory):
    def __init__(self, size, n_step, obs_shape, action_shape):
        
        super().__init__(size, obs_shape, action_shape)
        self.n_step = n_step
        
        self.buffer = defaultdict(partial(deque, maxlen=self.n_step))
    
        self.experiences = deque(maxlen=self.size)
        
    def remember(self, o, a, r, next_o, done):
        
        self.buffer['observation'].append(
            np.array(o).reshape(1, *self.shapes['observations']))
        
        self.buffer['action'].append(
            np.array(a).reshape(1, *self.shapes['actions']))        
        
        self.buffer['reward'].append(
            np.array(r).reshape(1, *self.shapes['rewards']))
        
        self.buffer['next_observation'].append(
            np.array(next_o).reshape(1, *self.shapes['next_observations']))
        
        self.buffer['done'].append(
            np.array(done).reshape(1, *self.shapes['done']))
        
        if len(self.buffer['observation']) >= self.n_step:
            
            experience_dims = []
            for field in Experience._fields:
                buffer = self.buffer[field]
                hist = np.array(buffer).reshape(1, self.n_step, *buffer[0].shape[1:])
                experience_dims.append(hist)  
                
            experience = Experience(*experience_dims)
            self.experiences.append(experience)
    
    def sample(self, batch_size):
        """
        """
        experiences = self.experiences
        out = {}
        for field in Experience._fields:
            arr = np.array([getattr(e, field) for e in experiences])
            out[field] = arr.reshape(-1, self.n_step, *self.shapes[field])
            
        return out

In [110]:
mem = HistoryMemory(10, n_step=5, obs_shape=(4,), action_shape=(1,))

import gym

e = gym.make('CartPole-v0')
o = e.reset()
for step in range(10):
    act = e.action_space.sample()
    next_o, r, done, _ = e.step(act)
    mem.remember(o, act, r, next_o, done)
    o = next_o

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


In [111]:
sample = mem.sample(2)

KeyError: 'observation'

In [108]:
#  what we actually want from an n_step return memory is
train_obs = [exp[0] for exp in sample['observation']]

In [104]:
sample['observation'].shape

(6, 1, 5, 4)