In [81]:
import torch
import spacy
import random
from torchtext.datasets import Multi30k
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data import get_tokenizer, to_map_style_dataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

In [83]:
tokenizer_de = get_tokenizer('spacy','de_core_news_sm')
tokenizer_en = get_tokenizer('spacy','en_core_web_sm')

In [84]:
train_dataset, val_dataset, test_dataset = Multi30k(root='data')
train_dataset = to_map_style_dataset(train_dataset)
val_dataset = to_map_style_dataset(val_dataset)
test_dataset = to_map_style_dataset(test_dataset)

In [85]:
len(train_dataset), len(val_dataset), len(test_dataset)

(29000, 1014, 1000)

In [86]:
def yiled_token(dataset, tokenizer):
    index = 0 if type(tokenizer.keywords['spacy']) == spacy.lang.de.German else 1
    for items in dataset:
        yield tokenizer(items[index].lower())

In [87]:
vocab_de = build_vocab_from_iterator(yiled_token(train_dataset, tokenizer_de), min_freq=2, specials=['<pad>','<unk>','<bos>','<eos>'])
vocab_en = build_vocab_from_iterator(yiled_token(train_dataset, tokenizer_en), min_freq=2, specials=['<pad>','<unk>','<bos>','<eos>'])
vocab_de.set_default_index(1)
vocab_en.set_default_index(1)

In [88]:
def transform2token(dataset):
    length = len(dataset)
    for i in range(length):
        src, trg = dataset._data[i]
        src = tokenizer_de(src)
        trg = tokenizer_en(trg)
        src = [vocab_de['<bos>']] + [vocab_de[x.lower()] for x in src] + [vocab_de['<eos>']]
        trg = [vocab_en['<bos>']] + [vocab_en[x.lower()] for x in trg] + [vocab_en['<eos>']]
        dataset._data[i] = (torch.LongTensor(src), torch.LongTensor(trg))
    return dataset

In [89]:
len(vocab_de),len(vocab_en)

(7853, 5893)

In [90]:
train_dataset = transform2token(train_dataset)
val_dataset = transform2token(val_dataset)
test_dataset = transform2token(test_dataset)

In [91]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 8

In [92]:
def collate_batch(batch):
    src_list, trg_list = [], []
    for src, trg in batch:
        src_list.append(src)
        trg_list.append(trg)
    return pad_sequence(src_list), pad_sequence(trg_list)

In [93]:
def batch_sample(dataset, batch_size):
    indices = [(i, len(s[0])) for i, s in enumerate(dataset)]
    random.shuffle(indices)
    pooled_indices = []
    for i in range(0, len(indices), batch_size*100):
        pooled_indices.extend(sorted(indices[i:i + batch_size * 100], key=lambda x: x[1]))
    pooled_indices = [x[0] for x in pooled_indices]
    for i in range(0, len(pooled_indices), batch_size):
        yield pooled_indices[i:i + batch_size]

In [94]:
train_dataloader = DataLoader(train_dataset,batch_sampler=batch_sample(train_dataset, batch_size),collate_fn=collate_batch)

In [154]:
for src, trg in train_dataloader:
    pass

tensor([[   2,    2,    2,    2,    2,    2,    2,    2],
        [   5,    5,    8,    5,    5,    5,    5,    5],
        [   1,   66,   16,   13,    1,   13,   13, 4906],
        [ 116,   25,  159,    7, 2010,   12,   29,   66],
        [  60,    7,   21,    6,   23,    6,    7,   25],
        [  21,    6,   75,   82,   39,   87,   15,   11],
        [   6,  102,   10,   79,  207,   62,  436, 1984],
        [1282,   40,   93,  404,    9,   21,  151,   31],
        [   7,   37, 5977,    8,   64,   14, 1114,   12],
        [   6,    8,    7,  916,  139,  747,    7,   33],
        [3877,  168,   15,   10,    1,   46,   77,  779],
        [ 117,    1,   81,    5,   28,  556,  112,   44],
        [   4,    4,    4,  587,  756,  117,  194,  395],
        [   3,    3,    3,    4,    4,    4,    4,    4],
        [   0,    0,    0,    3,    3,    3,    3,    3]])
torch.Size([15, 8])
tensor([[   2,    2,    2,    2,    2,    2,    2,    2],
        [  21,    4,    4,    4,    4,    4,    4, 