# Tokenization for Language Modeling

## Questions to answer
1. Is byte-pair encoding information theory optimal? How do you know if something is optimal? In other words, does an uniform distribution across the vocabulary tokens represent an uniform distribution across a language? Does this question even make sense? If not, what is the right way to think about how well the tokenization represents the prevalence of individual tokens?

2. How do modern language models pick the vocabulary size? 

## Overview

Auto-Regressive Language Models are classifers. They produce a probability distribution of the prediction of what comes next on the space defined by the vocabulary. A naive way is to use characters as this vocabulary. The drawback of this is that a lot of time and resources is spent on learning how to build words. It's a lot better to give the language models the words we already know and use. The drawback with using words however is that it limits the model's ability to learn and use new words. From an information theory perspective, it also doesn't represent the importance of these words in our day to day usage. Tokenization with Byte Pair Encoding is a sweet spot between characters and words where we build our vocabulary with chunks of characters based on how common they are in our language. This gives the language model the most common words so it doesnt need to learn them from scratch alongside the ability to learn and use new words or just words that are not as common. Tokens are chunks of commonly occuring sequences of characters in language. Tokenization is the process of translating text into sequences of tokens and vice versa. 

## Research History

* First defined by Phil Gage for Data Compression in http://www.pennelynn.com/Documents/CUJ/HTML/94HTML/19940045.HTM 
* Byte Pair Encoding was introduced in the context of language modeling for Neural Machine Translation of Rare Words with Subword Units: https://arxiv.org/pdf/1508.07909
    * a
* And operationalized into large language models in GPT-2: https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf
    * With slight modifications: 

* More recently, there was research into how larger models need larger vocabulary. Scaling Laws with Vocabulary: Larger Models Deserve Larger Vocabularies https://arxiv.org/abs/2407.13623
    *  a


## Other Relevant Papers
* Language Models spend a lot of early layer compute translating tokens to meaningful units, and late layers reversing this: https://arxiv.org/pdf/2406.09519.
* (biomedical domain llm trained with a specialized tokenizer) Specialized domains benefit from domain-specific tokenization: https://link.springer.com/chapter/10.1007/978-3-031-41682-8_2 
* Byte Pair Encoding is sub-optimal?: https://arxiv.org/abs/2004.03720 
* Section III C. of Large Language Models: A Survey: https://arxiv.org/pdf/2402.06196 : Brief intros to bpe, wpe, and spe. No substance.
* Tokenization is More Than Compression: https://arxiv.org/abs/2402.18376
* (tokenization introduces a sampling bias) Understanding and Mitigating Tokenization Bias in Language Models: https://arxiv.org/abs/2406.16829
* Pre-GPT4 research on how tokenization affects neural machine translation. How Much Does Tokenization Affect Neural Machine Translation?: https://arxiv.org/abs/1812.08621 
* (Nepali language tokenization??) Can Perplexity Predict Fine-Tuning Performance? An Investigation of Tokenization Effects on Sequential Language Models for Nepali: https://arxiv.org/pdf/2404.18071v1 
* Tokenization Matters! Degrading Large Language Models through Challenging Their Tokenization: https://arxiv.org/pdf/2405.17067
* (Mechanistically, how do LLMs convert such arbitrary groups of tokens into useful higher-level representations?) Token Erasure as a Footprint of Implicit Vocabulary Items in LLMs: https://arxiv.org/abs/2406.20086
* Tries to do away with tokenization altogether. MEGABYTE: Predicting Million-byte Sequences with Multiscale Transformers https://arxiv.org/abs/2305.07185 

## Some known effects of Tokenization

* LLMs can't spell words perfectly.
* Basic arithmetic is difficult in a hilarious way.
* Simple string processing (especially character level) like counting the number of characters, reversing words.
* Worse ability on non-english languages - because tokenization is primarily done on english language
* etc.... 


## Tokenization Playground
https://tiktokenizer.vercel.app/ 

Tokenization is kind of arbitrary. Spaces sometimes change the same word from being one token to multiple tokens. Numbers are tokenized in a way that looks completely random to us. Non-english languages are broken up to more tokens to represent the same sentence -> this is part of why non-english performance is worse. GPT-2 tokenizer struggled with coding partly because each whitespace was a separate token and that was wasteful, the newer models fix this. GPT-4 tokenizer drops token count for same information from GPT-2, cause the vocabulary size is double. And so on...

## Implementation from Scratch
Along the lines of Andrej Karpathy's tutorial on GPT Tokenizer: https://www.youtube.com/watch?v=zduSFxRajkE&t=1430s 

### Unicode
Cant use all of unicode code points (is what they call characters), because 
1. it's too large = 150k 
2. and it keeps changing. 

