In [1]:
import sys
sys.path.append("../src")
from config import *
from plugin_write_and_run import *

In [2]:
%%write_and_run ../src/replay_buffer.py
import numpy as np
import sys
sys.path.append("../src")
from config import *

In [3]:
%%write_and_run -a ../src/replay_buffer.py

class ReplayBuffer():
    def __init__(self, env, buffer_capacity=BUFFER_CAPACITY, batch_size=BATCH_SIZE):
        self.env = env
        self.buffer_capacity = buffer_capacity
        self.batch_size = batch_size
        self.buffer_counter = 0

        self.states = np.zeros((self.buffer_capacity, env.observation_space.shape[0]))
        self.actions = np.zeros((self.buffer_capacity, env.action_space.shape[0]))
        self.rewards = np.zeros((self.buffer_capacity))
        self.next_states = np.zeros((self.buffer_capacity, env.observation_space.shape[0]))
        self.dones = np.zeros((self.buffer_capacity), dtype=bool)

    def add_record(self, state, action, reward, next_state, done):
        # Set index to zero if counter = buffer_capacity and start again (1 % 100 = 1 and 101 % 100 = 1) so we substitute the older entries
        index = self.buffer_counter % self.buffer_capacity

        self.states[index] = state
        self.actions[index] = action
        self.rewards[index] = reward
        self.next_states[index] = next_state
        self.dones[index] = done
        
        # Update the counter when record something
        self.buffer_counter += 1
        
    def get_minibatch(self):
        record_range = min(self.buffer_counter, self.buffer_capacity)
        batch_indices = np.random.choice(record_range, self.batch_size, replace=False)

        # Convert to tensors
        state = self.states[batch_indices]
        action = self.actions[batch_indices]
        reward = self.rewards[batch_indices]
        next_state = self.next_states[batch_indices]
        done = self.dones[batch_indices]
        
        return state, action, reward, next_state, done

In [4]:
import gym

In [5]:
env = gym.make(ENV_NAME)

In [6]:
rb = ReplayBuffer(env)

In [7]:
rb.buffer_counter

0

In [8]:
env.reset()

array([0.99282967, 0.11953765, 0.26821029])

In [9]:
action = np.array([-0.5])

In [10]:
state, reward, done, _ = env.step(action)

In [11]:
for i in range(1000):
    rb.add_record(state, action, reward, state, done)

In [12]:
rb.buffer_counter

1000

In [13]:
rb.dones

array([False, False, False, ..., False, False, False])

In [14]:
state, action, reward, next_state, done = rb.get_minibatch()

In [15]:
state.shape

(64, 3)

In [16]:
action.shape

(64, 1)

In [17]:
rb.buffer_counter

1000

In [18]:
reward.shape

(64,)