In [4]:
import datasets

In [5]:
dataset = datasets.load_dataset("iwslt2017", "iwslt2017-en-zh")

In [5]:
from itertools import chain
from nltk.tokenize import word_tokenize
from tqdm import tqdm

vocab = set()
for sent in tqdm(
    chain.from_iterable([
        dataset['train']['translation'], 
        dataset['test']['translation'], 
        dataset['validation']['translation']
])):
    words = set(filter(lambda tok:len(tok) > 0, word_tokenize(sent['en'])))
    vocab |= words

240694it [00:12, 19643.28it/s]


#### Vocab size, using NLTK

We use this to determine the word index for the Tokeniser.

The Tokeniser, when encoding, will output:
  - index: int
    The index of the word in the learned vocabulary.
  - attention_mask: 1
    The base attention mask (1) for the attention layer of the model.
  - token_type: 0 | 1
    If the token is a padding token, as all sentences are padded/truncated to a fixed context length

The Tokeniser, when decoding, will take:
  - index: int
    The predicted token (index) that the model outputs.
  - ...

  Given these inputs, the Tokeniser should map the indices to the correct token in the learned vocabulary, and output the translated token.

In [6]:
print(len(vocab))

72289


In [1]:
from transformers import MarianTokenizer

tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-zh')

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
sent = dataset['train']['translation'][0]['en']

### Desired Behavior

In [8]:
tokenizer(sent)

{'input_ids': [2547, 39, 194, 961, 2, 24873, 6, 309, 41, 23, 24, 9869, 13, 1330, 16042, 9, 55, 3, 1727, 9, 733, 9, 58, 3147, 10782, 27, 28, 23, 100, 7045, 5488, 6, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [11]:
token_ids = tokenizer.encode(sent, return_tensors='pt')
token_ids

tensor([[ 2547,    39,   194,   961,     2, 24873,     6,   309,    41,    23,
            24,  9869,    13,  1330, 16042,     9,    55,     3,  1727,     9,
           733,     9,    58,  3147, 10782,    27,    28,    23,   100,  7045,
          5488,     6,     0]])