Solution: encodings like
* utf-8 (1-4 bytes, most common, backwards compatible with the older ascii encoding)
* utf-16
* utf-32 (fixed length but a bit wasteful)
More on unicode: https://www.reedbeta.com/blog/programmers-intro-to-unicode/


If we use utf-8 as the entire vocabulary (character level - ish), there's only 256 elements. Makes context lengths really long when we represent 
english language in it. Transformers (specifically the Attention layers) scale quadratic with context length. We can do better. 
Solution: Byte Pair Encoding Algorithm. 

Side: Feeding raw bytes makes language models end to end, and potentially mitigates some of the quirks of tokenization? But it doesnt work as far as we know. MEGABYTE studies if this can be done - need to modify architecture, hasnt been verified/reproduced/studied by enough people.

In [14]:
# unicode standard
print("h: ", ord("h"))
print("µ: ", ord('µ')) 

print("hello world bytes in utf-8:", list("hello world".encode("utf-8")))
print("hello world bytes in utf-32 (wasteful!):", list("hello world".encode("utf-32")))

h:  104
µ:  181
hello world bytes in utf-8: [104, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100]
hello world bytes in utf-32 (wasteful!): [255, 254, 0, 0, 104, 0, 0, 0, 101, 0, 0, 0, 108, 0, 0, 0, 108, 0, 0, 0, 111, 0, 0, 0, 32, 0, 0, 0, 119, 0, 0, 0, 111, 0, 0, 0, 114, 0, 0, 0, 108, 0, 0, 0, 100, 0, 0, 0]


### Byte Pair Encoding 
(AKA Digram coding)
Wikipedia page has good explanation: https://en.wikipedia.org/wiki/Byte_pair_encoding 

Iteratively:
* Find the pair of tokens that occur most frequently
* Replace that pair with a single token that we replace their occurence with it. 
* Until there are no pairs of bytes that occur more than once. In practice: the stopping condition is when a vocabulary of desired size is obtained. Vocabulary size is typically a hyperparameter.

Increases vocabulary size, but compresses sequence length. 

In [110]:
# example text: wikipedia's bpe article with some random unicode at staart
global_text = "Ｕｎｉｃｏｄｅ! 🅤🅝🅘🅒🅞🅓🅔‽ 🇺‌🇳‌🇮‌🇨‌🇴‌🇩‌🇪! 😄 ˜ßµœπœ∫¥¨©∂ø ^ _ . º ' ` ˛ \xFF Byte pair encoding[1][2] (also known as digram coding)[3] is an algorithm, first described in 1994 by Philip Gage for encoding strings of text into tabular form for use in downstream modeling.[4] Its modification is notable as the large language model tokenizer with an ability to combine both tokens that encode single characters (including single digits or single punctuation marks) and those that encode whole words (even the longest compound words).[5][6][7] This modification, in the first step, assumes all unique characters to be an initial set of 1-character long n-grams (i.e. initial 'tokens'). Then, successively the most frequent pair of adjacent characters is merged into a new, 2-character long n-gram and all instances of the pair are replaced by this new token. This is repeated until a vocabulary of prescribed size is obtained. Note that new words can always be constructed from final vocabulary tokens and initial-set characters.[8] This algorithmic approach have been extended from spoken language to sign language in recent years.[9]All the unique tokens found in a corpus are listed in a token vocabulary, the size of which, in the case of GPT-3.5 and GPT-4, is 100256.. The difference between the modified and the original algorithm is that the original algorithm does not merge the most frequent pair of bytes of data, but replaces them by a new byte that was not contained in the initial dataset. A lookup table of the replacements is required to rebuild the initial dataset. The algorithm is effective for tokenization because it has low computational overhead and remains consistent and reliable. Original algorithm: The original algorithm operates by iteratively replacing the most common contiguous sequences of characters in a target text with unused 'placeholder' bytes. The iteration ends when no sequences can be found, leaving the target text effectively compressed. Decompression can be performed by reversing this process, querying known placeholder terms against their corresponding denoted sequence, using a lookup table. In the original paper, this lookup table is encoded and stored alongside the compressed text. Example: Suppose the data to be encoded is aaabdaaabac The byte pair 'aa' occurs most often, so it will be replaced by a byte that is not used in the data, such as 'Z'. Now there is the following data and replacement table: ZabdZabac, Z=aa, Then the process is repeated with byte pair 'ab', replacing it with 'Y': ZYdZYac, Y=ab, Z=aa. The only literal byte pair left occurs only once, and the encoding might stop here. Alternatively, the process could continue with recursive byte pair encoding, replacing 'ZY' with 'X': XdXac, X=ZY, Y=ab, Z=aa. This data cannot be compressed further by byte pair encoding because there are no pairs of bytes that occur more than once. To decompress the data, simply perform the replacements in the reverse order."
global_tokens = list(map(int, global_text.encode("utf-8"))) 

