In [1]:
from torch.utils.data import DataLoader, Dataset
import torch
from torch.nn.utils.rnn import pad_sequence

In [2]:
class MyDataset(Dataset):
    def __init__(self):
        self.data = [i for i in range(10)] # pretend these are data points
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        return self.data[index]

In [3]:
dataset = MyDataset()
loader = DataLoader(dataset, batch_size=4, shuffle=False)

In [4]:
for batch in loader:
    print(batch)

tensor([0, 1, 2, 3])
tensor([4, 5, 6, 7])
tensor([8, 9])


Under the hood:
- When dataloader wants to form a batch: it calls the dataset's __getitem__ several times to get individual samples, then it passes that list of samples to a function called the collate function(collate_fn) -> that function determines how those samples are then combined into one batch. 
- By default, pytorch uses the default_collate, which tries to stack everything into tensors automatically

The collate function = how to combine samples into a batch:
- The collate function defines what to do with a list of items from the dataset
- If you don't pass one, pytorch uses a default version that:
  - Converts lists of numbers into tensors
  - Stacks them along a new batch dimension
  - Works recursively for tuples and dicts

Default collate:

In [5]:
data = [(torch.tensor([1,2]),0),
        (torch.tensor([3,4]),1),
        (torch.tensor([5,6]),0)]

loader = DataLoader(data,batch_size=2)

for batch in loader:
    print(batch)

[tensor([[1, 2],
        [3, 4]]), tensor([0, 1])]
[tensor([[5, 6]]), tensor([0])]


Need for custom collate functions:
- default stacking fails when samples are not the same shape or type
- For example when working with variable-length text sequences, images of different sizes, or more compplex data structures

In [23]:
class VariableLengthDataset(Dataset):
    def __init__(self):
        self.data = [[1,2],[1,2,3,4],[1,2,3,4,5]]
    def __getitem__(self, idx):
        return torch.tensor(self.data[idx])
    def __len__(self):
        return len(self.data)

In [24]:
dataset = VariableLengthDataset()
for item in dataset:
    print(item)

tensor([1, 2])
tensor([1, 2, 3, 4])
tensor([1, 2, 3, 4, 5])


In [25]:
loader = DataLoader(dataset, batch_size=2)

In [26]:
for batch in loader:
    print(batch)

RuntimeError: stack expects each tensor to be equal size, but got [2] at entry 0 and [4] at entry 1

Tensors above have different sizes and therefore cannot be stacked

Custom collate function:
- we can define how to combine variable-length sequences - for instance, pad them to the same length:

In [27]:
def pad_collate(batch):
    # 'batch' is a list of tensors of different lengths
    print(batch)
    padded = pad_sequence(batch, batch_first=True, padding_value=0)
    return padded

In [28]:
loader = DataLoader(dataset, batch_size=2, collate_fn=pad_collate)

In [29]:
for batch in loader:
    print(batch)

[tensor([1, 2]), tensor([1, 2, 3, 4])]
tensor([[1, 2, 0, 0],
        [1, 2, 3, 4]])
[tensor([1, 2, 3, 4, 5])]
tensor([[1, 2, 3, 4, 5]])


another collate function example:

In [30]:
data = [
    (torch.tensor([1,2,3]), 0),
    (torch.tensor([4,5]),1),
    (torch.tensor([6]),0)
]

In [34]:
def nlp_collate(batch):
    print(batch)
    sequences,labels = zip(*batch)
    padded = pad_sequence(sequences, batch_first=True, padding_value=0)
    labels = torch.tensor(labels)
    return padded,labels

In [35]:
loader = DataLoader(data, batch_size=2, collate_fn=nlp_collate)

In [36]:
for x,y in loader:
    print(x)
    print(y)
    break

[(tensor([1, 2, 3]), 0), (tensor([4, 5]), 1)]
tensor([[1, 2, 3],
        [4, 5, 0]])
tensor([0, 1])


In [51]:
z = [([1,2,3],0),([4,5],1),([6,7,8],0)]

In [52]:
sequences,labels = zip(*z)

In [53]:
sequences

([1, 2, 3], [4, 5], [6, 7, 8])

In [54]:
labels

(0, 1, 0)

In [50]:
for seq,label in zip(*z):
    print(seq)

[1, 2, 3]
0


In [42]:
for i in zip(*z):
    print(i)
    break

([1, 2, 3], [4, 5])


In [61]:
a = [1,2,3,4]
b = [5,6,7,8]

In [62]:
a_,b_,c_,d_ = zip(a,b)

In [63]:
print(a_)

(1, 5)


In [60]:
for _ in i:
    print(_)

(1, 5)
(2, 6)
(3, 7)
(4, 8)
