## Import statements

In [57]:
from datasets import load_dataset
from tokenizers import (
    decoders,
    models,
    normalizers,
    pre_tokenizers,
    processors,
    trainers,
    Tokenizer,
)
from transformers import AutoTokenizer, PreTrainedTokenizerFast

## Data loading

In [58]:
dataset = load_dataset("rotten_tomatoes", split="train")


def get_training_corpus():
    for i in range(0, len(dataset), 1000):
        yield dataset[i: i + 1000]["text"]

## Initialization/training

In [59]:
tokenizer = Tokenizer(models.BPE(unk_token="<UNK>"))
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True)
special_tokens = ["<PAD>", "<EOS>", "<UNK>"]
trainer = trainers.BpeTrainer(
    vocab_size=1024,
    special_tokens=special_tokens,
    initial_alphabet=pre_tokenizers.ByteLevel.alphabet()
)
tokenizer.train_from_iterator(get_training_corpus(), trainer=trainer)
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
tokenizer.decoder = decoders.ByteLevel()






## Testing

In [60]:
encoding = tokenizer.encode("Let's test this tokenizer.")
print(encoding.tokens)
print(encoding.ids)
print(tokenizer.decode(encoding.ids))
print(len(tokenizer.get_vocab()))
print(tokenizer.get_vocab_size())

['Ġ', 'L', 'et', "'s", 'Ġt', 'est', 'Ġthis', 'Ġto', 'k', 'en', 'iz', 'er', '.']
[223, 46, 328, 308, 259, 375, 355, 296, 77, 271, 577, 264, 16]
 Let's test this tokenizer.
1024
1024


## Saving

In [61]:
wrapped_tokenizer = PreTrainedTokenizerFast(
    tokenizer_object=tokenizer,
    pad_token="<PAD>",
    eos_token="<EOS>",
    unk_token="<UNK>",
    # cls_token="[CLS]",
    # sep_token="[SEP]",
    # mask_token="[MASK]",
)

wrapped_tokenizer.save_pretrained("../../saved_models/tokenizers/rotten_tomatoes_bpe_style")

('../../saved_models/tokenizers/rotten_tomatoes_bpe_style/tokenizer_config.json',
 '../../saved_models/tokenizers/rotten_tomatoes_bpe_style/special_tokens_map.json',
 '../../saved_models/tokenizers/rotten_tomatoes_bpe_style/tokenizer.json')

In [62]:
tok = AutoTokenizer.from_pretrained("../../saved_models/tokenizers/rotten_tomatoes_bpe_style")

tokens = tok.tokenize("Test number 2.!@#$%^&*()")

print(tokens)
ids = tok.convert_tokens_to_ids(tokens)
print(ids)
decoded_string = tok.decode(ids)
print(decoded_string)
print(len(tok.get_vocab()))

['T', 'est', 'Ġn', 'um', 'ber', 'Ġ', '2', '.', '!', '@', '#', '$', '%', '^', '&', '*', '(', ')']
[54, 375, 307, 376, 702, 223, 20, 16, 3, 34, 5, 6, 7, 64, 8, 12, 10, 11]
Test number 2.!@#$%^&*()
1024