print(global_text)
print(len(global_text))
print("---------------------------------")
print(global_tokens)
print(len(global_tokens))
# tokens are longer than text cause the special characters need more than one byte. 


Ｕｎｉｃｏｄｅ! 🅤🅝🅘🅒🅞🅓🅔‽ 🇺‌🇳‌🇮‌🇨‌🇴‌🇩‌🇪! 😄 ˜ßµœπœ∫¥¨©∂ø ^ _ . º ' ` ˛ ÿ Byte pair encoding[1][2] (also known as digram coding)[3] is an algorithm, first described in 1994 by Philip Gage for encoding strings of text into tabular form for use in downstream modeling.[4] Its modification is notable as the large language model tokenizer with an ability to combine both tokens that encode single characters (including single digits or single punctuation marks) and those that encode whole words (even the longest compound words).[5][6][7] This modification, in the first step, assumes all unique characters to be an initial set of 1-character long n-grams (i.e. initial 'tokens'). Then, successively the most frequent pair of adjacent characters is merged into a new, 2-character long n-gram and all instances of the pair are replaced by this new token. This is repeated until a vocabulary of prescribed size is obtained. Note that new words can always be constructed from final vocabulary tokens and initial-set

In [52]:
def get_stats(ids):
    """Find the most common pair"""
    counts = {}
    for ch_pair in zip(ids[:-1], ids[1:]):
        counts[ch_pair] = counts.get(ch_pair, 0) + 1
    return counts

stats = get_stats(tokens)
sorted_counts = sorted(((v,k) for k, v in stats.items()), reverse=True)
print(sorted_counts)

# fully decoding with .decode like below to convert to strings doesnt work when the encoding is variable length.
# sorted_counts_ch = [(v, bytes(list(k)).decode("utf-8")) for v, k in sorted_counts]
sorted_counts_ch = [(v, (chr(k[0]), chr(k[1]))) for v, k in sorted_counts]
print(sorted_counts_ch)

# lot of words end with e or s, and start with t. h commonly follows t -> the is the most common english word. 

