In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import pad
from torch.utils.data.distributed import DistributedSampler

import sentencepiece as spm
import pandas as pd

class TFDataset(Dataset):

    def __init__(self, bpm_model, tsv_file):
        sp = spm.SentencePieceProcessor()
        sp.load(bpm_model)

        self.sp = sp
        self.bos_id = sp.bos_id() #1
        self.eos_id = sp.eos_id() #2
        
        self.tsv_file = pd.read_csv(tsv_file, delimiter='\t', usecols=['src', 'tar'])
    
    def __len__(self):
        return len(self.tsv_file) #250k
    
    def __getitem__(self, idx):
        src_sent = self.tsv_file.iloc[idx, 0] 
        tar_sent = self.tsv_file.iloc[idx, 1]
        src_encoded = [self.bos_id] + self.sp.encode_as_ids(src_sent) + [self.eos_id]
        tar_encoded = [self.bos_id] + self.sp.encode_as_ids(tar_sent) + [self.eos_id]
        
        return torch.tensor(src_encoded), torch.tensor(tar_encoded)

def collate_fn(batch, max_pad=128):
    
    '''batch : [(src_tensor, tar_tensor), ...]'''
    
    src_list, tar_list = [], []
    
    for (src, tar) in batch:
        src_padded = pad(src, (0, max_pad - len(src))) # 문장 뒤로 max_len까지 zero-padding
        src_list.append(src_padded)
        tar_padded = pad(tar, (0, max_pad - len(tar)))
        tar_list.append(tar_padded)
    
    src = torch.stack(src_list) # list([128],[128],[128]) => tensor w/ size([3,128])
    tar = torch.stack(tar_list)
    
    return (src, tar)

def create_dataloader(bpm_model, tsv_file, is_distributed=False, batch_size=128):
    dataset = TFDataset(bpm_model, tsv_file)
    sampler = (DistributedSampler(dataset) if is_distributed else None)
    
    train_dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=(is_distributed is False),
        sampler=sampler,
        collate_fn=collate_fn
    )
    return train_dataloader

In [2]:
import os
model_path = '/kaggle/input/wmt14-bpe-250k/bpe_250k.model'
tsv_path = '/kaggle/input/wmt14-train-250k/train_df_250k.tsv'

In [3]:
import pandas as pd
data = pd.read_csv(tsv_path, delimiter='\t', usecols=['src', 'tar'])

In [4]:
data.head()

Unnamed: 0,src,tar
0,Resumption of the session,Wiederaufnahme der Sitzungsperiode
1,I declare resumed the session of the European ...,"Ich erkläre die am Freitag, dem 17. Dezember u..."
2,"Although, as you will have seen, the dreaded '...","Wie Sie feststellen konnten, ist der gefürchte..."
3,You have requested a debate on this subject in...,Im Parlament besteht der Wunsch nach einer Aus...
4,"In the meantime, I should like to observe a mi...",Heute möchte ich Sie bitten - das ist auch der...


In [5]:
train_dataloader = create_dataloader(model_path, tsv_path)
src, tar = next(iter(train_dataloader))
print(src)
print(tar)
print(src.shape)
print(tar.shape)

tensor([[    1,   879,    58,  ...,     0,     0,     0],
        [    1,   594,  5468,  ...,     0,     0,     0],
        [    1,  4891, 31943,  ...,     0,     0,     0],
        ...,
        [    1,   594,    63,  ...,     0,     0,     0],
        [    1,   186,    99,  ...,     0,     0,     0],
        [    1,   266,  1990,  ...,     0,     0,     0]])
tensor([[    1,  1330,   463,  ...,     0,     0,     0],
        [    1,   600,   171,  ...,     0,     0,     0],
        [    1,  4080, 31943,  ...,     0,     0,     0],
        ...,
        [    1,  3269,   171,  ...,     0,     0,     0],
        [    1, 11709,   171,  ...,     0,     0,     0],
        [    1,  1323,    51,  ...,     0,     0,     0]])
torch.Size([128, 128])
torch.Size([128, 128])
