# ByteTok: _A simple tokenizer_


The dataset used here is from Hugging Face: [Sci-Fi-Books-gutenberg](https://huggingface.co/datasets/stevez80/Sci-Fi-Books-gutenberg).
To access the dataset, you may need to sign up for a Hugging Face account and generate an access token. See [docs](https://huggingface.co/docs) for more info.



In [1]:
from datasets import load_dataset

ds = load_dataset("stevez80/Sci-Fi-Books-gutenberg", split="train")

In [None]:
ds[0]["text"][1000:2000]

"b'\\xef\\xbb\\xbfThe Project Gutenberg eBook of Frankenstein; Or, The Modern Prometheus\\r\\n    \\r\\nThis ebook is for the use of anyone anywhere in the United States and\\r\\nmost other parts of the world at no cost and with almost no restrictions\\r\\nwhatsoever. You may copy it, give it away or re-use it under the terms\\r\\nof the Project Gutenberg License included with this ebook or online\\r\\nat www.gutenberg.org. If you are not located in the United States,\\r\\nyou will have to check the laws of the country where you are located\\r\\nbefore using this eBook.\\r\\n\\r\\nTitle: Frankenstein; Or, The Modern Prometheus\\r\\n\\r\\n\\r\\nAuthor: Mary Wollstonecraft Shelley\\r\\n\\r\\nRelease date: October 1, 1993 [eBook #84]\\r\\n                Most recently updated: December 2, 2022\\r\\n\\r\\nLanguage: English\\r\\n\\r\\nCredits: Judith Boss, Christy Phillips, Lynn Hanninen and David Meltzer. HTML version by Al Haines.\\r\\n        Further corrections by Menno de Leeuw.\\r\\n\

## UTF encoding

A Unicode code point is a single numeric value in the Unicode range (e.g., U+0041 for A, U+1F600 for üòÄ).

The byte sequence is an encoding of that code point (e.g., UTF‚Äë8 uses 1‚Äì4 bytes to encode the single code point).

Watch the first 20 mins of [this video](https://youtu.be/vpSkBV5vydg) for an explanation on utf-8.


### Text as a sequence of raw bytes


In [None]:
# loading 1000 examples here for faster training
tokens = "".join(ds[:1000]["text"]).encode("utf-8")
print(f"{tokens[:100]=}")
print(f"\n{len(tokens)=}")
print(f"\n{type(tokens)=}")

tokens[:100]=b"b'\\xef\\xbb\\xbfThe Project Gutenberg eBook of Frankenstein; Or, The Modern Prometheus\\r\\n    \\r\\nThis"

len(tokens)=479771850

type(tokens)=<class 'bytes'>


### Convert each byte to integer


In [4]:
tokens = list(tokens)

print(f"{tokens[:100]=}")
print(f"\n{len(tokens)=}")
print(f"\n{type(tokens)=}")

tokens[:100]=[98, 39, 92, 120, 101, 102, 92, 120, 98, 98, 92, 120, 98, 102, 84, 104, 101, 32, 80, 114, 111, 106, 101, 99, 116, 32, 71, 117, 116, 101, 110, 98, 101, 114, 103, 32, 101, 66, 111, 111, 107, 32, 111, 102, 32, 70, 114, 97, 110, 107, 101, 110, 115, 116, 101, 105, 110, 59, 32, 79, 114, 44, 32, 84, 104, 101, 32, 77, 111, 100, 101, 114, 110, 32, 80, 114, 111, 109, 101, 116, 104, 101, 117, 115, 92, 114, 92, 110, 32, 32, 32, 32, 92, 114, 92, 110, 84, 104, 105, 115]

len(tokens)=479771850

type(tokens)=<class 'list'>


## Byte Pair Encoding


### Encoding



#### Get frequency of all byte pairs


In [5]:
type BpFreqStore = list[tuple[tuple[int, int], int]]


def get_bp_freqs(toks: list[int]) -> BpFreqStore:
    pairs = []
    ranks = {}

    for a, b in zip(toks, toks[1:]):
        pairs.append((a, b))

    for pair in pairs:
        ranks[pair] = ranks.get(pair, 0) + 1

    # sorted frequency of pairs
    return sorted(ranks.items(), key=lambda x: x[1], reverse=True)

In [None]:
tok_ids = get_bp_freqs(tokens)
tok_ids[:10]

[((101, 32), 11315346),
 ((114, 92), 9407734),
 ((92, 114), 9006460),
 ((92, 110), 9006037),
 ((32, 116), 8732423),
 ((92, 120), 7811235),
 ((116, 104), 7779215),
 ((104, 101), 7698448),
 ((32, 97), 6368815),
 ((100, 32), 6253081),
 ((116, 32), 6103084),
 ((115, 32), 5790313),
 ((105, 110), 5449142),
 ((101, 114), 5071126),
 ((110, 32), 4914592),
 ((97, 110), 4672785),
 ((114, 101), 4199520),
 ((44, 32), 4033344),
 ((32, 115), 4018963),
 ((32, 111), 4007100),
 ((32, 119), 3942099),
 ((101, 110), 3940221),
 ((111, 110), 3598047),
 ((110, 100), 3477258),
 ((101, 100), 3463858),
 ((32, 104), 3395825),
 ((32, 32), 3348817),
 ((111, 117), 3259171),
 ((97, 116), 3238309),
 ((116, 101), 3157138),
 ((114, 32), 3044284),
 ((111, 114), 2984189),
 ((110, 92), 2931151),
 ((32, 105), 2824544),
 ((121, 32), 2804834),
 ((105, 116), 2729187),
 ((46, 32), 2720414),
 ((110, 103), 2643219),
 ((110, 116), 2611704),
 ((104, 97), 2555397),
 ((97, 114), 2547164),
 ((101, 115), 2521434),
 ((116, 111), 2511306

In [7]:
top_freq_pair = max(tok_ids, key=lambda pair: pair[1])
top_freq_pair

((101, 32), 11315346)

#### Encode consecutive byte pairs with new token id

In [8]:
def merge(tok_ids: list[int], replace_pair: tuple[int, int], new_id: int):
    new_ids = []
    i = 0
    while i < len(tok_ids):
        if (
            i < len(tok_ids) - 1
            and tok_ids[i] == replace_pair[0]
            and tok_ids[i + 1] == replace_pair[1]
        ):
            new_ids.append(new_id)
            i += 2
        else:
            new_ids.append(tok_ids[i])
            i += 1
    return new_ids


print(merge([5, 4, 32, 32, 2, 5, 6], (32, 32), 99))

[5, 4, 99, 2, 5, 6]


#### BPE training

Note that the training dataset is huge. Takes an hour to train the whole dataset on my laptop. Better train on a subset of the dataset.

In [9]:
vocab_size = 276
byte_size = 256
# max value of a byte is 256 (1 byte = 8 bits = 256 max in int)
# diff gives how many merges required to reach new vocab_size
num_merges = vocab_size - byte_size

tok_ids = list(tokens)
merges = {}  # (int, int) -> int

for tok_id in range(num_merges):
    # find frequency of all byte pairs
    bp_freqs = get_bp_freqs(tok_ids)
    merge_pair, _ = max(bp_freqs, key=lambda x: x[1])
    new_id = byte_size + tok_id
    tok_ids = merge(tok_ids, merge_pair, new_id)
    merges[merge_pair] = new_id
    print(f"merged {merge_pair} -> {new_id}")

merged (101, 32) -> 256
merged (114, 92) -> 257
merged (257, 110) -> 258
merged (92, 258) -> 259
merged (116, 104) -> 260
merged (92, 120) -> 261
merged (100, 32) -> 262
merged (116, 32) -> 263
merged (115, 32) -> 264
merged (105, 110) -> 265
merged (101, 114) -> 266
merged (97, 110) -> 267
merged (44, 32) -> 268
merged (101, 110) -> 269
merged (111, 110) -> 270
merged (260, 256) -> 271
merged (32, 32) -> 272
merged (111, 117) -> 273
merged (111, 114) -> 274
merged (121, 32) -> 275


#### Analyze text compression ratio

In [10]:
print(f"{len(tokens)=}")
print(f"{len(tok_ids)=}")
print(f"Compression ratio={len(tokens) / len(tok_ids)}")

len(tokens)=479771850
len(tok_ids)=366930472
Compression ratio=1.307527955868435


### Decoding

In [11]:
def decode(ids: list[int], merges: dict[tuple[int, int], int]) -> str:
    # ? token id: byte mapping for first 256 token ids
    itob = {tok_id: bytes([tok_id]) for tok_id in range(256)}

    # ? token id: byte mapping for tok ids > 255
    # build merged tokens by concatenating their components
    # ! merges dict must preseve the order of merges during encoding
    for (a, b), new_id in merges.items():
        itob[new_id] = itob[a] + itob[b]

    tokens = b"".join(itob[id] for id in ids)
    return tokens.decode("utf-8", errors="replace")


dec_text = decode([127], merges)
dec_text

'\x7f'

In [12]:
def encode(text: str, merges: dict[tuple[int, int], int]):
    # get byte representation of text
    tok_ids = list(text.encode("utf-8", errors="replace"))
    # encode statistically frequent byte pairs learned from training
    while len(tok_ids) > 1:
        bps = [bp for bp, _ in get_bp_freqs(tok_ids)]
        # extract smallest byte pair that can be "compressed"
        # compress smaller pairs first as outputs can be recursively compressed later
        pair_to_merge = min(bps, key=lambda p: merges.get(p, float("inf")))
        # no byte pairs to merge
        if pair_to_merge not in merges:
            break
        id_after_merge = merges[pair_to_merge]
        tok_ids = merge(tok_ids, pair_to_merge, id_after_merge)

    return tok_ids

In [13]:
enc = encode(
    "ÌïúÍµ≠ ÏÇ¨ÌöåÎäî Îπ†Î•¥Í≤å Î≥ÄÌôîÌïòÍ≥† ÏûàÎã§. Í∏∞Ïà†Ïùò Î∞úÏ†ÑÏùÄ ÏÇ¨ÎûåÎì§Ïùò ÏùºÏÉÅÏÉùÌôúÎøêÎßå ÏïÑÎãàÎùº ÏùºÌïòÎäî Î∞©ÏãùÍ≥º ÏÜåÌÜµÌïòÎäî Î∞©ÏãùÏóêÎèÑ ÌÅ∞ ÏòÅÌñ•ÏùÑ ÎØ∏ÏπòÍ≥† ÏûàÎã§. Ïä§ÎßàÌä∏Ìè∞Í≥º Ïù∏ÌÑ∞ÎÑ∑Ïùò Î≥¥Í∏âÏúºÎ°ú Ï†ïÎ≥¥Ïóê Ï†ëÍ∑ºÌïòÎäî ÏÜçÎèÑÎäî Ïù¥Ï†ÑÍ≥º ÎπÑÍµêÌï† Ïàò ÏóÜÏùÑ Ï†ïÎèÑÎ°ú Îπ®ÎùºÏ°åÏúºÎ©∞, Ïù¥Îäî ÍµêÏú°, Í≤ΩÏ†ú, Î¨∏Ìôî Ï†ÑÎ∞òÏóê Í±∏Ï≥ê ÏÉàÎ°úÏö¥ Í∏∞ÌöåÎ•º ÎßåÎì§Ïñ¥ ÎÇ¥Í≥† ÏûàÎã§. ÎèôÏãúÏóê Ïù¥Îü¨Ìïú Î≥ÄÌôîÎäî Í∞úÏù∏ÏóêÍ≤å Îçî ÎßéÏùÄ ÏÑ†ÌÉùÍ≥º Ï±ÖÏûÑÏùÑ ÏöîÍµ¨ÌïúÎã§.",
    merges,
)
enc

[237,
 149,
 156,
 234,
 181,
 173,
 32,
 236,
 130,
 172,
 237,
 154,
 140,
 235,
 138,
 148,
 32,
 235,
 185,
 160,
 235,
 165,
 180,
 234,
 178,
 140,
 32,
 235,
 179,
 128,
 237,
 153,
 148,
 237,
 149,
 152,
 234,
 179,
 160,
 32,
 236,
 158,
 136,
 235,
 139,
 164,
 46,
 32,
 234,
 184,
 176,
 236,
 136,
 160,
 236,
 157,
 152,
 32,
 235,
 176,
 156,
 236,
 160,
 132,
 236,
 157,
 128,
 32,
 236,
 130,
 172,
 235,
 158,
 140,
 235,
 147,
 164,
 236,
 157,
 152,
 32,
 236,
 157,
 188,
 236,
 131,
 129,
 236,
 131,
 157,
 237,
 153,
 156,
 235,
 191,
 144,
 235,
 167,
 140,
 32,
 236,
 149,
 132,
 235,
 139,
 136,
 235,
 157,
 188,
 32,
 236,
 157,
 188,
 237,
 149,
 152,
 235,
 138,
 148,
 32,
 235,
 176,
 169,
 236,
 139,
 157,
 234,
 179,
 188,
 32,
 236,
 134,
 140,
 237,
 134,
 181,
 237,
 149,
 152,
 235,
 138,
 148,
 32,
 235,
 176,
 169,
 236,
 139,
 157,
 236,
 151,
 144,
 235,
 143,
 132,
 32,
 237,
 129,
 176,
 32,
 236,
 152,
 129,
 237,
 150,
 165,
 236,
 157,
 132,
 3

In [14]:
decode(enc, merges)

'ÌïúÍµ≠ ÏÇ¨ÌöåÎäî Îπ†Î•¥Í≤å Î≥ÄÌôîÌïòÍ≥† ÏûàÎã§. Í∏∞Ïà†Ïùò Î∞úÏ†ÑÏùÄ ÏÇ¨ÎûåÎì§Ïùò ÏùºÏÉÅÏÉùÌôúÎøêÎßå ÏïÑÎãàÎùº ÏùºÌïòÎäî Î∞©ÏãùÍ≥º ÏÜåÌÜµÌïòÎäî Î∞©ÏãùÏóêÎèÑ ÌÅ∞ ÏòÅÌñ•ÏùÑ ÎØ∏ÏπòÍ≥† ÏûàÎã§. Ïä§ÎßàÌä∏Ìè∞Í≥º Ïù∏ÌÑ∞ÎÑ∑Ïùò Î≥¥Í∏âÏúºÎ°ú Ï†ïÎ≥¥Ïóê Ï†ëÍ∑ºÌïòÎäî ÏÜçÎèÑÎäî Ïù¥Ï†ÑÍ≥º ÎπÑÍµêÌï† Ïàò ÏóÜÏùÑ Ï†ïÎèÑÎ°ú Îπ®ÎùºÏ°åÏúºÎ©∞, Ïù¥Îäî ÍµêÏú°, Í≤ΩÏ†ú, Î¨∏Ìôî Ï†ÑÎ∞òÏóê Í±∏Ï≥ê ÏÉàÎ°úÏö¥ Í∏∞ÌöåÎ•º ÎßåÎì§Ïñ¥ ÎÇ¥Í≥† ÏûàÎã§. ÎèôÏãúÏóê Ïù¥Îü¨Ìïú Î≥ÄÌôîÎäî Í∞úÏù∏ÏóêÍ≤å Îçî ÎßéÏùÄ ÏÑ†ÌÉùÍ≥º Ï±ÖÏûÑÏùÑ ÏöîÍµ¨ÌïúÎã§.'

### Opimizations



#### Pattern matching: reduce unnecessary tokens

In [38]:
import regex as re

# gpt 2 regex
pat = re.compile(
    r"""'(?i:[smdt]|ll|ve|re)| ?\p{L}+| ?\p{N}{1,3}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)

print(re.findall(pat, "Hello world 1234!"))
print(re.findall(pat, "Hello world 1234! !!"))
print(re.findall(pat, "Hello world 1234! !!   a              "))
print(re.findall(pat, "HELLo'S world 1234! !!   a              "))

['Hello', ' world', ' 123', '4', '!']
['Hello', ' world', ' 123', '4', '!', ' !!']
['Hello', ' world', ' 123', '4', '!', ' !!', '  ', ' a', '              ']
['HELLo', "'S", ' world', ' 123', '4', '!', ' !!', '  ', ' a', '              ']
