In [None]:
!pip install torchrl>=0.1.1 tensordict>=0.1.1

# Tensordict dataloading speed

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from tensordict import TensorDict


In [2]:
a = torch.rand(1000, 50, 2)
td = TensorDict({"a": a}, batch_size=1000)

### Case 1: store data as tensors, create TensorDicts on the run

In [3]:
class SimpleDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    

dataset = SimpleDataset(td['a'])
dataloader = DataLoader(dataset, batch_size=32)
x = TensorDict({'a': next(iter(dataloader))}, batch_size=32)
print(x.shape)

torch.Size([32])


In [4]:
%timeit for x in dataloader: TensorDict({'a': x}, batch_size=x.shape[0])

1.24 ms ± 5.71 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### Case 2: store data as TensorDicts and directly load them

In [5]:
class TensorDictDataset(Dataset):
    def __init__(self, data):
        self.data = [d for d in data]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
data  = TensorDictDataset(td)
# use collate_fn=torch.stack to avoid StopIteration error
dataloader = DataLoader(data, batch_size=32, collate_fn=torch.stack)
x = next(iter(dataloader))
print(x.shape)

torch.Size([32])


In [6]:
%timeit for x in dataloader: pass

1.73 ms ± 5.12 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### Case 3: store TensorDict data as dictionaries and create TensorDicts on the run with collate_fn

In [7]:
class CustomTensorDictDataset(Dataset):
    def __init__(self, data):
        self.data = [
            {key: value[i] for key, value in data.items()}
            for i in range(data.shape[0])
        ]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


class CustomTensorDictCollate(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, batch):
        return TensorDict(
            {key: torch.stack([b[key] for b in batch]) for key in batch[0].keys()},
            batch_size=len(batch),
        )
    
data = CustomTensorDictDataset(td)
dataloader = DataLoader(data, batch_size=32, collate_fn=CustomTensorDictCollate())
x = next(iter(dataloader))
print(x.shape)

torch.Size([32])


In [8]:
%timeit for x in dataloader: pass

573 µs ± 1.32 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


---

Apparently, splitting data into dictionaries and creating TensorDicts on the run is the fastest way to load data... but why is it not faster to just index TensorDicts instead? And is there a better way?