
### DataLoader를 이용해 batch size 별로 padding 적용하기 
- 데이터의 size가 다를 경우 
- DataLoader의 파라미터 중 collate_fn 사용하기  
https://python.plainenglish.io/understanding-collate-fn-in-pytorch-f9d1742647d3

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np


In [None]:
class TextDataset(Dataset):
    def __init__(self,X,y=None):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        if self.y is not None:
            return [self.X[idx], self.y[idx]]
        return self.X[idx]

In [None]:
nlp_data = [
    {'tokenized_input': [1, 4, 5, 9, 3, 2],
     'label':0},
    {'tokenized_input': [1, 7, 3, 14, 48, 7, 23, 154, 2],
     'label':0},
    {'tokenized_input': [1, 30, 67, 117, 21, 15, 2],
     'label':1},
    {'tokenized_input': [1, 17, 2],
     'label':0}
]

loader = DataLoader(nlp_data, 2)
batch = next(iter(loader))

RuntimeError: ignored

In [None]:
from torch.nn.utils.rnn import pad_sequence
from pprint import pprint

def custom_collate(data):

    inputs = [torch.tensor(d['tokenized_input']) for d in data]
    labels = [d['label'] for d in data]

    inputs = pad_sequence(inputs, batch_first=True)
    labels = torch.tensor(labels)

    return {'tokenized_input' : inputs, 'label' : labels}

loader = DataLoader(nlp_data, 2, shuffle=False, collate_fn=custom_collate)
iter_loader = iter(loader)
batch = next(iter_loader)
print(batch)
batch2 = next(iter_loader)
print(batch2)

{'tokenized_input': tensor([[  1,   4,   5,   9,   3,   2,   0,   0,   0],
        [  1,   7,   3,  14,  48,   7,  23, 154,   2]]), 'label': tensor([0, 0])}
{'tokenized_input': tensor([[  1,  30,  67, 117,  21,  15,   2],
        [  1,  17,   2,   0,   0,   0,   0]]), 'label': tensor([1, 0])}


### 클래스로 custom collate를 만들어서 사용해보기
- 참조
    - https://www.kaggle.com/code/kunwar31/pytorch-pad-sequences-per-batch/notebook

In [None]:
class TextDataset(Dataset):
    def __init__(self, X, Y=None):
        self.X = X
        self.Y = Y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        if self.Y is not None:
            return [self.X[idx], self.Y[idx]]

        return self.X[idx]



In [None]:
from torch.nn.utils.rnn import pad_sequence

class MyCollator(object):
    def __init__(self, test=False):
        self.test = test

    def __call__(self, batch): # MyCollator 인스턴스를 실행할 경우 자동으로 실행되는 메서드
        if not self.test: # 테스트 데이터가 아닐 경우
            data = [torch.tensor(item[0]) for item in batch]
            target = torch.tensor([item[1] for item in batch])
        else: # 테스트 데이터일 경우
            data = batch

        data = pad_sequence(data, batch_first=True)

        if not self.test:
            return [data, target]

        return [data]



In [None]:
sample_size = 1024

sizes = np.random.normal(loc=200, scale=50, size=(sample_size, )).astype(np.int32)
X = [np.ones((sizes[i])) for i in range(sample_size)]
Y = np.random.rand(sample_size).round()
# print(sizes)
# print(X)
# print(Y)

In [None]:
sizes.max()

371

In [None]:
batch_size = 128
dataset = TextDataset(X, Y)
test_dataset = TextDataset(X)

In [None]:
collate = MyCollator()
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate)

for epochs in range(3):
    print('epochs :', epochs)
    for x, y in loader:
        print(x.shape, y.shape)

epochs : 0
torch.Size([128, 340]) torch.Size([128])
torch.Size([128, 317]) torch.Size([128])
torch.Size([128, 320]) torch.Size([128])
torch.Size([128, 324]) torch.Size([128])
torch.Size([128, 320]) torch.Size([128])
torch.Size([128, 298]) torch.Size([128])
torch.Size([128, 316]) torch.Size([128])
torch.Size([128, 371]) torch.Size([128])
epochs : 1
torch.Size([128, 340]) torch.Size([128])
torch.Size([128, 309]) torch.Size([128])
torch.Size([128, 320]) torch.Size([128])
torch.Size([128, 371]) torch.Size([128])
torch.Size([128, 317]) torch.Size([128])
torch.Size([128, 317]) torch.Size([128])
torch.Size([128, 290]) torch.Size([128])
torch.Size([128, 317]) torch.Size([128])
epochs : 2
torch.Size([128, 320]) torch.Size([128])
torch.Size([128, 320]) torch.Size([128])
torch.Size([128, 371]) torch.Size([128])
torch.Size([128, 324]) torch.Size([128])
torch.Size([128, 300]) torch.Size([128])
torch.Size([128, 300]) torch.Size([128])
torch.Size([128, 317]) torch.Size([128])
torch.Size([128, 340]) t

### 후기

- Dataset, DataLoader, Custom Collate의 관계  
    Dataset에서 batch_size만큼 데이터를 가져와 Custom Collate에서 pad_sequence를 가져오는 것 같다.

- 데이터가 텐서인지 아닌지 잘 확인할 것ㅠㅠ