[(89, (101, 32)), (68, (32, 116)), (62, (115, 32)), (59, (116, 104)), (53, (105, 110)), (44, (104, 101)), (44, (32, 97)), (40, (101, 110)), (39, (114, 101)), (38, (116, 101)), (38, (116, 32)), (38, (32, 105)), (37, (100, 32)), (35, (110, 32)), (31, (101, 114)), (30, (32, 98)), (29, (32, 111)), (28, (110, 103)), (27, (99, 111)), (27, (44, 32)), (26, (111, 114)), (26, (105, 116)), (26, (97, 116)), (26, (32, 99)), (25, (97, 108)), (24, (101, 100)), (24, (97, 99)), (23, (105, 115)), (23, (101, 115)), (22, (111, 110)), (22, (97, 110)), (21, (114, 32)), (21, (99, 101)), (20, (116, 97)), (19, (115, 101)), (19, (108, 97)), (19, (97, 114)), (19, (32, 115)), (19, (32, 114)), (19, (32, 112)), (18, (116, 105)), (18, (110, 99)), (18, (108, 32)), (18, (97, 98)), (18, (46, 32)), (17, (116, 111)), (17, (115, 116)), (17, (104, 97)), (17, (98, 121)), (16, (121, 32)), (16, (110, 100)), (16, (32, 100)), (15, (240, 159)), (15, (114, 105)), (15, (111, 100)), (15, (103, 32)), (15, (32, 108)), (15, (32, 102))

### Try to implement this algorithm without following Andrej Karpathy. It's like a leetcode puzzle.

In [73]:
from collections import namedtuple
Frequency = namedtuple("Frequency", ["pair", "count"])

In [143]:
def get_most_frequent(ids):
    """Given a sequence of token ids, find the most frequent"""
    most_frequent = Frequency(pair=None, count=0)
    counts = {}
    for pair in zip(ids[:-1], ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
        if counts[pair] > most_frequent.count:
            most_frequent = Frequency(pair=pair, count=counts[pair])
    return most_frequent

def merge_tokens(ids, id1, id2, id_new):
    """Merge tokens 'id1'+'id2' in the sequence 'ids' into 'id_new', in place."""
    if len(ids) < 2:
        return
    write_i = 0
    read_i = 0
    while read_i < len(ids):
        if read_i == len(ids)-1:
            ids[write_i] = ids[read_i]
        elif ids[read_i] == id1 and ids[read_i+1] == id2:
            ids[write_i] = id_new
            read_i += 1 # skip next token because its merged
        else:
            ids[write_i] = ids[read_i]
        read_i += 1
        write_i += 1
    ids[:] = ids[:write_i]

In [144]:
def test_get_most_frequent(ids, expected_most_frequent):
    most_frequent = get_most_frequent(ids)
    assert most_frequent.pair == expected_most_frequent.pair, f"pair mismatch: expected {expected_most_frequent.pair}, got {most_frequent.pair}"
    assert most_frequent.count == expected_most_frequent.count, f"count mismatch: expected {expected_most_frequent.count}, got {most_frequent.count}"

most_frequent_test_cases = [
    ([], Frequency(None, 0)),
    ([1], Frequency(None, 0)),
    ([1, 2, 3, 4, 5], Frequency((1,2), 1)),
    ([1, 2, 1, 2, 3, 4], Frequency((1, 2), 2)),
    ([1, 2, 3, 4, 1, 2, 3, 4, 1, 2], Frequency((1, 2), 3)),
    ([1, 1, 1, 1], Frequency((1, 1), 3)),
    ([1000000, 2000000, 1000000, 2000000], Frequency((1000000, 2000000), 2)),
    ([1, 3, 3, 2, 2, 3, 3, 3], Frequency((3, 3), 3)),
    ([1, 2, 3, 2, 3, 1], Frequency((2, 3), 2)),
    ([1, 1, 2, 2, 1, 1, 3, 3, 1, 1], Frequency((1, 1), 3))
]
print("\nTESTING get_most_frequent ...")
fail_count = 0
for i, most_frequent_test in enumerate(most_frequent_test_cases):
    print(f"test case {i}: {most_frequent_test}")
    try:
        test_get_most_frequent(*most_frequent_test)
        print("\tPASS")
    except Exception as e:
        fail_count += 1
        print(f"xxxxxxx FAIL. {str(e)}")

print("-----------------------------------")
print(f"{len(most_frequent_test_cases) - fail_count} PASSED, {fail_count} FAILED.")
print("-----------------------------------")


TESTING get_most_frequent ...
test case 0: ([], Frequency(pair=None, count=0))
	PASS
test case 1: ([1], Frequency(pair=None, count=0))
	PASS
test case 2: ([1, 2, 3, 4, 5], Frequency(pair=(1, 2), count=1))
	PASS
test case 3: ([1, 2, 1, 2, 3, 4], Frequency(pair=(1, 2), count=2))
	PASS
test case 4: ([1, 2, 3, 4, 1, 2, 3, 4, 1, 2], Frequency(pair=(1, 2), count=3))
	PASS
test case 5: ([1, 1, 1, 1], Frequency(pair=(1, 1), count=3))
	PASS
test case 6: ([1000000, 2000000, 1000000, 2000000], Frequency(pair=(1000000, 2000000), count=2))
	PASS
test case 7: ([1, 3, 3, 2, 2, 3, 3, 3], Frequency(pair=(3, 3), count=3))
	PASS
test case 8: ([1, 2, 3, 2, 3, 1], Frequency(pair=(2, 3), count=2))
	PASS
test case 9: ([1, 1, 2, 2, 1, 1, 3, 3, 1, 1], Frequency(pair=(1, 1), count=3))
	PASS
-----------------------------------
10 PASSED, 0 FAILED.
-----------------------------------


In [145]:
def test_merge_tokens(ids, id1, id2, id_new, expected_ids):
    merge_tokens(ids, id1, id2, id_new)
    assert len(ids) == len(expected_ids), f"expected {len(expected_ids)} long, got {len(ids)} long"
    for i in range(len(expected_ids)):
        # can do check list equal, looping through items makes error message more useful
        assert ids[i] == expected_ids[i], f"Merged tokens mismatch at index i, expected {expected_ids[i]}, got {ids[i]}"

merge_token_tests = [
    ([1, 2, 3, 4], 2, 3, 5, [1, 5, 4]),
    ([1, 2, 3, 2, 3, 4], 2, 3, 5, [1, 5, 5, 4]),
    ([1, 2, 3, 4], 1, 2, 5, [5, 3, 4]),
    ([1, 2, 3, 4], 3, 4, 5, [1, 2, 5]),
    ([1, 2, 4, 3], 2, 3, 5, [1, 2, 4, 3]),
    ([], 1, 2, 3, []),
    ([1], 1, 2, 3, [1]),
    ([1, 2, 1, 2, 1, 2], 1, 2, 3, [3, 3, 3]),
    ([1, 1, 1, 1], 1, 1, 2, [2, 2]),
    ([1, 2, 3, 4, 5], 5, 6, 7, [1, 2, 3, 4, 5])
]
print("\nTESTING merge_tokens...")
fail_count = 0
for i, merge_token_test in enumerate(merge_token_tests):
    print(f"test case {i} {merge_token_test[0]} --> {merge_token_test[1]}+{merge_token_test[2]}={merge_token_test[3]} --> {merge_token_test[4]}")
    try:
        test_merge_tokens(*merge_token_test)
        print("\tPASS")
    except Exception as e:
        fail_count += 1
        print(f"xxxxxx FAIL. {str(e)}")

print("-----------------------------------")
print(f"{len(most_frequent_test_cases) - fail_count} PASSED, {fail_count} FAILED.")
print("-----------------------------------")




TESTING merge_tokens...
test case 0 [1, 2, 3, 4] --> 2+3=5 --> [1, 5, 4]
	PASS
test case 1 [1, 2, 3, 2, 3, 4] --> 2+3=5 --> [1, 5, 5, 4]
	PASS
test case 2 [1, 2, 3, 4] --> 1+2=5 --> [5, 3, 4]
	PASS
test case 3 [1, 2, 3, 4] --> 3+4=5 --> [1, 2, 5]
	PASS
test case 4 [1, 2, 4, 3] --> 2+3=5 --> [1, 2, 4, 3]
	PASS
test case 5 [] --> 1+2=3 --> []
	PASS
test case 6 [1] --> 1+2=3 --> [1]
	PASS
test case 7 [1, 2, 1, 2, 1, 2] --> 1+2=3 --> [3, 3, 3]
	PASS
test case 8 [1, 1, 1, 1] --> 1+1=2 --> [2, 2]
	PASS
test case 9 [1, 2, 3, 4, 5] --> 5+6=7 --> [1, 2, 3, 4, 5]
	PASS
-----------------------------------
10 PASSED, 0 FAILED.
-----------------------------------


In [146]:
def run_bpe(tokens, vocab_size, target_vocab_size, early_stop_no_repeat=False):
    merged_tokens_map = {} # remember how the new tokens map to characters
    def token_to_char(id):
        if id < 256:
            return chr(id)
        else:
            return merged_tokens_map[id]

    while vocab_size < target_vocab_size:
        frequent = get_most_frequent(tokens)
        if frequent.count < 2 and early_stop_no_repeat:
            # no more multiple occurences left
            break
        merge_tokens(tokens, frequent.pair[0], frequent.pair[1], vocab_size)
        merged_tokens_map[vocab_size] = token_to_char(frequent.pair[0]) + token_to_char(frequent.pair[1])
        vocab_size += 1
    return merged_tokens_map

In [147]:
tokens = global_tokens[:]
vocab_size = 256 # because we used utf-8
target_vocab_size = 384
post_utf_map = run_bpe(tokens, vocab_size, target_vocab_size, early_stop_no_repeat=True)
print(f"Global tokens compressed from {len(global_tokens)} to {len(tokens)} tokens, compression ratio = {len(global_tokens)/len(tokens):.2f} ")
print(f"New vocab size = {vocab_size}, New tokens: ", post_utf_map)

Global tokens compressed from 3055 to 1505 tokens, compression ratio = 2.03 
New vocab size = 256, New tokens:  {256: 'e ', 257: 's ', 258: 'th', 259: 'in', 260: 'en', 261: 't ', 262: 'd ', 263: 're', 264: 'er', 265: 'the ', 266: ' a', 267: 'co', 268: ', ', 269: 'ac', 270: 'or', 271: 'is ', 272: 'ing', 273: 'ed ', 274: 'ta', 275: 'al', 276: ' b', 277: '. ', 278: 'to', 279: 'ti', 280: 'on', 281: 'ar', 282: ' the ', 283: 'ð\x9f', 284: 'an', 285: 'ith', 286: 'ir', 287: 'of', 288: 'pl', 289: 'ter', 290: 'Th', 291: 'y ', 292: 'ab', 293: 'pa', 294: 'yt', 295: 'plac', 296: 'cod', 297: 'pair', 298: 'ce', 299: 'ken', 300: 'm ', 301: 'qu', 302: 'mo', 303: 'ch', 304: 'ig', 305: 'ing ', 306: 'encod', 307: 'no', 308: 'com', 309: 'da', 310: 'replac', 311: 'token', 312: 'un', 313: 'se', 314: 'st ', 315: 'yte ', 316: 'pair ', 317: 'data', 318: 'si', 319: 'ð\x9f\x85', 320: 'â\x80', 321: 'ð\x9f\x87', 322: 'gor', 323: 'gorith', 324: 'of ', 325: 'char', 326: 'charac', 327: 'character', 328: ' a ', 329: 'a

In [154]:
# plot target vocab size vs compression ratio
target_vocab_sizes = [vocab_size + i for i in range(1, 1024, 16)]
compression_ratios = []
for target_size in target_vocab_sizes:
    tokens = global_tokens[:]
    run_bpe(tokens, vocab_size, target_size, early_stop_no_repeat=True)
    compression_ratios.append(len(global_tokens)/(1.0*len(tokens)))

In [157]:
## plot generated with Claude, an attempt to get familiar with plotly instead of using matplotlib.

import plotly.graph_objects as go

# Create the figure with a specific width and height
fig = go.Figure()

# Define colors
point_color = '#e74c3c'  # Original dark red color for data points
line_color = '#f7a79e'   # Lighter version of the red for the line

# Add the dotted line plot
fig.add_trace(go.Scatter(
    x=target_vocab_sizes,
    y=compression_ratios,
    mode='lines',
    line=dict(color=line_color, width=2, dash='dot'),
    name='Trend'
))

# Add markers at data points
fig.add_trace(go.Scatter(
    x=target_vocab_sizes,
    y=compression_ratios,
    mode='markers',
    marker=dict(color=point_color, size=8),
    name='Data Points'
))

# Customize the layout
fig.update_layout(
    title={
        'text': 'Compression Ratio of BPE Algorithm',
        'y':0.95,
        'x':0.5,
        'xanchor': 'center',
        'yanchor': 'top',
        'font': dict(size=20, color='#333333')
    },
    xaxis_title={
        'text': "New Vocabulary Size",
        'font': dict(size=14, color='#333333')
    },
    yaxis_title={
        'text': "Compression Ratio<br>(Original / merged length)",
        'font': dict(size=14, color='#333333')
    },
    plot_bgcolor='#f8f8f8',
    paper_bgcolor='#f8f8f8',
    hovermode='x unified',
    hoverlabel=dict(bgcolor="white", font_size=12),
    xaxis=dict(
        showgrid=True,
        gridcolor='lightgrey',
        tickfont=dict(size=12),
        zeroline=False
    ),
    yaxis=dict(
        showgrid=True,
        gridcolor='lightgrey',
        tickfont=dict(size=12),
        zeroline=False
    ),
    width=700,  # Set the width of the plot
    height=500,  # Set the height of the plot
    margin=dict(l=80, r=40, t=80, b=60),  # Adjust margins
    legend=dict(
        yanchor="top",
        y=0.99,
        xanchor="left",
        x=0.01
    )
)

# Show the plot
fig.show()

In [187]:
tokens = global_tokens[:]
vocab_size = 256 # because we used utf-8
target_vocab_size = 384
post_utf_map = run_bpe(tokens, vocab_size, target_vocab_size, early_stop_no_repeat=True)
print(f"Global tokens compressed from {len(global_tokens)} to {len(tokens)} tokens, compression ratio = {len(global_tokens)/len(tokens):.2f} ")
print(f"New vocab size = {vocab_size}, New tokens: ", post_utf_map)

Global tokens compressed from 3055 to 1505 tokens, compression ratio = 2.03 
New vocab size = 256, New tokens:  {256: 'e ', 257: 's ', 258: 'th', 259: 'in', 260: 'en', 261: 't ', 262: 'd ', 263: 're', 264: 'er', 265: 'the ', 266: ' a', 267: 'co', 268: ', ', 269: 'ac', 270: 'or', 271: 'is ', 272: 'ing', 273: 'ed ', 274: 'ta', 275: 'al', 276: ' b', 277: '. ', 278: 'to', 279: 'ti', 280: 'on', 281: 'ar', 282: ' the ', 283: 'ð\x9f', 284: 'an', 285: 'ith', 286: 'ir', 287: 'of', 288: 'pl', 289: 'ter', 290: 'Th', 291: 'y ', 292: 'ab', 293: 'pa', 294: 'yt', 295: 'plac', 296: 'cod', 297: 'pair', 298: 'ce', 299: 'ken', 300: 'm ', 301: 'qu', 302: 'mo', 303: 'ch', 304: 'ig', 305: 'ing ', 306: 'encod', 307: 'no', 308: 'com', 309: 'da', 310: 'replac', 311: 'token', 312: 'un', 313: 'se', 314: 'st ', 315: 'yte ', 316: 'pair ', 317: 'data', 318: 'si', 319: 'ð\x9f\x85', 320: 'â\x80', 321: 'ð\x9f\x87', 322: 'gor', 323: 'gorith', 324: 'of ', 325: 'char', 326: 'charac', 327: 'character', 328: ' a ', 329: 'a

In [180]:
def decode(ids, token_map):
    """Given token ids, return the raw text"""
    token_map.update({i: chr(i) for i in range(256)})
    text = [""] * len(ids)
    for i, id in enumerate(ids):
        text[i] = token_map[id]
    return "".join(text)

def encode(text, token_map):
    """Given raw text, return the token ids"""
    # find out longest token so we can start search here.
    len_longest_token = 1
    for t in token_map.values():
        if len(t) > len_longest_token:
            len_longest_token = len(t)

    inverse_token_map = {v:k for k,v in token_map.items()}
    inverse_token_map.update({chr(i): i for i in range(256)})
    tokens = []
    i = 0
    while i < len(text):
        # search for tokens of length of the longest token and keep
        # reducing the substring length to search until a matching token is found
        search_substr_end = min(i + len_longest_token, len(text)) 
        token_added = False
        while search_substr_end > i:
            if text[i:search_substr_end] in inverse_token_map:
                tokens.append(inverse_token_map[text[i:search_substr_end]])
                token_added = True
                i = search_substr_end
                break                
            else:
                search_substr_end -= 1
        assert token_added, f"Unknown token {text[i]} found at index {i}, unexpected!"
    return tokens

In [171]:
def test_decode(ids, token_map, expected_text):
    text = decode(ids, token_map)
    assert text == expected_text, f"Expected {expected_text}, got {text}."

def test_encode(text, token_map, expected_ids):
    ids = encode(text, token_map)
    assert ids == expected_ids, f"Expected {expected_ids}, got {ids}"

test_cases_decode = [
    ([], {}, ""),  # Empty input
    ([65, 66, 67], {}, "ABC"),  # Basic ASCII characters
    ([256, 65], {256: "Hello "}, "Hello A"),  # Mixed token and ASCII
    ([256, 257, 258], {256: "Hello", 257: " world", 258: "!"}, "Hello world!"),  # All tokens
    ([256, 32, 257], {256: "Hello", 257: "world!"}, "Hello world!"),  # Tokens with space
    ([256, 257, 67], {256: "こん", 257: "にちは"}, "こんにちはC"),  # Non-ASCII characters
    ([256, 257, 258], {256: "a", 257: "b", 258: "c"}, "abc"),  # Single-char tokens
    ([65, 256, 66, 257], {256: "BC", 257: "DE"}, "ABCBDE"),  # Interleaved ASCII and tokens
    ([256, 257, 256], {256: "repeat", 257: " "}, "repeat repeat"),  # Repeated tokens
    ([256, 257, 258, 259], {256: "a", 257: "b", 258: "c", 259: "defg"}, "abcdefg"),  # Mixed length tokens
    ([256, 65, 257, 66], {256: "Hello", 257: "World"}, "HelloAWorldB")  # Tokens with interspersed ASCII
]
print("\nTESTING decode...")
fail_count = 0
for i, t_d in enumerate(test_cases_decode):
    print(f"test case {i}: {t_d[0]} -> {t_d[1]} -> {t_d[2]}")
    try:
        test_decode(*t_d)
        print("\tPASS")
    except Exception as e:
        fail_count += 1
        print(f"...... FAIL. {str(e)}")
print("-------------------------")
print(f"{len(test_cases_decode) - fail_count} PASS, {fail_count} FAIL.")
print("-------------------------")


TESTING decode...
test case 0: [] -> {} -> 
	PASS
test case 1: [65, 66, 67] -> {} -> ABC
	PASS
test case 2: [256, 65] -> {256: 'Hello '} -> Hello A
	PASS
test case 3: [256, 257, 258] -> {256: 'Hello', 257: ' world', 258: '!'} -> Hello world!
	PASS
test case 4: [256, 32, 257] -> {256: 'Hello', 257: 'world!'} -> Hello world!
	PASS
test case 5: [256, 257, 67] -> {256: 'こん', 257: 'にちは'} -> こんにちはC
	PASS
test case 6: [256, 257, 258] -> {256: 'a', 257: 'b', 258: 'c'} -> abc
	PASS
test case 7: [65, 256, 66, 257] -> {256: 'BC', 257: 'DE'} -> ABCBDE
	PASS
test case 8: [256, 257, 256] -> {256: 'repeat', 257: ' '} -> repeat repeat
	PASS
test case 9: [256, 257, 258, 259] -> {256: 'a', 257: 'b', 258: 'c', 259: 'defg'} -> abcdefg
	PASS
test case 10: [256, 65, 257, 66] -> {256: 'Hello', 257: 'World'} -> HelloAWorldB
	PASS
-------------------------
11 PASS, 0 FAIL.
-------------------------


In [184]:
test_cases_encode = [
    ("", {}, []),  # Empty input
    ("ABC", {}, [65, 66, 67]),  # Basic ASCII characters
    ("Hello A", {256: "Hello "}, [256, 65]),  # Mixed token and ASCII
    ("Hello world!", {256: "Hello", 257: " world"}, [256, 257, 33]),  # All tokens
    ("Hello world!", {256: "Hello", 257: "world!"}, [256, 32, 257]),  # Tokens with space
    ("こんにちはC", {256: "こん", 257: "にちは"}, [256, 257, 67]),  # Non-ASCII characters
    ("abc", {256: "ab", 257: "bc", 258: "ca"}, [256, 99]),  # Single-char tokens
    ("ABCBDE", {256: "BC", 257: "DE"}, [65, 256, 66, 257]),  # Interleaved ASCII and tokens
    ("repeat repeat", {256: "repeat"}, [256, 32, 256]),  # Repeated tokens
    ("abcdefg", {256: "ab", 259: "defg"}, [256, 99, 259]),  # Mixed length tokens
    ("HelloAWorldB", {256: "Hello", 257: "World"}, [256, 65, 257, 66]),  # Tokens with interspersed ASCII
    ("TestNoMatch", {256: "Other"}, [84, 101, 115, 116, 78, 111, 77, 97, 116, 99, 104]),  # No matching tokens
    ("ABCabc", {256: "ABC", 257: "abc"}, [256, 257]),  # Case-sensitive tokens
    ("ABCABCABC", {256: "ABC"}, [256, 256, 256]),  # Repeated token
    ("A B C", {256: "A B C", 257: "A B"}, [256])  # Longest match priority
]

print("\nTESTING encode...")
fail_count = 0
for i, t_e in enumerate(test_cases_encode):
    print(f"test case {i}: {t_e[0]} -> {t_e[1]} -> {t_e[2]}")
    try:
        test_encode(*t_e)
        print("\tPASS")
    except Exception as e:
        fail_count += 1
        print(f"xxxxxx FAIL. {str(e)}")
print("-------------------------")
print(f"{len(test_cases_encode) - fail_count} PASS, {fail_count} FAIL.")
print("-------------------------")



TESTING encode...
test case 0:  -> {} -> []
	PASS
test case 1: ABC -> {} -> [65, 66, 67]
	PASS
test case 2: Hello A -> {256: 'Hello '} -> [256, 65]
	PASS
test case 3: Hello world! -> {256: 'Hello', 257: ' world'} -> [256, 257, 33]
	PASS
test case 4: Hello world! -> {256: 'Hello', 257: 'world!'} -> [256, 32, 257]
	PASS
test case 5: こんにちはC -> {256: 'こん', 257: 'にちは'} -> [256, 257, 67]
	PASS
test case 6: abc -> {256: 'ab', 257: 'bc', 258: 'ca'} -> [256, 99]
	PASS
test case 7: ABCBDE -> {256: 'BC', 257: 'DE'} -> [65, 256, 66, 257]
	PASS
test case 8: repeat repeat -> {256: 'repeat'} -> [256, 32, 256]
	PASS
test case 9: abcdefg -> {256: 'ab', 259: 'defg'} -> [256, 99, 259]
	PASS
test case 10: HelloAWorldB -> {256: 'Hello', 257: 'World'} -> [256, 65, 257, 66]
	PASS
test case 11: TestNoMatch -> {256: 'Other'} -> [84, 101, 115, 116, 78, 111, 77, 97, 116, 99, 104]
	PASS
test case 12: ABCabc -> {256: 'ABC', 257: 'abc'} -> [256, 257]
	PASS
test case 13: ABCABCABC -> {256: 'ABC'} -> [256, 256, 256]

In [188]:
print(decode(encode(global_text, post_utf_map), post_utf_map) == global_text)

AssertionError: Unknown token Ｕ found at index 0, unexpected!

In [190]:
print(decode(tokens, post_utf_map))

ï¼µï½ï½ï½ï½ï½ï½! ð¤ððððððâ½ ðºâð³âð®âð¨âð´âð©âðª! ð ËÃÂµÅÏÅâ«Â¥Â¨Â©âÃ¸ ^ _ . Âº ' ` Ë Ã¿ Byte pair encoding[1][2] (also known as digram coding)[3] is an algorithm, first described in 1994 by Philip Gage for encoding strings of text into tabular form for use in downstream modeling.[4] Its modification is notable as the large language model tokenizer with an ability to combine both tokens that encode single characters (including single digits or single punctuation marks) and those that encode whole words (even the longest compound words).[5][6][7] This modification, in the first step, assumes all unique characters to be an initial set of 1-character long n-grams (i.e. initial 'tokens'). Then, successively the most frequent pair of adjacent characters is merged into a new, 2-character long n-gram and all instances of the pair are replaced by this new token. This is repeated until a vocabulary of prescribed size is obtained. 

**Differences in Andrej Karpathy's *naive* implementation**
* Same split of helper functions: get_most_common and merge
* In get most common: he builds a dictionary of counts, and uses python's max() function with the count key.
* The merge logic is effectively the same, he uses one pointer, I use two pointers which I found more intuitive.
* The iterative loop is the same too, except he does a for loop.
* For encoding and decoding, he works in bytes, I work directly on strings. The prod models work in bytes, so when the llm outputs a non-valid utf-8 byte sequence, we can replace it with an error character. Working in strings means we can decode erroneously for the special chars whose representations are multiple bytes. We can see that error when we encode and decode the entire global string. 

## Other Tokenizers
* WordPieceEncoding: used in BERT and Electra. Takes all alphabets from input data to make sure nothing will be UNK (unknown), also tries to for tokens based on their frequency.
* SentencePieceEncoding: Others assume white-space is the defacto word separated, which may not always be true. This addresses that issue? 
* Maximum Prefix Encoding: ?? 