In [1]:
import torch
from tensordict import TensorDict
T, F = True, False

# Basics

## TensorDict
- source = 딕트
- batch_size = [배치사이즈]

In [3]:
bsize = 5
x = TensorDict({'key1': torch.zeros(bsize, 3),
                'key2': torch.zeros(bsize, 5, 6, dtype=torch.bool)},
                [bsize])
x

TensorDict(
    fields={
        key1: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        key2: Tensor(shape=torch.Size([5, 5, 6]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)

In [19]:
x.batch_size                    # .batch_size 속성

torch.Size([5])

In [20]:
assert x['key1'] is x.get('key1')

x['key1']                         # [키] -> 텐서

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])

### operations along batch dim
- x[인덱싱] / x.gather
- reshape/view
- permute
- (un)squeeze/expand
- unbind/split
- torch.stack/cat

In [15]:
x[2]                            # 인덱싱 -> new 텐서딕트

TensorDict(
    fields={
        key1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        key2: Tensor(shape=torch.Size([5, 6]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

In [4]:
x1 = TensorDict(dict(key1=torch.zeros(bsize, 1),
                     key2=torch.zeros(bsize, 5, 6, dtype=torch.bool)),
                [bsize])
x2 = TensorDict(dict(key1=torch.ones(bsize, 1),
                     key2=torch.ones(bsize, 5, 6, dtype=torch.bool)),
                [bsize])
x = torch.stack((x1, x2))

print(x.batch_size, x['key1'].size(), x['key1'])

torch.Size([2, 5]) torch.Size([2, 5, 1]) tensor([[[0.],
         [0.],
         [0.],
         [0.],
         [0.]],

        [[1.],
         [1.],
         [1.],
         [1.],
         [1.]]])


In [7]:
x.view(-1).batch_size, x.view(-1).get('key1').shape

(torch.Size([10]), torch.Size([10, 1]))

In [8]:
x.permute(1, 0).batch_size, x.permute(1, 0).get('key1').shape

(torch.Size([5, 2]), torch.Size([5, 2, 1]))

In [10]:
x.expand(3, *x.batch_size).batch_size, x.expand(3, *x.batch_size).get('key1').shape

(torch.Size([3, 2, 5]), torch.Size([3, 2, 5, 1]))

In [13]:
# nested TensorDict

x = TensorDict({'key1': torch.zeros(bsize, 3),
                'key2': TensorDict({'subkey1': torch.zeros(bsize, 2, 1)}, [bsize, 2]) },
                [bsize])
x

TensorDict(
    fields={
        key1: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        key2: TensorDict(
            fields={
                subkey1: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([5, 2]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)

## Replay Buffers

In [2]:
from torchrl.data import ReplayBuffer

buffer = ReplayBuffer()         # default size 1000, ListStorage

print(len(buffer))

buffer.extend(range(2000))      # extend / add
print(len(buffer))

  Referenced from: <CFED5F8E-EC3F-36FD-AAA3-2C6C7F8D3DD9> /opt/homebrew/Caskroom/miniconda/base/envs/py311/lib/python3.11/site-packages/torchvision/image.so
  warn(


0
1000


In [3]:
from torchrl.data import LazyMemmapStorage, LazyTensorStorage, ListStorage

size = 100

buffer_list = ReplayBuffer(storage=ListStorage(size),       
                           collate_fn=lambda x: x)          # numerical data 아니면

buffer_list.extend(['a', 0, 'b'])

buffer_list.sample(3)

['b', 'a', 'a']

In [4]:
from tensordict import TensorDict

buffer_lazytensor = ReplayBuffer(storage=LazyTensorStorage(size))

data = TensorDict({'a': torch.arange(12).view(3, 4),
                   ('b', 'c'): torch.arange(15).view(3, 5)},
                   batch_size=[3])
buffer_lazytensor.extend(data)
len(buffer_lazytensor)

3

In [5]:
sample = buffer_lazytensor.sample(5)
sample

TensorDict(
    fields={
        a: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([5, 5]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([5]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)

In [6]:
sample['a']

tensor([[ 4,  5,  6,  7],
        [ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [ 0,  1,  2,  3]])

In [8]:
sample['b', 'c']

tensor([[ 5,  6,  7,  8,  9],
        [ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [ 0,  1,  2,  3,  4]])

In [10]:
from torchrl.data import TensorDictReplayBuffer

buffer = TensorDictReplayBuffer(storage=LazyTensorStorage(size), batch_size=12)

buffer.extend(data)
print(len(buffer))

sample = buffer.sample()
print(sample)
sample['index']                 # 'index' key of the sample

3
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([12, 4]), device=cpu, dtype=torch.int64, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([12, 5]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([12]),
            device=cpu,
            is_shared=False),
        index: Tensor(shape=torch.Size([12]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([12]),
    device=cpu,
    is_shared=False)


tensor([2, 0, 1, 2, 2, 1, 0, 2, 2, 2, 0, 2])