In [1]:
import torch
import tensordict as td
from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer
from torchrl.data.replay_buffers.storages import LazyMemmapStorage

import numpy as np


In [43]:
def create_replay_buffer(prb=False, buffer_size=100000, buffer_scratch_dir="/tmp/", device="cpu",make_replay_buffer=3):
    
    if prb:
            buffer = TensorDictPrioritizedReplayBuffer(
            alpha=0.7,
            beta=0.5,
            pin_memory=False,
            prefetch=make_replay_buffer,
            storage=LazyMemmapStorage(
                buffer_size,
                scratch_dir=buffer_scratch_dir,
                device=device,
            ),
        )
    else:
        replay_buffer = TensorDictReplayBuffer(
            pin_memory=False,
            prefetch=make_replay_buffer,
            storage=LazyMemmapStorage(
                buffer_size,
                scratch_dir=buffer_scratch_dir,
                device=device,
            ),
        )
    return replay_buffer

In [44]:
buffer = create_replay_buffer(prb=False, buffer_size=1000, buffer_scratch_dir="/tmp/", device="cpu",make_replay_buffer=3)

In [45]:
transition = td.TensorDict({"observation": torch.from_numpy(np.random.uniform(0, 1, (1, 5))).float().to("cpu"),
                            "action": torch.from_numpy(np.random.uniform(0, 1, (1, 1))).float().to("cpu"),
                            "reward": torch.from_numpy(np.random.uniform(0, 1, (1, 1))).float().to("cpu"),
                            "next_observation": torch.from_numpy(np.random.uniform(0, 1, (1, 5))).float().to("cpu"),
                            "done": torch.from_numpy(np.random.uniform(0, 1, (1, 1))).bool().to("cpu")}, batch_size=1)
                            

In [46]:
transition

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next_observation: Tensor(shape=torch.Size([1, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([1, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        reward: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([1]),
    device=None,
    is_shared=False)

In [47]:
transition["action"]

tensor([[0.3452]])

In [48]:
buffer.extend(transition)

Creating a MemmapStorage...
The storage is being created: 
	action: /tmp/action.memmap, 0.3814697265625 Mb of storage (size: torch.Size([1000, 1])).
	done: /tmp/done.memmap, 0.095367431640625 Mb of storage (size: torch.Size([1000, 1])).
	index: /tmp/index.memmap, 0.3814697265625 Mb of storage (size: torch.Size([1000])).
	next_observation: /tmp/next_observation.memmap, 1.9073486328125 Mb of storage (size: torch.Size([1000, 5])).
	observation: /tmp/observation.memmap, 1.9073486328125 Mb of storage (size: torch.Size([1000, 5])).
	reward: /tmp/reward.memmap, 0.762939453125 Mb of storage (size: torch.Size([1000, 1])).


array([0])

In [36]:
buffer.__len__()

1

In [28]:
batch = buffer.sample(batch_size=2)
print(batch)

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.bool, is_shared=False),
        index: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int32, is_shared=False),
        next_observation: Tensor(shape=torch.Size([2, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([2, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        reward: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2]),
    device=cpu,
    is_shared=False)


In [18]:
batch["action"]

tensor([0.6947])