# Training Tokenizers

In [None]:
import json
import enum
import copy
import numpy as np
from pathlib import Path
from datasets import load_dataset
from transformers import PreTrainedTokenizerFast
from tokenizers import Tokenizer, decoders, models, normalizers, pre_tokenizers, trainers, processors

class Case(enum.Enum):
  LOWER = 0
  UPPER = 1
  TITLE = 2

class Whitespace(enum.Enum):
  NONE = 0
  SPACE = 1

Path('tokenizers').mkdir(exist_ok=True)
languages = ['arabic', 'azerbaijani', 'chinese', 'english', 'farsi', 'german', 'hebrew', 'hindi', 'korean', 'spanish', 'turkish', 'vietnamese']

Lengths
```
{'arabic': 1712361,
 'azerbaijani': 1715809,
 'chinese': 2879487,
 'english': 2635469,
 'farsi': 132568,
 'german': 282059,
 'hebrew': 20686,
 'hindi': 40027,
 'korean': 2632457,
 'spanish': 4058317,
 'turkish': 1810342,
 'vietnamese': 2493325}
```

## Vanilla Tokenizers

Each token is a single integer. For benchmarking.

In [None]:
for lang in languages:
  dataset = load_dataset("Gabrui/multilingual_TinyStories", lang)
  
  special_toks = {'pad_token':"<|PAD|>", 'bos_token':"<|BOS|>", 'eos_token':"<|EOS|>", 'unk_token':"<|UNK|>"}
  def batch_iterator(batch_size=1000):
      tok_dataset = dataset['train'].select_columns("story")
      for batch in tok_dataset.iter(batch_size):
          yield batch["story"]
  
  tokenizer = Tokenizer(models.BPE(unk_token='<|UNK|>'))
  tokenizer.normalizer = normalizers.NFKC()
  tokenizer.pre_tokenizer = pre_tokenizers.Sequence([pre_tokenizers.UnicodeScripts(),
                                                     pre_tokenizers.Digits(individual_digits=True), pre_tokenizers.Metaspace(prepend_scheme='never')])
  tokenizer.add_special_tokens(list(special_toks.values()))
  tokenizer.post_processor = processors.TemplateProcessing(
      single="<|BOS|> $A <|EOS|>",
      pair="<|BOS|> $A <|EOS|> <|BOS|>:1 $B:1 <|EOS|>:1",
      special_tokens=[
          ("<|BOS|>", tokenizer.token_to_id("<|BOS|>")),
          ("<|EOS|>", tokenizer.token_to_id("<|EOS|>")),
      ],
  )
  tokenizer.decoder = decoders.Metaspace(prepend_scheme='never')
  if not Path(f'vanilla_{lang}.json').exists() or (RETRAIN_TOK:=False):
    trainer = trainers.BpeTrainer(vocab_size=15000, special_tokens=list(special_toks.values()), show_progress=True) # min_frequency=20,
    tokenizer.train_from_iterator(batch_iterator(), trainer, length=len(dataset['train']))
    # tokenizer.save(f'vanilla_{lang}.json')
  else:
    tokenizer = Tokenizer.from_file(f'vanilla_{lang}.json')
  
  fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer)
  fast_tokenizer.add_special_tokens(special_toks)
  fast_tokenizer.save_pretrained(f'tokenizers/{lang}_vanilla')

## Multidimensional Tokenizers

Each token is mapped to a tuple of integers. Currently we model only if the token starts with a space and what kind of case it has (lower case, upper case, or title case), resulting in the tuple: `(token_id, has_whitespace, case)`.
- `token_id`: the id of the originally trainned token (lower case, without space);
- `has_whitespace`: true (`1`) if the token starts with a space;
- `case`: indicates if the token is lower case, upper case, or title case.

For a simpler implementation, the tokenizers library is used training with 'stemmed' (lower case, without space) data and manually augmenting the trained tokenizer with the extended vocabulary.

