  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [296]:
import numpy
# SumTree
# a binary tree data structure where the parent’s value is the sum of its children
class SumTree:
    write = 0
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = numpy.zeros(2 * capacity - 1)
        self.data = numpy.zeros(capacity, dtype=object)
        self.n_entries = 0
        self.pending_idx = set()

    # update to the root node
    def _propagate(self, idx, change):
        parent = (idx - 1) // 2
        self.tree[parent] += change
        if parent != 0:
            self._propagate(parent, change)

    # find sample on leaf node
    def _retrieve(self, idx, s):
        left = 2 * idx + 1
        right = left + 1

        if left >= len(self.tree):
            return idx

        if s <= self.tree[left]:
            return self._retrieve(left, s)
        else:
            return self._retrieve(right, s - self.tree[left])

    def total(self):
        return self.tree[0]

    # store priority and sample
    def add(self, p, data):
        idx = self.write + self.capacity - 1
        self.pending_idx.add(idx)

        self.data[self.write] = data
        self.update(idx, p)

        self.write += 1
        if self.write >= self.capacity:
            self.write = 0

        if self.n_entries < self.capacity:
            self.n_entries += 1

    # update priority
    def update(self, idx, p):
        if idx not in self.pending_idx:
            return
        self.pending_idx.remove(idx)
        change = p - self.tree[idx]
        self.tree[idx] = p
        self._propagate(idx, change)

    # get priority and sample
    def get(self, s):
        idx = self._retrieve(0, s)
        dataIdx = idx - self.capacity + 1
        self.pending_idx.add(idx)
        return (idx, self.tree[idx], dataIdx)

In [408]:
tree = SumTree(10)

In [412]:
tree.add(3, ("s2", "a2", "r2"))

In [4]:
from hashlib import md5
import gym
import mujoco_py
from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env
import numpy as np

env = make_vec_env("CartPole-v1", 4)
from buffer import PrioritizedReplayBuffer
buffer = PrioritizedReplayBuffer(10192, 1, env.observation_space, env.action_space, 4)

In [5]:
obs = env.reset()
for i in range(10000):
    actions = np.array([1,0,0,1])
    next_obs, rewards, dones, infos = env.step(actions)
    buffer.add(obs, actions[:,None], rewards)
    

AssertionError: array([1]) (<class 'numpy.ndarray'>) invalid

In [3]:
buffer.n_envs

4