In [9]:
%reload_ext autoreload
%autoreload 2

import datasets

dataset = datasets.load_dataset("bentrevett/multi30k")

In [10]:
train, val, test = (dataset['train'], dataset['validation'], dataset['test'])

In [11]:
import spacy

en_nlp = spacy.load("en_core_web_sm")
de_nlp = spacy.load("de_core_news_sm")

eos_token = '<eos>'
sos_token = '<sos>'
max_length = 1000


def tokenize(text):
    en_tokens = [token.text.lower() for token in en_nlp.tokenizer(text['en'])][:max_length]
    de_tokens = [token.text.lower() for token in de_nlp.tokenizer(text['de'])][:max_length]
    en_tokens = [sos_token] + en_tokens + [eos_token]
    de_tokens = [sos_token] + de_tokens + [eos_token]

    text['en_tokens'] = en_tokens
    text['de_tokens'] = de_tokens

    return text

In [12]:
train_data = [tokenize(item) for item in train]
val_data = [tokenize(item) for item in val]
test_data = [tokenize(item) for item in test]

In [13]:
from seq2seq.Vocabulary import Vocabulary

min_freq = 2
unk_token = '<unk>'
pad_token = '<pad>'

special_tokens = [
    unk_token,
    pad_token,
    sos_token,
    eos_token
]

en_vocab = Vocabulary([item['en_tokens'] for item in test_data], special_tokens, min_freq=min_freq)
de_vocab = Vocabulary([item['de_tokens'] for item in test_data], special_tokens, min_freq=min_freq)

In [14]:
len(en_vocab), len(de_vocab)

(546, 473)

In [22]:
import torch
import torch.nn as nn


def encode_texts(item):
    item['en_ids'] = torch.Tensor(en_vocab.encode_seq(item['en_tokens']))
    item['de_ids'] = torch.Tensor(de_vocab.encode_seq(item['de_tokens']))
    return item


train_data = [encode_texts(item) for item in train_data]
val_data = [encode_texts(item) for item in val_data]
test_data = [encode_texts(item) for item in test_data]

In [23]:
def get_collate(pad_index):
    def collate_fn(batch):
        batch_en_ids = [example["en_ids"] for example in batch]
        batch_de_ids = [example["de_ids"] for example in batch]
        batch_en_ids = nn.utils.rnn.pad_sequence(batch_en_ids, padding_value=pad_index)
        batch_de_ids = nn.utils.rnn.pad_sequence(batch_de_ids, padding_value=pad_index)
        batch = {
            "en_ids": batch_en_ids,
            "de_ids": batch_de_ids,
        }
        return batch

    return collate_fn

In [38]:
from torch.utils.data import DataLoader


def get_loader(dataset, batch_size, pad_index, shuffle=False):
    data_loader = DataLoader(dataset, batch_size=batch_size, collate_fn=get_collate(pad_index), shuffle=shuffle)
    return data_loader


batch_size = 64
pad_index = en_vocab.encode('<pad>')
train_loader = get_loader(train_data, batch_size, pad_index, shuffle=True)
val_loader = get_loader(val_data, batch_size, pad_index)
test_loader = get_loader(test_data, batch_size, pad_index)