In [None]:
for lang in languages:
  print(lang)
  dataset = load_dataset("Gabrui/multilingual_TinyStories", lang)
  
  special_tokens = ["<|PAD|>", "<|BOS|>", "<|EOS|>", "<|UNK|>"]
  special_toks = {'pad_token':"<|PAD|>", 'bos_token':"<|BOS|>", 'eos_token':"<|EOS|>", 'unk_token':"<|UNK|>"}
  def batch_iterator(batch_size=1000):
      tok_dataset = dataset['train'].select_columns("story")
      for batch in tok_dataset.iter(batch_size):
          yield batch["story"]
  
  
  tokenizer = Tokenizer(models.BPE(unk_token='<|UNK|>'))
  tokenizer.normalizer = normalizers.Sequence([normalizers.Lowercase(), normalizers.NFKC()])
  tokenizer.pre_tokenizer = pre_tokenizers.Sequence([pre_tokenizers.BertPreTokenizer(), pre_tokenizers.UnicodeScripts(),
                                                     pre_tokenizers.Digits(individual_digits=True), pre_tokenizers.Metaspace(prepend_scheme='never')])
  tokenizer.decoder = decoders.Metaspace(prepend_scheme='never')
  tokenizer.add_special_tokens(list(special_toks.values()))
  trainer = trainers.BpeTrainer(vocab_size=15000, special_tokens=special_tokens, show_progress=True)
  tokenizer.train_from_iterator(batch_iterator(), trainer, length=len(dataset['train']))
  
  # For inference it should be lossless
  tokenizer.normalizer = normalizers.NFKC()
  tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(prepend_scheme='never')
  tokenizer.post_processor = processors.TemplateProcessing(
      single="<|BOS|> $A <|EOS|>",
      pair="<|BOS|> $A <|EOS|> <|BOS|>:1 $B:1 <|EOS|>:1",
      special_tokens=[
          ("<|BOS|>", tokenizer.token_to_id("<|BOS|>")),
          ("<|EOS|>", tokenizer.token_to_id("<|EOS|>")),
      ],
  )
  tokenizer_data = json.loads(tokenizer.to_str())
  vocab_stem = tokenizer_data['model']['vocab']
  merge_stem = tokenizer_data['model']['merges']
  
  items = list(vocab_stem.items())
  vocab = {k: v for k, v in items[:len(special_tokens)]}
  id2multi = {v: (v, Whitespace.NONE, Case.LOWER) for _, v in items[:len(special_tokens)]}
  merge = []
  strange = []
  cid = len(special_tokens)
  n0 = len(special_tokens)
  assert cid == items[n0][-1]
  
  # Initializing vocab and checking charset
  for tok, tid in items[n0:]:
    if len(tok) > 1:
      break
    vocab[tok] = cid
    id2multi[cid] = (tid, Whitespace.NONE, Case.LOWER)
    cid += 1
    # Adding uppercase for single characters
    if tok != tok.upper() and vocab.get(tok.upper()) is None:
      if len(tok.upper()) > 1:
        strange.append((tok, cid, tok.upper()))
        continue
      vocab[tok.upper()] = cid
      id2multi[cid] = (tid, Whitespace.NONE, Case.UPPER)
      cid += 1
  
  # Adding whitespaces characters to vocab
  t0 = tid + 1
  re_add = ['▁', '\t', '\n']
  for i, tok in enumerate(re_add):
    vocab[tok] = cid
    id2multi[cid] = (t0+i, Whitespace.NONE, Case.LOWER)
    cid += 1
  n_re = len(re_add)
  
  # Adding merge with space
  for tok, nid in list(vocab.items())[n0:-n_re]:
    merge.append(['▁', tok])
    vocab[f'▁{tok}'] = cid
    id2multi[cid] = (id2multi[nid][0], Whitespace.SPACE, id2multi[nid][-1])
    cid += 1
  
  # Adding the rest, all combinations if not repeated
  assert len(merge_stem) == len(vocab_stem) - t0 + 1
  for i, (tok, tid) in enumerate(items[t0-1:]):
    for case in Case:
      for wspc in Whitespace:
        rtok = tok
        r0, r1 = merge_stem[i]
        if case is Case.UPPER:
          rtok = tok.upper()
          r0, r1 = r0.upper(), r1.upper()
        elif case is Case.TITLE:
          rtok = tok.title()
          r0 = r0.title()
        if rtok == tok and case is not Case.LOWER:
          continue
        if wspc is Whitespace.SPACE:
          rtok = f'▁{rtok}'
          r0 = f'▁{r0}'
        if vocab.get(rtok) is not None:
          continue
        merge.append([r0, r1])
        vocab[rtok] = cid
        id2multi[cid] = (tid+n_re, wspc, case)
        cid += 1
  
  # Verify if everything is all right
  i = 0
  for (tok, cid), (cid2, tupl) in zip(vocab.items(), id2multi.items()):
    assert cid == i
    i += 1
    assert cid == cid2
    tid, wspc, case = tupl
    if tid < n0:
      continue
    if wspc is Whitespace.SPACE:
      assert tok[0] == '▁'
    if case is Case.LOWER:
      assert tok.lower() == tok
    elif case is Case.UPPER:
      assert tok.upper() == tok
      assert tok.lower() != tok
    else:
      assert tok.title() == tok
      assert tok.lower() != tok
    tok_orig = (tok[1:] if wspc is Whitespace.SPACE else tok).lower()
    if tid < t0:
      assert vocab_stem[tok_orig] == tid
    elif tid >= t0 + n_re:
      if vocab_stem.get(tok_orig) is None:
        tok_orig_real = items[tid - n_re][0]
        assert tok_orig.upper().lower() == tok_orig_real.lower().upper().lower()
        tok_orig = tok_orig_real
      assert vocab_stem[tok_orig] == tid - n_re or tok_orig.upper().lower() == items[tid - n_re][0].upper().lower()
  for i, mer in enumerate(merge):
    assert vocab[''.join(mer)] > i
  for orig, i, upper in strange:
    assert vocab[upper]
  
  id2multi_arr = np.array([(v[0], v[1].value, v[2].value) for k, v in id2multi.items()])
  tokenizer_data_extended = copy.deepcopy(tokenizer_data)
  tokenizer_data_extended['model']['vocab'] = vocab
  tokenizer_data_extended['model']['merges'] = merge
  tokenizer_extended = Tokenizer.from_str(json.dumps(tokenizer_data_extended))
  fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer_extended)
  fast_tokenizer.add_special_tokens(special_toks)
  fast_tokenizer.save_pretrained(f'tokenizers/{lang}_multi')
  np.savetxt(f'tokenizers/{lang}_multi/multi.txt', id2multi_arr, fmt='%d')