In [1]:
# !pip install torchrl>=0.1.1 tensordict>=0.1.1

In [2]:
import tensordict

# ??tensordict

# Tensordict dataloading speed

In [3]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from tensordict import TensorDict


In [4]:
a = torch.rand(1000, 50, 2)
td = TensorDict({"a": a}, batch_size=1000)

### Case 1: store data as tensors, create TensorDicts on the run

In [5]:
class SimpleDataset(Dataset):
    def __init__(self, data):
        # We split into a list since it is faster to dataload (fair comparison vs others)
        self.data = [d for d in data]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    

dataset = SimpleDataset(td['a'])
dataloader = DataLoader(dataset, batch_size=32, collate_fn=torch.stack)
x = TensorDict({'a': next(iter(dataloader))}, batch_size=32)
print(x.shape)

torch.Size([32])


In [6]:
%timeit -r 10 -n 100 for x in dataloader: TensorDict({'a': x}, batch_size=x.shape[0])

522 µs ± 79.3 µs per loop (mean ± std. dev. of 10 runs, 100 loops each)


### Case 2: store data as TensorDicts and directly load them

In [9]:
class TensorDictDataset(Dataset):
    def __init__(self, data):
        self.data = [d for d in data]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
data  = TensorDictDataset(td)
# use collate_fn=torch.stack to avoid StopIteration error
dataloader = DataLoader(data, batch_size=32, collate_fn=torch.stack)
x = next(iter(dataloader))
print(x.shape)

torch.Size([32])


In [10]:
%timeit -r 10 -n 100 for x in dataloader: pass

1.75 ms ± 58.9 µs per loop (mean ± std. dev. of 10 runs, 100 loops each)


### Case 3: store TensorDict data as dictionaries and create TensorDicts on the run with collate_fn

In [11]:
class CustomTensorDictDataset(Dataset):
    def __init__(self, data):
        self.data = [
            {key: value[i] for key, value in data.items()}
            for i in range(data.shape[0])
        ]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


class CustomTensorDictCollate(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, batch):
        return TensorDict(
            {key: torch.stack([b[key] for b in batch]) for key in batch[0].keys()},
            batch_size=len(batch),
        )
    
data = CustomTensorDictDataset(td)
dataloader = DataLoader(data, batch_size=32, collate_fn=CustomTensorDictCollate())
x = next(iter(dataloader))
print(x.shape)

torch.Size([32])


In [12]:
%timeit -r 10 -n 100 for x in dataloader: pass

563 µs ± 63.6 µs per loop (mean ± std. dev. of 10 runs, 100 loops each)


## Case 4: easiest way

https://github.com/pytorch-labs/tensordict/issues/374

In [13]:
dataloader = DataLoader(td, batch_size=32, collate_fn=lambda x: x)
x = next(iter(dataloader))
print(x.shape)

torch.Size([32])


In [15]:
%timeit -r 10 -n 100 for x in dataloader: pass

668 µs ± 1.02 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## Case 5: direct indexing

In [16]:
class TensorDictDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
data  = TensorDictDataset(td)
dataloader = DataLoader(data, batch_size=32, collate_fn=torch.stack)
x = next(iter(dataloader))
print(x)

LazyStackedTensorDict(
    fields={
        a: Tensor(shape=torch.Size([32, 50, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([32]),
    device=None,
    is_shared=False)


In [18]:
%timeit -r 10 -n 100 for x in dataloader: pass

6.67 ms ± 12.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [25]:
class TensorDictDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitems__(self, idx):
        return self.data[idx]
    
data  = TensorDictDataset(td)
dataloader = DataLoader(data, batch_size=32, collate_fn=lambda x: x)#, collate_fn=torch.stack)
x = next(iter(dataloader))
print(x)

TensorDict(
    fields={
        a: Tensor(shape=torch.Size([32, 50, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([32]),
    device=None,
    is_shared=False)


In [26]:
%timeit for x in dataloader: pass

678 µs ± 1.91 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


---

Apparently, splitting data into dictionaries and creating TensorDicts on the run is the fastest way to load data... but why is it not faster to just index TensorDicts instead? And is there a better way?