# 2 Byte-Pair Encoding Tokenizer


## 2.1 The Unicode Standard

Unicode is a text encoding standard that maps characters to integer code points. As of Unicode 16.0 (released
in September 2024), the standard defines 154,998 characters across 168 scripts. For example, the character
“s” has the code point 115 (typically notated as U+0073, where U+ is a conventional prefix and 0073 is 115 in
hexadecimal), and the character “牛” has the code point 29275. In Python, you can use the ord() function
to convert a single Unicode character into its integer representation. The chr() function converts an integer
Unicode code point into a string with the corresponding character.

In [1]:
ord('s')

115

In [2]:
chr(0)

'\x00'

In [4]:
print(chr(0))

 


In [5]:
print("this is a test" + chr(0) + "string")

this is a test string


## 2.2 Unicode Encodings

While the Unicode standard defines a mapping from characters to code points (integers), it’s impractical to
train tokenizers directly on Unicode codepoints, since the vocabulary would be prohibitively large (around
150K items) and sparse (since many characters are quite rare). Instead, we’ll use a Unicode encoding, which
converts a Unicode character into a sequence of bytes. The Unicode standard itself defines three encodings:
UTF-8, UTF-16, and UTF-32, with UTF-8 being the dominant encoding for the Internet (more than 98%
of all webpages).
To encode a Unicode string into UTF-8, we can use the encode() function in Python. To access the
underlying byte values for a Python bytes object, we can iterate over it (e.g., call list()). Finally, we can
use the decode() function to decode a UTF-8 byte string into a Unicode string.

In [6]:
test_string = "hello! こんにちは!"

In [7]:
utf8_encoded = test_string.encode("utf-8")

In [14]:
utf16_encoded = test_string.encode("utf-16")

In [8]:
print(utf8_encoded)

b'hello! \xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf!'


In [15]:
print(utf16_encoded)

b'\xff\xfeh\x00e\x00l\x00l\x00o\x00!\x00 \x00S0\x930k0a0o0!\x00'


In [None]:
list(utf8_encoded)
# what does this do?
# it converts the utf8_encoded string to a list of bytes

[104,
 101,
 108,
 108,
 111,
 33,
 32,
 227,
 129,
 147,
 227,
 130,
 147,
 227,
 129,
 171,
 227,
 129,
 161,
 227,
 129,
 175,
 33]

In [10]:
print(len(test_string))

13


In [11]:
print(len(utf8_encoded))

23


In [38]:
utf8_encoded

b'hello! \xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf!'

In [12]:
print(utf8_encoded.decode("utf-8"))

hello! こんにちは!


Consider the following (incorrect) function, which is intended to decode a UTF-8 byte string into
a Unicode string. Why is this function incorrect? Provide an example of an input byte string
that yields incorrect results.

```python
def decode_utf8_bytes_to_str_wrong(bytestring: bytes):
    return "".join([bytes([b]).decode("utf-8") for b in bytestring])
>>> decode_utf8_bytes_to_str_wrong("hello".encode("utf-8"))
'hello'
```

In [16]:
def decode_utf8_bytes_to_str_wrong(bytestring: bytes):
    return "".join([bytes([b]).decode("utf-8") for b in bytestring])

In [None]:
decode_utf8_bytes_to_str_wrong("hello".encode("utf-8"))

'hello'

In [None]:
test_decode_utf8_bytes_to_str_wrongode_utf8 = "hello! こんにちは!"

In [30]:
decode_utf8_bytes_to_str_wrong(test_decode_utf8_bytes_to_str_wrongode_utf8.encode("utf-8"))

UnicodeDecodeError: 'utf-8' codec can't decode byte 0xe3 in position 0: unexpected end of data

In [31]:
[bytes([b]) for b in test_decode_utf8_bytes_to_str_wrongode_utf8.encode("utf-8")]

[b'h',
 b'e',
 b'l',
 b'l',
 b'o',
 b'!',
 b' ',
 b'\xe3',
 b'\x81',
 b'\x93',
 b'\xe3',
 b'\x82',
 b'\x93',
 b'\xe3',
 b'\x81',
 b'\xab',
 b'\xe3',
 b'\x81',
 b'\xa1',
 b'\xe3',
 b'\x81',
 b'\xaf',
 b'!']

In [32]:
[b for b in test_decode_utf8_bytes_to_str_wrongode_utf8.encode("utf-8")]

[104,
 101,
 108,
 108,
 111,
 33,
 32,
 227,
 129,
 147,
 227,
 130,
 147,
 227,
 129,
 171,
 227,
 129,
 161,
 227,
 129,
 175,
 33]

In [33]:
def decode_utf8_bytes_to_str_correct(bytestring: bytes):
    return "".join([b.decode("utf-8") for b in bytestring])

In [35]:
test_decode_utf8_bytes_to_str_wrongode_utf8

'hello! こんにちは!'

In [34]:
decode_utf8_bytes_to_str_correct(test_decode_utf8_bytes_to_str_wrongode_utf8.encode("utf-8"))

AttributeError: 'int' object has no attribute 'decode'

Answer:


❌ What’s wrong with this function?
```
	•	It decodes one byte at a time, as if each byte corresponds to an independent UTF-8 character.
	•	But UTF-8 is a variable-length encoding, where:
	•	ASCII characters → 1 byte
	•	Other Unicode characters → 2 to 4 bytes

Decoding each byte separately breaks multi-byte sequences, which causes:
	1.	UnicodeDecodeError (if the byte is not a valid standalone character), or
	2.	Corrupted output (if decoding doesn’t throw an error but yields incorrect characters)

```


## 2.4 BPE Tokenizer Training

The BPE tokenizer training procedure consists of three main steps.

Vocabulary initialization The tokenizer vocabulary is a one-to-one mapping from bytestring token to integer ID. Since we’re training a byte-level BPE tokenizer, our initial vocabulary is simply the set of all bytes. Since there are 256 possible byte values, our initial vocabulary is of size 256.

Pre-tokenization Once you have a vocabulary, you could, in principle, count how often bytes occur next to each other in your text and begin merging them starting with the most frequent pair of bytes. However, this is quite computationally expensive, since we’d have to go take a full pass over the corpus each time we merge. In addition, directly merging bytes across the corpus may result in tokens that differ only in punctuation (e.g., dog! vs. dog.). These tokens would get completely different token IDs, even though they are likely to have high semantic similarity (since they differ only in punctuation).

To avoid this, we pre-tokenize the corpus. You can think of this as a coarse-grained tokenization over the corpus that helps us count how often pairs of characters appear. For example, the word 'text' might be a pre-token that appears 10 times. In this case, when we count how often the characters ‘t’ and ‘e’ appear next to each other, we will see that the word ‘text’ has ‘t’ and ‘e’ adjacent and we can increment their count by 10 instead of looking through the corpus. Since we’re training a byte-level BPE model, each pre-token is represented as a sequence of UTF-8 bytes.

The original BPE implementation of Sennrich et al. [2016] pre-tokenizes by simply splitting on whitespace (i.e., s.split(" ")). In contrast, we’ll use a regex-based pre-tokenizer (used by GPT-2; Radford et al., 2019) from github.com/openai/tiktoken/pull/234/files:

In [1]:
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

### Explanation of Regex Pattern `PAT`

#### Code
```python
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
```

#### What it Does
This regex pattern is used to **tokenize text**, often for NLP applications like language models.

It uses the `regex` module (not `re`) to support Unicode properties like `\p{L}`.

#### Breakdown of the Pattern

##### `'(?:[sdmt]|ll|ve|re)`
- Matches English **contractions**:
  - `'s`, `'d`, `'m`, `'t`, `'ll`, `'ve`, `'re`
- `(?:...)` is a **non-capturing group**.

##### ` ?\p{L}+`
- Matches optional space followed by one or more **letters** (from any language).
- `\p{L}` = any Unicode letter.

##### ` ?\p{N}+`
- Matches optional space followed by one or more **numbers**.
- `\p{N}` = any Unicode numeric digit.

##### ` ?[^\s\p{L}\p{N}]+`
- Matches optional space followed by one or more **symbols or punctuation**.
- It excludes whitespace, letters, and numbers.

##### `\s+(?!\S)`
- Matches **trailing whitespace**.
- Negative lookahead `(?!\S)` ensures it’s not followed by any non-whitespace character.

##### `\s+`
- Matches **any other whitespace**.

#### Usage Example (with `regex` module)
```python
import regex
text = "Here's an example: 42 tokens, maybe?"
tokens = regex.findall(PAT, text)
print(tokens)
```

#### Summary
- This is a **Unicode-aware tokenizer regex**.
- Useful for processing:
  - Words
  - Numbers
  - Punctuation
  - Contractions
  - Whitespace
- Designed for tasks like LLM training or text preprocessing.

In [2]:
import regex as re
re.findall(PAT, "some text that i'll pre-tokenize")

['some', ' text', ' that', ' i', "'ll", ' pre', '-', 'tokenize']

In [6]:
text = "Here's an example: 42 tokens, maybe? 😂"
tokens = re.findall(PAT, text)
print(tokens)

['Here', "'s", ' an', ' example', ':', ' 42', ' tokens', ',', ' maybe', '?', ' 😂']


When using it in your code, however, you should use re.finditer to avoid storing the pre-tokenized words as you construct your mapping from pre-tokens to their counts.

### Explanation: Why Use `re.finditer` Instead of `re.findall`

#### Context

You have a regex pattern (e.g., `PAT`) for tokenizing text, and you're building a mapping from each **pre-token** to its **count** (like a frequency dictionary).

---

#### Key Difference Between `findall` and `finditer`

| Method        | Description                                                                 | Memory Usage        |
|---------------|-----------------------------------------------------------------------------|---------------------|
| `re.findall()`| Returns a list of all matches as strings                                    | **Higher** (loads all matches) |
| `re.finditer()`| Returns an **iterator** yielding match objects one at a time               | **Lower** (streaming) |

---

#### Why Prefer `finditer`

- If you're just going to **count** tokens, you don't need to **store** the entire list.
- `re.finditer` lets you:
  - Stream matches one-by-one
  - Avoid unnecessary memory usage
  - Work better with large datasets

---

#### Example

```python
from collections import defaultdict
import regex  # Must use 'regex' instead of 're' for \p{} support

PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
text = "Here's an example: 42 tokens, maybe?"

token_counts = defaultdict(int)

for match in regex.finditer(PAT, text):
    token = match.group()
    token_counts[token] += 1
```

---

#### Summary

- Use `re.finditer()` to **avoid storing all tokens** in memory.
- Especially useful when **constructing token frequency maps**.
- Improves **performance and scalability**.

```python
# Bad (memory-heavy)
tokens = regex.findall(PAT, text)
for token in tokens:
    token_counts[token] += 1

# Good (efficient)
for match in regex.finditer(PAT, text):
    token = match.group()
    token_counts[token] += 1
```

### Answer: Pretoken Function

In [62]:
import regex as re
from collections import defaultdict

def process_text_with_pre_tokenize(text):
    '''
    Pre-tokenize the text using regex to match tokens.
    This function uses a regex pattern to find tokens in the text.
    It returns a dictionary with tokens as keys and their counts as values.
    '''
    PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
    token_counts = defaultdict(int)
    # Good (efficient)
    for match in re.finditer(PAT, text):
        token = match.group()
        token_counts[token] += 1
    return token_counts


def convert_dict_to_list(tokens_counts):
    '''
    Convert the dictionary of token counts to a list of tuples.
    Each tuple contains a token and its count.
    '''
    return list(tokens_counts.keys())

In [18]:
example_text = "the cat in the hat"
token_counts,token_counts_in_UTF8 = process_text_with_pre_tokenize(example_text)
print(token_counts)
print(token_counts_in_UTF8)

defaultdict(<class 'int'>, {'the': 1, ' cat': 1, ' in': 1, ' the': 1, ' hat': 1})
defaultdict(<class 'int'>, {b'the': 1, b' cat': 1, b' in': 1, b' the': 1, b' hat': 1})


In [30]:
print(convert_dict_to_list(token_counts))
print(convert_dict_to_list(token_counts_in_UTF8))

['the', ' cat', ' in', ' the', ' hat']
[b'the', b' cat', b' in', b' the', b' hat']


In [46]:
type(convert_dict_to_list(token_counts_in_UTF8)[0])

bytes

In [55]:
def convert_utf8_to_int(tokens_representation):
    if isinstance(tokens_representation[0], str):
        print("Converting str to int")
        return [list(map(ord, string)) for string in tokens_representation]
    if isinstance(tokens_representation[0], bytes):
        print("Converting bytes to int")
        return [list(bytes_item) for bytes_item in tokens_representation]

In [57]:
convert_utf8_to_int(convert_dict_to_list(token_counts_in_UTF8))
convert_utf8_to_int(convert_dict_to_list(token_counts))

Converting bytes to int
Converting str to int


[[116, 104, 101],
 [32, 99, 97, 116],
 [32, 105, 110],
 [32, 116, 104, 101],
 [32, 104, 97, 116]]

### Compute BPE 
merges Now that we’ve converted our input text into pre-tokens and represented each pre-token as a sequence of UTF-8 bytes, we can compute the BPE merges (i.e., train the BPE tokenizer). At a high level, the BPE algorithm iteratively counts every pair of bytes and identifies the pair with the highest frequency (“A”, “B”). Every occurrence of this most frequent pair (“A”, “B”) is then merged, i.e., replaced with a new token “AB”. This new merged token is added to our vocabulary; as a result, the final vocabulary after BPE training is the size of the initial vocabulary (256 in our case), plus the number of BPE merge operations performed during training. For eﬀiciency during BPE training, we do not consider pairs that cross pre-token boundaries. 2 When computing merges, deterministically break ties in pair frequency by preferring the lexicographically greater pair. For example, if the pairs (“A”, “B”), (“A”, “C”), (“B”, “ZZ”), and (“BA”, “A”) all have the highest frequency, we’d merge (“BA”, “A”):

>>> max([("A", "B"), ("A", "C"), ("B", "ZZ"), ("BA", "A")]) ('BA', 'A')

### Explanation: "Deterministically break ties in pair frequency by preferring the lexicographically greater pair"

#### Context
This typically applies in **Byte Pair Encoding (BPE)** or similar tokenization algorithms, where you:
1. Count how often each pair of symbols appears.
2. Merge the most frequent pair.
3. Repeat.

Sometimes, **two or more pairs** have the **same frequency**. The algorithm needs a way to choose **which one to merge**.

---

#### 🔑 Key Terms

- **Tie in frequency**: Two symbol pairs occur the same number of times.
- **Deterministically**: Always make the same choice given the same input (no randomness).
- **Lexicographically greater**: Think of dictionary order — `'z' > 'a'`, `'dog' > 'cat'`.

---

#### 🔸 What It Means

> If multiple symbol pairs have the same frequency:
> → Choose the one that comes **later in alphabetical order**.

---

#### 🧠 Example

Assume these are the most frequent pairs with the same frequency:

```
('th', 10)
('he', 10)
('in', 10)
```

To break the tie:
- Sort lexicographically:
  ```
  'th' < 'in' < 'he'  ❌ (wrong)
  Actually: 'in' < 'he' < 'th' ✅
  ```
- Pick the **lexicographically greatest**:
  → `'th'`

---

#### ✅ Why It’s Important

- Ensures **consistency** in training and inference.
- Prevents randomness that could lead to mismatched tokenization.

---

#### 💡 Summary

> When two symbol pairs are tied in frequency, merge the one that is **alphabetically last**.
```

In [32]:
max([("A", "B"), ("A", "C"), ("B", "ZZ"), ("BA", "A")])

('BA', 'A')

In [33]:
def break_ties_during_merge_by_lexicographically(list_of_tuples):
    '''
    Break ties during merge by lexicographically sorting the tuples.
    This function return the maximum tuple based on the first element.
    '''
    return max(list_of_tuples)

In [34]:
test_of_vocab_pairs = [("A", "B"), ("A", "C"), ("B", "ZZ"), ("BA", "A")]
print(break_ties_during_merge_by_lexicographically(test_of_vocab_pairs))

('BA', 'A')


### Special tokens
Often, some strings (e.g., <|endoftext|>) are used to encode metadata (e.g., boundaries between documents). When encoding text, it’s often desirable to treat some strings as “special tokens” that should never be split into multiple tokens (i.e., will always be preserved as a single token). For example, the end-of-sequence string <|endoftext|> should always be preserved as a single token (i.e., a single integer ID), so we know when to stop generating from the language model. These special tokens must be added to the vocabulary, so they have a corresponding fixed token ID.

Algorithm 1 of Sennrich et al. [2016] contains an ineﬀicient implementation of BPE tokenizer training (essentially following the steps that we outlined above). As a first exercise, it may be useful to implement and test this function to test your understanding.

### TODO: Need to implement this function, for now, it's unclear how to do it.

### Algorithm 1:


In [59]:
example_text_ag1 = """
    low low low low low
    lower lower widest widest widest
    newest newest newest newest newest newest
"""

In [63]:
process_text_with_pre_tokenize(example_text_ag1)

defaultdict(int,
            {'\n   ': 3,
             ' low': 5,
             ' lower': 2,
             ' widest': 3,
             ' newest': 6,
             '\n': 1})

### 🤔 Should We Keep Whitespace Matches in BPE Tokenizer?

#### 📌 Regex in Question
```python
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
```

#### 🧠 Observation
This pattern **intentionally matches leading spaces**:
- `' example'`
- `' cat'`
- `'\n'`

---

#### ✅ Is This Normal in BPE Tokenizers?

**Yes**, this is **normal and intentional** in many BPE tokenizers, especially in models like:
- **GPT-2 / GPT-3 / GPT-4**
- **T5**
- **RoBERTa**

They treat **whitespace as meaningful**:
- Leading spaces (e.g., `' cat'`) are **part of the token**.
- This helps capture context and word boundaries **without needing a special separator**.
- For example:
  - `'cat'` and `' cat'` are different tokens.
  - `'Hello\nWorld'` keeps the newline to preserve formatting.

---

#### 🔍 Why Not Remove Whitespace?

Removing whitespace would:
- Break token alignment between training and inference.
- Change the meaning of tokens and mess up pretraining statistics.
- Lose valuable structure (e.g., indents, newlines, sentence spacing).

---

#### ✅ When Should You Remove Whitespace?

Only consider removing whitespace:
- If you are **preprocessing raw text** for **custom tokenization**.
- If your tokenizer is **not whitespace-sensitive** (rare).
- If you're using a **character-level model** or models that handle spacing differently.

---

#### 🧾 Summary

| 🔍 Behavior                     | ✅ Keep Whitespace |
|-------------------------------|--------------------|
| GPT-style BPE tokenization    | Yes                |
| Word-boundary sensitive model | Yes                |
| Character-level model         | Maybe not          |
| Custom cleaner/tokenizer      | Optional           |

> ✅ **Final Answer**: **Yes**, keeping leading whitespace is **normal** and **recommended** for standard BPE tokenization like GPT-2.


In [68]:
def process_text_with_pre_tokenize(text, PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""):
    '''
    Pre-tokenize the text using regex to match tokens.
    This function uses a regex pattern to find tokens in the text.
    It returns a dictionary with tuple of characters as keys and their counts as values.
    '''
    token_counts = defaultdict(int)
    for match in re.finditer(PAT, text):
        token = match.group()
        char_tuple = tuple(token)  # Convert string token to tuple of characters
        token_counts[char_tuple] += 1
    return dict(token_counts)


Great question!

In this line:

return dict(token_counts)

You’re converting the defaultdict to a regular dict. Here’s why this is useful (and sometimes necessary):

⸻

🧠 Reason

defaultdict is a subclass of dict, but it has extra behavior: when you access a missing key, it automatically creates it with a default value (like 0 for defaultdict(int)).

This can be dangerous or undesired in some contexts, such as:
	•	Serialization (e.g., saving to JSON) — defaultdict can’t be directly serialized.
	•	Debugging or displaying — printing defaultdict might be confusing due to its type.
	•	Returning from a function — to prevent accidental modification or side effects when callers use [] indexing.

⸻

✅ So, converting to dict:

return dict(token_counts)

	•	Makes the return value a plain Python dictionary (dict).
	•	Removes the automatic “default value” behavior.
	•	Is safer and more portable.

⸻

🔁 Example

from collections import defaultdict

d = defaultdict(int)
d['a'] += 1
print(d['b'])  # Outputs 0 and adds 'b' to the dict

d2 = dict(d)
print(d2['c'])  # KeyError!

So dict(token_counts) ensures no unexpected key creation occurs later.

⸻

👇 TL;DR

We use dict(...) at return time to:
	•	Strip away the special behavior of defaultdict
	•	Make the function output predictable and compatible with other tools and libraries

Let me know if you’d like to return a defaultdict instead, or if you want examples of why it matters in practice.

In [69]:
example_text_ag1 = """
    low low low low low
    lower lower widest widest widest
    newest newest newest newest newest newest
"""

Test regrex pattern for the assignment

In [79]:
vocab = process_text_with_pre_tokenize(example_text_ag1, r'\b[a-zA-Z]+\b')
vocab

{('l', 'o', 'w'): 5,
 ('l', 'o', 'w', 'e', 'r'): 2,
 ('w', 'i', 'd', 'e', 's', 't'): 3,
 ('n', 'e', 'w', 'e', 's', 't'): 6}

### Merges 
We first look at every successive pair of bytes and sum the frequency of the words where they appear {lo: 7, ow: 7, we: 8, er: 2, wi: 3, id: 3, de: 3, es: 9, st: 9, ne: 6, ew: 6}. The pair ('es') and ('st') are tied, so we take the lexicographically greater pair, ('st'). We would then merge the pre-tokens so that we end up with {(l,o,w): 5, (l,o,w,e,r): 2, (w,i,d,e,st): 3, (n,e,w,e,st): 6}.

In the second round, we see that (e, st) is the most common pair (with a count of 9) and we would merge into {(l,o,w): 5, (l,o,w,e,r): 2, (w,i,d,est): 3, (n,e,w,est): 6}. Continuing this, the sequence of merges we get in the end will be ['s t', 'e st', 'o w', 'l ow', 'w est', 'n e', 'ne west', 'w i', 'wi d', 'wid est', 'low e', 'lowe r'].

If we take 6 merges, we have ['s t', 'e st', 'o w', 'l ow', 'w est', 'n e'] and our vocabulary elements would be [<|endoftext|>, [...256 BYTE CHARS], st, est, ow, low, west, ne].

With this vocabulary and set of merges, the word newest would tokenize as [ne, west].

In [88]:
import regex

example_text_ag1 = """
    low low low low low
    lower lower widest widest widest
    newest newest newest newest newest newest
"""

# This pattern matches one or more Unicode letters
tokens = regex.findall(r"\p{L}+", example_text_ag1)
pure_word_result = "".join(tokens)
print(pure_word_result)

lowlowlowlowlowlowerlowerwidestwidestwidestnewestnewestnewestnewestnewestnewest


In [90]:
indices = list(map(int, pure_word_result.encode("utf-8")))
successive_texts = list(map(str, pure_word_result))
print(indices)
print(successive_texts)

[108, 111, 119, 108, 111, 119, 108, 111, 119, 108, 111, 119, 108, 111, 119, 108, 111, 119, 101, 114, 108, 111, 119, 101, 114, 119, 105, 100, 101, 115, 116, 119, 105, 100, 101, 115, 116, 119, 105, 100, 101, 115, 116, 110, 101, 119, 101, 115, 116, 110, 101, 119, 101, 115, 116, 110, 101, 119, 101, 115, 116, 110, 101, 119, 101, 115, 116, 110, 101, 119, 101, 115, 116, 110, 101, 119, 101, 115, 116]
['l', 'o', 'w', 'l', 'o', 'w', 'l', 'o', 'w', 'l', 'o', 'w', 'l', 'o', 'w', 'l', 'o', 'w', 'e', 'r', 'l', 'o', 'w', 'e', 'r', 'w', 'i', 'd', 'e', 's', 't', 'w', 'i', 'd', 'e', 's', 't', 'w', 'i', 'd', 'e', 's', 't', 'n', 'e', 'w', 'e', 's', 't', 'n', 'e', 'w', 'e', 's', 't', 'n', 'e', 'w', 'e', 's', 't', 'n', 'e', 'w', 'e', 's', 't', 'n', 'e', 'w', 'e', 's', 't', 'n', 'e', 'w', 'e', 's', 't']


In [107]:
# First we need to get the sucessive pairs of characters
char_counts = defaultdict(int)
for char1, char2 in zip(successive_texts, successive_texts[1:]): 
    char_counts[(char1, char2)] += 1
print(char_counts)

defaultdict(<class 'int'>, {('l', 'o'): 7, ('o', 'w'): 7, ('w', 'l'): 5, ('w', 'e'): 8, ('e', 'r'): 2, ('r', 'l'): 1, ('r', 'w'): 1, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'e'): 3, ('e', 's'): 9, ('s', 't'): 9, ('t', 'w'): 2, ('t', 'n'): 6, ('n', 'e'): 6, ('e', 'w'): 6})


The usage of zip

```Python
for item in zip([1, 2, 3], ['sugar', 'spice', 'everything nice']):
    print(item)
```    

The expression zip(indices, indices[1:]) creates pairs of consecutive elements from the indices list.

indices is a list (for example, a list of byte values or character codes).
indices[1:] is the same list but shifted left by one (it skips the first element).
When you use zip(indices, indices[1:]), you get an iterator of tuples where each tuple contains two consecutive elements from the original list.

Example:

```Python
indices = [1, 2, 3, 4]
pairs = list(zip(indices, indices[1:]))
print(pairs)
# Output: [(1, 2), (2, 3), (3, 4)]
```
Use case:
This is commonly used to count or process all adjacent pairs in a sequence.

In [95]:
# First we need to get the sucessive pairs of characters
int_counts = defaultdict(int)
for index1, index2 in zip(indices, indices[1:]): 
    int_counts[(index1, index2)] += 1
print(int_counts)

defaultdict(<class 'int'>, {(108, 111): 7, (111, 119): 7, (119, 108): 5, (119, 101): 8, (101, 114): 2, (114, 108): 1, (114, 119): 1, (119, 105): 3, (105, 100): 3, (100, 101): 3, (101, 115): 9, (115, 116): 9, (116, 119): 2, (116, 110): 6, (110, 101): 6, (101, 119): 6})


In [108]:
max_pair = max(char_counts, key=lambda x: (char_counts[x], x))
print(max_pair)  # This will print ('s', 't')

('s', 't')


In [109]:
def combine_successive_tokens(tokens):
    # First we need to get the sucessive pairs of characters
    counts = defaultdict(int)
    for indices1, indices2 in zip(tokens, tokens[1:]): 
        counts[(indices1, indices2)] += 1
    return counts

def find_max_pair(counts):
    # Find the maximum pair based on the counts
    max_pair = max(counts, key=lambda x: (counts[x], x))
    return max_pair


In [116]:
successive_texts = list(map(str, pure_word_result))

In [117]:
match1, match2 = find_max_pair(combine_successive_tokens(successive_texts))

In [121]:
def merge_tokens(tokens):
    new_tokens = []
    for i, (indices1, indices2) in enumerate(zip(successive_texts, successive_texts[1:])):
        if indices1 == match1 and indices2 == match2:
            new_tokens.append(indices1 + indices2)
        elif i > 1 and successive_texts[i-1] == match1 and successive_texts[i] == match2:
            continue
        else:
            new_tokens.append(indices1)

    return new_tokens
        
    

In [123]:
new_tokens = merge_tokens(successive_texts)
print(new_tokens)

['l', 'o', 'w', 'l', 'o', 'w', 'l', 'o', 'w', 'l', 'o', 'w', 'l', 'o', 'w', 'l', 'o', 'w', 'e', 'r', 'l', 'o', 'w', 'e', 'r', 'w', 'i', 'd', 'e', 'st', 'w', 'i', 'd', 'e', 'st', 'w', 'i', 'd', 'e', 'st', 'n', 'e', 'w', 'e', 'st', 'n', 'e', 'w', 'e', 'st', 'n', 'e', 'w', 'e', 'st', 'n', 'e', 'w', 'e', 'st', 'n', 'e', 'w', 'e', 'st', 'n', 'e', 'w', 'e', 'st']


In [125]:
combine_successive_tokens(list(map(str, new_tokens)))

defaultdict(int,
            {('l', 'o'): 7,
             ('o', 'w'): 7,
             ('w', 'l'): 5,
             ('w', 'e'): 8,
             ('e', 'r'): 2,
             ('r', 'l'): 1,
             ('r', 'w'): 1,
             ('w', 'i'): 3,
             ('i', 'd'): 3,
             ('d', 'e'): 3,
             ('e', 'st'): 9,
             ('st', 'w'): 2,
             ('st', 'n'): 6,
             ('n', 'e'): 6,
             ('e', 'w'): 6})

### Now we combine those functions together

Currently, in this stage, it aims at solving the original case in the 2016 paper of [Sennrich et al., 2016].

In [None]:
import regex

# This pattern matches one or more Unicode letters
def get_raw_tokens(text):
    tokens = regex.findall(r"\p{L}+", text)
    return "".join(tokens)

def get_list_of_characters(pure_word_result):
    indices = list(map(str, pure_word_result))
    return indices

def combine_successive_tokens(tokens):
    # First we need to get the sucessive pairs of characters
    counts = defaultdict(int)
    for indices1, indices2 in zip(tokens, tokens[1:]): 
        counts[(indices1, indices2)] += 1
    return counts

def find_max_pair(counts):
    # Find the maximum pair based on the counts
    max_pair = max(counts, key=lambda x: (counts[x], x))
    return max_pair

def merge_tokens(tokens, match1, match2):
    new_tokens = []
    for i, (indices1, indices2) in enumerate(zip(tokens, tokens[1:])):
        if indices1 == match1 and indices2 == match2:
            new_tokens.append(indices1 + indices2)
        elif i > 1 and tokens[i-1] == match1 and tokens[i] == match2:
            continue
        else:
            new_tokens.append(indices1)

    return new_tokens
        

def BPE_training_naive_version(text, num_merges = 6):
    # get the raw tokens
    pure_word_result = get_raw_tokens(text)
    # get the list of characters
    successive_texts = get_list_of_characters(pure_word_result)
    
    # iterate over the number of merges
    for i in range(num_merges):
        # combine successive tokens
        counts = combine_successive_tokens(successive_texts)
        # find the max pair
        match1, match2 = find_max_pair(counts)
        print(f"Maximum Match {i+1}: {match1}, {match2}")
        # merge tokens
        successive_texts = merge_tokens(successive_texts, match1, match2)
        counts = combine_successive_tokens(successive_texts)
        print(f"After merge {i+1}: {successive_texts}")
        print(f"Counts: {counts}")
    
    return successive_texts, counts


In [None]:
example_text_ag1 = """
    low low low low low
    lower lower widest widest widest
    newest newest newest newest newest newest
"""

In [147]:
final_result, counts = BPE_training_naive_version(example_text_ag1, num_merges = 6)

Maximum Match 1: s, t
After merge 1: ['l', 'o', 'w', 'l', 'o', 'w', 'l', 'o', 'w', 'l', 'o', 'w', 'l', 'o', 'w', 'l', 'o', 'w', 'e', 'r', 'l', 'o', 'w', 'e', 'r', 'w', 'i', 'd', 'e', 'st', 'w', 'i', 'd', 'e', 'st', 'w', 'i', 'd', 'e', 'st', 'n', 'e', 'w', 'e', 'st', 'n', 'e', 'w', 'e', 'st', 'n', 'e', 'w', 'e', 'st', 'n', 'e', 'w', 'e', 'st', 'n', 'e', 'w', 'e', 'st', 'n', 'e', 'w', 'e', 'st']
Counts: defaultdict(<class 'int'>, {('l', 'o'): 7, ('o', 'w'): 7, ('w', 'l'): 5, ('w', 'e'): 8, ('e', 'r'): 2, ('r', 'l'): 1, ('r', 'w'): 1, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'e'): 3, ('e', 'st'): 9, ('st', 'w'): 2, ('st', 'n'): 6, ('n', 'e'): 6, ('e', 'w'): 6})
Maximum Match 2: e, st
After merge 2: ['l', 'o', 'w', 'l', 'o', 'w', 'l', 'o', 'w', 'l', 'o', 'w', 'l', 'o', 'w', 'l', 'o', 'w', 'e', 'r', 'l', 'o', 'w', 'e', 'r', 'w', 'i', 'd', 'est', 'w', 'i', 'd', 'est', 'w', 'i', 'd', 'est', 'n', 'e', 'w', 'est', 'n', 'e', 'w', 'est', 'n', 'e', 'w', 'est', 'n', 'e', 'w', 'est', 'n', 'e', 'w', 'est'

In [140]:
print(final_result)
print(counts)

['low', 'ow', 'low', 'low', 'low', 'low', 'low', 'e', 'r', 'low', 'e', 'r', 'w', 'i', 'd', 'est', 'w', 'i', 'd', 'est', 'w', 'i', 'd', 'estne', 'w', 'estne', 'w', 'estne', 'w', 'estne', 'w', 'estne', 'w', 'estne']
defaultdict(<class 'int'>, {('low', 'ow'): 1, ('ow', 'low'): 1, ('low', 'low'): 4, ('low', 'e'): 2, ('e', 'r'): 2, ('r', 'low'): 1, ('r', 'w'): 1, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est'): 3, ('est', 'w'): 2, ('est', 'ne'): 6, ('ne', 'w'): 5, ('w', 'est'): 5})


In [142]:
list(counts.keys())

[('low', 'ow'),
 ('ow', 'low'),
 ('low', 'low'),
 ('low', 'e'),
 ('e', 'r'),
 ('r', 'low'),
 ('r', 'w'),
 ('w', 'i'),
 ('i', 'd'),
 ('d', 'est'),
 ('est', 'w'),
 ('est', 'ne'),
 ('ne', 'w'),
 ('w', 'est')]

### Here is a version that chatGPT fixes for me


### 🔍 Problem Diagnosis

The core problem lies in how you are merging your tokens and keeping track of them across whole words. Let me explain what went wrong and how to align it with the BPE logic:

#### ❌ Issue 1: Merging across words

Your current implementation processes all characters linearly, without considering word boundaries. For example, you merge `'st'` in the sequence:

```python
['n', 'e', 'w', 'e', 'st', 'n', 'e', 'w', 'e', 'st']
```

But this flattening means BPE is applied **across different words**, which shouldn’t happen.

#### ❌ Issue 2: Tokens are not grouped by words

You converted the input into:

```python
'l', 'o', 'w', 'l', 'o', 'w', ...
```

instead of keeping words like:

```python
['l', 'o', 'w'], ['l', 'o', 'w'], ...
```

This causes merges to affect across word boundaries.

---

### ✅ Expected Approach

📌 BPE operates **within each word**, and frequencies are tracked **across words**. So the input should be:

```python
[
    ['l', 'o', 'w'],
    ['l', 'o', 'w'],
    ['l', 'o', 'w'],
    ['l', 'o', 'w'],
    ['l', 'o', 'w'],
    ['l', 'o', 'w', 'e', 'r'],
    ['l', 'o', 'w', 'e', 'r'],
    ['w', 'i', 'd', 'e', 's', 't'],
    ['w', 'i', 'd', 'e', 's', 't'],
    ['w', 'i', 'd', 'e', 's', 't'],
    ['n', 'e', 'w', 'e', 's', 't'],
    ...
]
```

Then, you collect counts of adjacent symbol pairs across all words.

---

### ✅ Fixing Your Implementation

Here’s the key change you should make:

#### 1. Update `get_list_of_characters` to preserve word boundaries

```python
def get_list_of_characters(pure_word_result):
    return [list(word) for word in pure_word_result.split()]
```

---

#### 2. Fix `combine_successive_tokens`

```python
def combine_successive_tokens(tokens_list):
    counts = defaultdict(int)
    for word in tokens_list:
        for i in range(len(word) - 1):
            counts[(word[i], word[i+1])] += 1
    return counts
```

---

#### 3. Fix `merge_tokens` to merge in each word

```python
def merge_tokens(tokens_list, match1, match2):
    new_tokens_list = []
    for word in tokens_list:
        new_word = []
        i = 0
        while i < len(word):
            if i < len(word) - 1 and word[i] == match1 and word[i+1] == match2:
                new_word.append(match1 + match2)
                i += 2
            else:
                new_word.append(word[i])
                i += 1
        new_tokens_list.append(new_word)
    return new_tokens_list
```

---

#### 4. Update main loop in `BPE_training_naive_version`

Only minor changes are needed here once the above is fixed.

---

### ✅ Bonus: Add Print Debug Helpers

Add a helper to visualize your current tokens:

```python
def print_tokens(tokens_list):
    print(" ".join(["".join(word) for word in tokens_list]))
```

Call this after each merge to check your progress.

---

### ✅ With These Fixes

Your merge sequence will now match:

```python
['s t', 'e st', 'o w', 'l ow', 'w est', 'n e']
```

And `newest` will become:

```python
['ne', 'west']
```

with 6 merges.

In [None]:
import regex
from collections import defaultdict

def get_raw_tokens(text):
    # Use Unicode letters, then split on whitespace
    tokens = regex.findall(r"\p{L}+", text)
    return " ".join(tokens)

def get_list_of_characters(pure_word_result):
    # Preserve word boundaries
    return [list(word) for word in pure_word_result.split()]

def combine_successive_tokens(tokens_list):
    counts = defaultdict(int)
    # Count only happening during each pre-tokenized words
    for word in tokens_list:
        for i in range(len(word) - 1):
            counts[(word[i], word[i+1])] += 1
    return counts

def find_max_pair(counts):
    # Deterministic: break ties using lexicographic order
    return max(counts, key=lambda x: (counts[x], x))

def merge_tokens(tokens_list, match1, match2):
    new_tokens_list = []
    for word in tokens_list:
        new_word = []
        i = 0
        while i < len(word):
            if i < len(word) - 1 and word[i] == match1 and word[i+1] == match2:
                new_word.append(match1 + match2)
                i += 2
            else:
                new_word.append(word[i])
                i += 1
        new_tokens_list.append(new_word)
    return new_tokens_list

def print_tokens(tokens_list):
    print("Current tokens:")
    print(tokens_list)
    print(" ".join(["' '".join(word) for word in tokens_list]))
    print()

def BPE_training_naive_version(text, num_merges=6):
    # Step 1: Pre-tokenization and splitting into characters
    pure_word_result = get_raw_tokens(text)
    tokens_list = get_list_of_characters(pure_word_result)

    # Step 2: Perform merges
    for i in range(num_merges):
        counts = combine_successive_tokens(tokens_list)
        print(f"Counts: {counts}")
        if not counts:
            break
        match1, match2 = find_max_pair(counts)
        print(f"Merge {i+1}: ({match1}, {match2})")
        tokens_list = merge_tokens(tokens_list, match1, match2)
        print_tokens(tokens_list)

    return tokens_list

# Example text from Sennrich et al. (2016)
example_text_ag1 = """
    low low low low low
    lower lower widest widest widest
    newest newest newest newest newest newest
"""

# Run BPE training
final_result = BPE_training_naive_version(example_text_ag1, num_merges=6)

Counts: defaultdict(<class 'int'>, {('l', 'o'): 7, ('o', 'w'): 7, ('w', 'e'): 8, ('e', 'r'): 2, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'e'): 3, ('e', 's'): 9, ('s', 't'): 9, ('n', 'e'): 6, ('e', 'w'): 6})
Merge 1: (s, t)
Current tokens:
[['l', 'o', 'w'], ['l', 'o', 'w'], ['l', 'o', 'w'], ['l', 'o', 'w'], ['l', 'o', 'w'], ['l', 'o', 'w', 'e', 'r'], ['l', 'o', 'w', 'e', 'r'], ['w', 'i', 'd', 'e', 'st'], ['w', 'i', 'd', 'e', 'st'], ['w', 'i', 'd', 'e', 'st'], ['n', 'e', 'w', 'e', 'st'], ['n', 'e', 'w', 'e', 'st'], ['n', 'e', 'w', 'e', 'st'], ['n', 'e', 'w', 'e', 'st'], ['n', 'e', 'w', 'e', 'st'], ['n', 'e', 'w', 'e', 'st']]
l' 'o' 'w l' 'o' 'w l' 'o' 'w l' 'o' 'w l' 'o' 'w l' 'o' 'w' 'e' 'r l' 'o' 'w' 'e' 'r w' 'i' 'd' 'e' 'st w' 'i' 'd' 'e' 'st w' 'i' 'd' 'e' 'st n' 'e' 'w' 'e' 'st n' 'e' 'w' 'e' 'st n' 'e' 'w' 'e' 'st n' 'e' 'w' 'e' 'st n' 'e' 'w' 'e' 'st n' 'e' 'w' 'e' 'st

Counts: defaultdict(<class 'int'>, {('l', 'o'): 7, ('o', 'w'): 7, ('w', 'e'): 8, ('e', 'r'): 2, ('w', 'i'): 3, ('i',

In [156]:
final_result

[['low'],
 ['low'],
 ['low'],
 ['low'],
 ['low'],
 ['low', 'e', 'r'],
 ['low', 'e', 'r'],
 ['w', 'i', 'd', 'est'],
 ['w', 'i', 'd', 'est'],
 ['w', 'i', 'd', 'est'],
 ['ne', 'west'],
 ['ne', 'west'],
 ['ne', 'west'],
 ['ne', 'west'],
 ['ne', 'west'],
 ['ne', 'west']]

#### Here is a optimized version of the BPE training code using a frequency dictionary to avoid redundant work in combine_successive_tokens (renamed to count_pair_frequencies):

In [211]:
import regex
from collections import defaultdict, Counter

def get_raw_tokens(text, PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""):
    # Use Unicode letters, then split on whitespace
    tokens = regex.findall(PAT, text)
    return " ".join(tokens)

def get_list_of_characters(pure_word_result):
    # Preserve word boundaries
    return [list(word) for word in pure_word_result.split()]

def count_pair_frequencies(tokens_counter):
    counts = defaultdict(int)
    for word, freq in tokens_counter.items():
        for i in range(len(word) - 1):
            pair = (word[i], word[i+1])
            counts[pair] += freq
    return counts

def find_max_pair(counts):
    # Deterministic: break ties using lexicographic order
    return max(counts, key=lambda x: (counts[x], x))

def merge_tokens(tokens_counter, match1, match2):
    new_counter = Counter()
    for word, freq in tokens_counter.items():
        new_word = []
        i = 0
        while i < len(word):
            if i < len(word) - 1 and word[i] == match1 and word[i+1] == match2:
                new_word.append(match1 + match2)
                i += 2
            else:
                new_word.append(word[i])
                i += 1
        new_counter[tuple(new_word)] += freq
    return new_counter

def print_tokens(tokens_counter):
    print("Current tokens:")
    tokens_list = [list(word) for word in tokens_counter.keys()]
    print(" ".join(["' '".join(word) for word in tokens_list]))
    print()

def BPE_training_naive_version(text, num_merges=6, verbose=False, PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""):
    # Step 1: Pre-tokenization and splitting into characters
    new_vocab = []
    pure_word_result = get_raw_tokens(text, PAT)
    tokens_list = get_list_of_characters(pure_word_result)
    tokens_counter = Counter(tuple(word) for word in tokens_list)

    # Step 2: Perform merges
    for i in range(num_merges):
        counts = count_pair_frequencies(tokens_counter)
        if not counts:
            break
        match1, match2 = find_max_pair(counts)
        new_vocab.append(match1+match2)
        tokens_counter = merge_tokens(tokens_counter, match1, match2)
        if verbose:
            print(f"Counts: {counts}")
            print(f"Merge {i+1}: ({match1}, {match2})")
            print_tokens(tokens_counter)

    return tokens_counter, new_vocab


In [214]:
# Example text from Sennrich et al. (2016)
example_text_ag1 = """
    low'll low low low low
    lower lower widest widest widest
    newest newest newest newest newest newest
"""
raw_tokens = get_raw_tokens(example_text_ag1,  PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
list_tokens = get_list_of_characters(raw_tokens)
print(raw_tokens)
print(list_tokens)


     low 'll  low  low  low  low 
     lower  lower  widest  widest  widest 
     newest  newest  newest  newest  newest  newest 

[['l', 'o', 'w'], ["'", 'l', 'l'], ['l', 'o', 'w'], ['l', 'o', 'w'], ['l', 'o', 'w'], ['l', 'o', 'w'], ['l', 'o', 'w', 'e', 'r'], ['l', 'o', 'w', 'e', 'r'], ['w', 'i', 'd', 'e', 's', 't'], ['w', 'i', 'd', 'e', 's', 't'], ['w', 'i', 'd', 'e', 's', 't'], ['n', 'e', 'w', 'e', 's', 't'], ['n', 'e', 'w', 'e', 's', 't'], ['n', 'e', 'w', 'e', 's', 't'], ['n', 'e', 'w', 'e', 's', 't'], ['n', 'e', 'w', 'e', 's', 't'], ['n', 'e', 'w', 'e', 's', 't']]


In [215]:
# Example text from Sennrich et al. (2016)
example_text_ag1 = """
    low'll low low low low
    lower lower widest widest widest
    newest newest newest newest newest newest
"""
raw_tokens = get_raw_tokens(example_text_ag1,  PAT = r"\p{L}+")
list_tokens = get_list_of_characters(raw_tokens)
print(raw_tokens)
print(list_tokens)

low ll low low low low lower lower widest widest widest newest newest newest newest newest newest
[['l', 'o', 'w'], ['l', 'l'], ['l', 'o', 'w'], ['l', 'o', 'w'], ['l', 'o', 'w'], ['l', 'o', 'w'], ['l', 'o', 'w', 'e', 'r'], ['l', 'o', 'w', 'e', 'r'], ['w', 'i', 'd', 'e', 's', 't'], ['w', 'i', 'd', 'e', 's', 't'], ['w', 'i', 'd', 'e', 's', 't'], ['n', 'e', 'w', 'e', 's', 't'], ['n', 'e', 'w', 'e', 's', 't'], ['n', 'e', 'w', 'e', 's', 't'], ['n', 'e', 'w', 'e', 's', 't'], ['n', 'e', 'w', 'e', 's', 't'], ['n', 'e', 'w', 'e', 's', 't']]


In [None]:


# Run BPE training
final_result, new_vocab = BPE_training_naive_version(example_text_ag1, PAT = r"\p{L}+", num_merges=6, verbose=True)

Counts: defaultdict(<class 'int'>, {('l', 'o'): 7, ('o', 'w'): 7, ('w', 'e'): 8, ('e', 'r'): 2, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'e'): 3, ('e', 's'): 9, ('s', 't'): 9, ('n', 'e'): 6, ('e', 'w'): 6})
Merge 1: (s, t)
Current tokens:
l' 'o' 'w l' 'o' 'w' 'e' 'r w' 'i' 'd' 'e' 'st n' 'e' 'w' 'e' 'st

Counts: defaultdict(<class 'int'>, {('l', 'o'): 7, ('o', 'w'): 7, ('w', 'e'): 8, ('e', 'r'): 2, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'e'): 3, ('e', 'st'): 9, ('n', 'e'): 6, ('e', 'w'): 6})
Merge 2: (e, st)
Current tokens:
l' 'o' 'w l' 'o' 'w' 'e' 'r w' 'i' 'd' 'est n' 'e' 'w' 'est

Counts: defaultdict(<class 'int'>, {('l', 'o'): 7, ('o', 'w'): 7, ('w', 'e'): 2, ('e', 'r'): 2, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est'): 3, ('n', 'e'): 6, ('e', 'w'): 6, ('w', 'est'): 6})
Merge 3: (o, w)
Current tokens:
l' 'ow l' 'ow' 'e' 'r w' 'i' 'd' 'est n' 'e' 'w' 'est

Counts: defaultdict(<class 'int'>, {('l', 'ow'): 7, ('ow', 'e'): 2, ('e', 'r'): 2, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est'): 3, ('n', 'e'

In [203]:
new_vocab

['st', 'est', 'ow', 'low', 'west', 'ne']

In [204]:
# Example text from Sennrich et al. (2016)
example_text_ag1 = """
    low low low low low
    lower lower widest widest widest
    newest newest newest newest newest newest
"""

# Run BPE training
final_result, new_vocab = BPE_training_naive_version(example_text_ag1, num_merges=6, verbose=True)

Counts: defaultdict(<class 'int'>, {('l', 'o'): 7, ('o', 'w'): 7, ('w', 'e'): 8, ('e', 'r'): 2, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'e'): 3, ('e', 's'): 9, ('s', 't'): 9, ('n', 'e'): 6, ('e', 'w'): 6})
Merge 1: (s, t)
Current tokens:
l' 'o' 'w l' 'o' 'w' 'e' 'r w' 'i' 'd' 'e' 'st n' 'e' 'w' 'e' 'st

Counts: defaultdict(<class 'int'>, {('l', 'o'): 7, ('o', 'w'): 7, ('w', 'e'): 8, ('e', 'r'): 2, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'e'): 3, ('e', 'st'): 9, ('n', 'e'): 6, ('e', 'w'): 6})
Merge 2: (e, st)
Current tokens:
l' 'o' 'w l' 'o' 'w' 'e' 'r w' 'i' 'd' 'est n' 'e' 'w' 'est

Counts: defaultdict(<class 'int'>, {('l', 'o'): 7, ('o', 'w'): 7, ('w', 'e'): 2, ('e', 'r'): 2, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est'): 3, ('n', 'e'): 6, ('e', 'w'): 6, ('w', 'est'): 6})
Merge 3: (o, w)
Current tokens:
l' 'ow l' 'ow' 'e' 'r w' 'i' 'd' 'est n' 'e' 'w' 'est

Counts: defaultdict(<class 'int'>, {('l', 'ow'): 7, ('ow', 'e'): 2, ('e', 'r'): 2, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est'): 3, ('n', 'e'

In [205]:
# Example text from Sennrich et al. (2016)
example_text_ag1 = """
    low low low low low
    lower lower widest widest widest
    newest newest newest newest newest newest
"""

# Run BPE training
final_result, new_vocab = BPE_training_naive_version(example_text_ag1, num_merges=6, verbose=False)
print(final_result)
print('new vocabulary is', new_vocab)

Counter({('ne', 'west'): 6, ('low',): 5, ('w', 'i', 'd', 'est'): 3, ('low', 'e', 'r'): 2})
new vocabulary is ['st', 'est', 'ow', 'low', 'west', 'ne']


#### Time consumption

In [220]:
import regex
from collections import defaultdict, Counter

def get_raw_tokens(text, PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""):
    # Use Unicode letters, then split on whitespace
    tokens = regex.findall(PAT, text)
    return " ".join(tokens)

def get_list_of_characters(pure_word_result):
    # Preserve word boundaries
    return [list(word) for word in pure_word_result.split()]

def count_pair_frequencies(tokens_counter):
    counts = defaultdict(int)
    for word, freq in tokens_counter.items():
        for i in range(len(word) - 1):
            pair = (word[i], word[i+1])
            counts[pair] += freq
    return counts

def find_max_pair(counts):
    # Deterministic: break ties using lexicographic order
    return max(counts, key=lambda x: (counts[x], x))

def merge_tokens(tokens_counter, match1, match2):
    new_counter = Counter()
    for word, freq in tokens_counter.items():
        new_word = []
        i = 0
        while i < len(word):
            if i < len(word) - 1 and word[i] == match1 and word[i+1] == match2:
                new_word.append(match1 + match2)
                i += 2
            else:
                new_word.append(word[i])
                i += 1
        new_counter[tuple(new_word)] += freq
    return new_counter

def print_tokens(tokens_counter):
    print("Current tokens:")
    tokens_list = [list(word) for word in tokens_counter.keys()]
    print(" ".join(["' '".join(word) for word in tokens_list]))
    print()

def BPE_training_naive_version(text, num_merges=6, verbose=False, PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""):
    # Step 1: Pre-tokenization and splitting into characters
    new_vocab = []
    pure_word_result = get_raw_tokens(text, PAT = PAT)
    tokens_list = get_list_of_characters(pure_word_result)
    tokens_counter = Counter(tuple(word) for word in tokens_list)

    # Step 2: Perform merges
    for i in range(num_merges):
        counts = count_pair_frequencies(tokens_counter)
        if not counts:
            break
        match1, match2 = find_max_pair(counts)
        new_vocab.append(match1+match2)
        tokens_counter = merge_tokens(tokens_counter, match1, match2)
        if verbose:
            print(f"Counts: {counts}")
            print(f"Merge {i+1}: ({match1}, {match2})")
            print_tokens(tokens_counter)

    return tokens_counter, new_vocab


In [221]:
GPT2_TOKENIZER_REGEX = \
    r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

In [222]:
def run_bpe(text, num_merges):
    return BPE_training_naive_version(text, num_merges=num_merges, verbose=False, PAT = GPT2_TOKENIZER_REGEX)

In [223]:
import cProfile
import pstats

if __name__ == "__main__":

    file_path = "/Users/ethanj/Documents/CODE/Stanford_CS336/assignment1-basics-main/cs336_basics/text_examples/text_ex1.txt"

    with open(file_path, "r", encoding="utf-8") as f:
        text = f.read()

    profiler = cProfile.Profile()
    profiler.enable()
    final_result, new_vocab = run_bpe(text, 10000)
    profiler.disable()

    stats = pstats.Stats(profiler)
    stats.strip_dirs().sort_stats("cumtime").print_stats(20)  # Top 20 by cumulative time

         78352445 function calls in 9.700 seconds

   Ordered by: cumulative time
   List reduced from 43 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    9.700    9.700 341127449.py:1(run_bpe)
        1    0.152    0.152    9.700    9.700 3203458813.py:46(BPE_training_naive_version)
     3815    5.178    0.001    6.565    0.002 3203458813.py:25(merge_tokens)
     3816    2.066    0.001    2.239    0.001 3203458813.py:13(count_pair_frequencies)
 47713920    0.985    0.000    0.985    0.000 {built-in method builtins.len}
     3815    0.001    0.000    0.736    0.000 3203458813.py:21(find_max_pair)
     3815    0.429    0.000    0.735    0.000 {built-in method builtins.max}
 16214513    0.378    0.000    0.378    0.000 {method 'append' of 'list' objects}
  6723165    0.306    0.000    0.306    0.000 3203458813.py:23(<lambda>)
  7645260    0.187    0.000    0.187    0.000 __init__.py:601(__missing__)
    

In [226]:
new_vocab_encoded = [token.encode("utf-8") for token in new_vocab]
print(new_vocab_encoded)



In [None]:
from IPython.display import display

def print_vocab_chunks(vocab, items_per_line=10):
    for i in range(0, len(vocab), items_per_line):
        display(vocab[i:i + items_per_line])

# Example usage:
print_vocab_chunks(new_vocab)

## 2.5 Experimenting with BPE Tokenizer Training

In [None]:
# def get_raw_tokens(text, PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") -> str:
#     # Use Unicode letters, then split on whitespace
#     tokens = regex.findall(PAT, text)
#     return " ".join(tokens)

# def get_list_of_characters(pure_word_result) -> list[bytes]:
#     # Preserve word boundaries
#     return [bytes(word.encode("utf-8")) for word in pure_word_result.split()]

# def count_pair_frequencies(tokens_counter):
#     counts = defaultdict(int)
#     for word, freq in tokens_counter.items():
#         for i in range(len(word) - 1):
#             print("word:", word[i], word[i+1])
#             pair = (word[i], word[i+1])
#             counts[pair] += freq
#     return counts

# def find_max_pair(counts):
#     # Deterministic: break ties using lexicographic order
#     return max(counts, key=lambda x: (counts[x], x))

# def merge_tokens(tokens_counter, match1, match2):
#     new_counter = Counter()
#     for word, freq in tokens_counter.items():
#         new_word = []
#         i = 0
#         print(word)
#         while i < len(word):
#             if i < len(word) - 1 and word[i] == match1 and word[i+1] == match2:
#                 new_word.append((match1, match2))
#                 i += 2
#             else:
#                 new_word.append(word[i])
#                 i += 1
#         new_counter[tuple(new_word)] += freq
#     return new_counter



In [419]:

def get_raw_tokens(text, PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""):
    # Use Unicode letters, then split on whitespace
    tokens = regex.findall(PAT, text)
    return " ".join(tokens)

def get_list_of_characters(pure_word_result):
    # Preserve word boundaries
    return [list(word) for word in pure_word_result.split()]

def count_pair_frequencies(tokens_counter):
    counts = defaultdict(int)
    for word, freq in tokens_counter.items():
        for i in range(len(word) - 1):
            pair = (word[i], word[i+1])
            counts[pair] += freq
    return counts

def find_max_pair(counts):
    # Deterministic: break ties using lexicographic order
    return max(counts, key=lambda x: (counts[x], x))

def merge_tokens(tokens_counter, match1, match2):
    new_counter = Counter()
    for word, freq in tokens_counter.items():
        new_word = []
        i = 0
        while i < len(word):
            if i < len(word) - 1 and word[i] == match1 and word[i+1] == match2:
                new_word.append(match1 + match2)
                i += 2
            else:
                new_word.append(word[i])
                i += 1
        new_counter[tuple(new_word)] += freq
    return new_counter


Problem (train_bpe): BPE Tokenizer Training

(15 points)

Deliverable: Write a function that, given a path to an input text file, trains a (byte-level) BPE tokenizer. Your BPE training function should handle (at least) the following input parameters:

input_path: str Path to a text file with BPE tokenizer training data.

vocab_size: int A positive integer that defines the maximum final vocabulary size (including the initial byte vocabulary, vocabulary items produced from merging, and any special tokens).

special_tokens: list[str] A list of strings to add to the vocabulary. These special tokens do not otherwise affect BPE training.

Your BPE training function should return the resulting vocabulary and merges:

vocab: dict[int, bytes] The tokenizer vocabulary, a mapping from int (token ID in the vocabulary) to bytes (token bytes).

merges: list[tuple[bytes, bytes]] A list of BPE merges produced from training. Each list item is a tuple of bytes (<token1>, <token2>), representing that <token1> was merged with <token2>. The merges should be ordered by order of creation.

To test your BPE training function against our provided tests, you will first need to implement the test adapter at [adapters.run_train_bpe]. Then, run uv run pytest tests/test_train_bpe.py. Your implementation should be able to pass all tests. Optionally (this could be a large time-investment), you can implement the key parts of your training method using some systems language, for instance C++ (consider cppyy for this) or Rust (using PyO3). If you do this, be aware of which operations require copying vs reading directly from Python memory, and make sure to leave build instructions, or make sure it builds using only pyproject.toml. Also note that the GPT-2 regex is not well-supported in most regex engines and will be too slow in most that do. We have verified that Oniguruma is reasonably fast and supports negative lookahead, but the regex package in Python is, if anything, even faster.

In [None]:
from dataclasses import dataclass

GPT2_TOKENIZER_REGEX = \
    r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

@dataclass
class BPETokenizerParams:
    vocab: dict[int, bytes]
    merges: list[tuple[bytes, bytes]]

def train_BPE_v1(texts: str, num_merges: int) -> BPETokenizerParams:
    # Initialize parameters
    merges: list[tuple[bytes, bytes]] = []
    vocab: dict[int, bytes] = {x: bytes([x]) for x in range(256)}
    # Explanation of the difference between the two dictionary comprehensions:
    # ✅ This one is correct: vocab: dict[int, bytes] = {x: bytes([x]) for x in range(256)}
    # It creates a dictionary mapping each integer from 0 to 255 (inclusive)
    # to its corresponding single-byte representation using bytes([x]).
    # bytes([x]) takes a list of integers and returns a bytes object of length 1.
    
    # ❌ This one is incorrect and will raise a TypeError:
    # vocab: dict[int, bytes] = {x: bytes(x) for x in range(256)}  # Incorrect
    # # bytes(x) tries to create a bytes object of length x filled with zero bytes,
    # which is not what we want here. For example, bytes(3) returns b'\x00\x00\x00',
    # not b'\x03'.
    # Also, bytes(x) will fail for x >= 256 due to range constraints.

    # Step 1: Pre-tokenization and splitting into characters
    raw_tokens = get_raw_tokens(texts, GPT2_TOKENIZER_REGEX)
    bytes_indices = get_list_of_characters(raw_tokens)
    tokens_counter = Counter(tuple(word) for word in bytes_indices)


    # Step 2: Perform merges
    for i in range(num_merges):
        counts = count_pair_frequencies(tokens_counter)
        if not counts:
            break
        match1, match2  = find_max_pair(counts)
        new_ID_for_new_token = 256 + i
        bytes_presentation_of_max_pair = (match1+match2).encode('utf8')
        merges.append((match1.encode('utf8'), match2))
        # print("merges", merges)
        vocab[new_ID_for_new_token] = bytes_presentation_of_max_pair
        tokens_counter = merge_tokens(tokens_counter, match1, match2)
        print("token_counter", tokens_counter)
        print("Merge", merges)
    
    return vocab, merges


In [435]:
example_text_ag1 = """
    low low low low low
    lower lower widest widest widest
    newest newest newest newest newest newest
"""

In [440]:
len(vocab)

262

In [438]:
vocab, merges = train_BPE(example_text_ag1, 6)
merges
print(vocab)

token_counter Counter({('n', 'e', 'w', 'e', 'st'): 6, ('l', 'o', 'w'): 5, ('w', 'i', 'd', 'e', 'st'): 3, ('l', 'o', 'w', 'e', 'r'): 2})
Merge [(b's', b't')]
token_counter Counter({('n', 'e', 'w', 'est'): 6, ('l', 'o', 'w'): 5, ('w', 'i', 'd', 'est'): 3, ('l', 'o', 'w', 'e', 'r'): 2})
Merge [(b's', b't'), (b'e', b'st')]
token_counter Counter({('n', 'e', 'w', 'est'): 6, ('l', 'ow'): 5, ('w', 'i', 'd', 'est'): 3, ('l', 'ow', 'e', 'r'): 2})
Merge [(b's', b't'), (b'e', b'st'), (b'o', b'w')]
token_counter Counter({('n', 'e', 'w', 'est'): 6, ('low',): 5, ('w', 'i', 'd', 'est'): 3, ('low', 'e', 'r'): 2})
Merge [(b's', b't'), (b'e', b'st'), (b'o', b'w'), (b'l', b'ow')]
token_counter Counter({('n', 'e', 'west'): 6, ('low',): 5, ('w', 'i', 'd', 'est'): 3, ('low', 'e', 'r'): 2})
Merge [(b's', b't'), (b'e', b'st'), (b'o', b'w'), (b'l', b'ow'), (b'w', b'est')]
token_counter Counter({('ne', 'west'): 6, ('low',): 5, ('w', 'i', 'd', 'est'): 3, ('low', 'e', 'r'): 2})
Merge [(b's', b't'), (b'e', b'st'), 

#### Keep going on Version 2

In [455]:
from dataclasses import dataclass

GPT2_TOKENIZER_REGEX = \
    r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

@dataclass
class BPETokenizerParams:
    vocab: dict[int, bytes]
    merges: list[tuple[bytes, bytes]]

def train_BPE(input_path: str, vocab_size: int, special_tokens: list[str]) -> BPETokenizerParams:

    #######################################################################################
    # Read the input text
    with open(input_path, "r", encoding="utf-8") as f:
        texts = f.read()
    
    #######################################################################################
    # Initialize parameters
    vocab: dict[int, bytes] = {}

    # Add special tokens first
    for i, token in enumerate(special_tokens):
        vocab[i] = token.encode("utf-8")

    # Add regular bytes 0-255 after special tokens
    offset = len(special_tokens)

    for x in range(256):
        vocab[offset + x] = bytes([x])
    #######################################################################################
    merges: list[tuple[bytes, bytes]] = []
    #######################################################################################

    raw_tokens = get_raw_tokens(texts, GPT2_TOKENIZER_REGEX)
    bytes_indices = get_list_of_characters(raw_tokens)
    tokens_counter = Counter(tuple(word) for word in bytes_indices)

    # Step 2: Perform merges
    index_tracking = offset
    while len(vocab) < vocab_size:
        counts = count_pair_frequencies(tokens_counter)
        if not counts:
            break
        match1, match2  = find_max_pair(counts)
        new_ID_for_new_token = 256 + index_tracking
        index_tracking += 1
        bytes_presentation_of_max_pair = (match1+match2).encode('utf8')
        merges.append((match1.encode('utf8'), match2.encode('utf8')))
        # print("merges", merges)
        vocab[new_ID_for_new_token] = bytes_presentation_of_max_pair
        tokens_counter = merge_tokens(tokens_counter, match1, match2)
        
    return vocab, merges


In [470]:
special_tokens: list[str] = ["<|endoftext|>"] # example

In [462]:
test_txt_file_path = '/Users/ethanj/Documents/CODE/Stanford_CS336/assignment1-basics-main/cs336_basics/text_examples/text_ex1.txt'
vocab, merges = train_BPE(test_txt_file_path, 500, special_tokens)

#### Test functions

In [514]:
import importlib
import cs336_basics.BPE_tokenizer as bpe

importlib.reload(bpe)

<module 'cs336_basics.BPE_tokenizer' from '/Users/ethanj/Documents/CODE/Stanford_CS336/assignment1-basics-main/cs336_basics/BPE_tokenizer.py'>

#### Test traing speed

In [515]:
import json
import time

my_BPE_tokenizer = bpe.BPETokenizer()

input_path = '/Users/ethanj/Documents/CODE/Stanford_CS336/assignment1-basics-main/tests/fixtures/'
start_time = time.time()
vocab, merges = my_BPE_tokenizer.train_BPE(
    input_path=input_path+'corpus.en',
    vocab_size=500,
    special_tokens=["<|endoftext|>"],
)
end_time = time.time()
assert end_time - start_time < 1.5

#### Test output result

In [516]:
from __future__ import annotations

import pathlib
from functools import lru_cache

@lru_cache
def gpt2_bytes_to_unicode() -> dict[int, str]:
    """
    Returns a mapping between every possible byte (an integer from 0 to 255) to a
    printable unicode string character representation. This function is taken
    from the GPT-2 code.

    For example, `chr(0)` is `\x00`, which is an unprintable character:

    >>> chr(0)
    '\x00'
    >>> print(chr(0))

    As a result, this function returns a dictionary `d` where `d[0]` returns `Ā`.
    The bytes that are visually printable keep their original string representation [1].
    For example, `chr(33)` returns `!`, and so accordingly `d[33]` returns `!`.
    Note in particular that the space character `chr(32)` becomes `d[32]`, which
    returns 'Ġ'.

    For unprintable characters, the function shifts takes the integer representing
    the Unicode code point of that character (returned by the Python `ord`) function
    and shifts it by 256. For example, `ord(" ")` returns `32`, so the the space character
    ' ' is shifted to `256 + 32`. Since `chr(256 + 32)` returns `Ġ`, we use that as the
    string representation of the space.

    This function can simplify the BPE implementation and makes it slightly easier to
    manually inspect the generated merges after they're serialized to a file.
    """
    # These 188 integers can used as-is, since they are not whitespace or control characters.
    # See https://www.ssec.wisc.edu/~tomw/java/unicode.html.
    bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
    cs = bs[:]
    # now get the representations of the other 68 integers that do need shifting
    # each will get mapped chr(256 + n), where n will grow from 0...67 in the loop
    # Get printable representations of the remaining integers 68 integers.
    n = 0
    for b in range(2**8):
        if b not in bs:
            # If this integer isn't in our list of visually-representable
            # charcters, then map it to the next nice character (offset by 256)
            bs.append(b)
            cs.append(2**8 + n)
            n += 1
    characters = [chr(n) for n in cs]
    d = dict(zip(bs, characters))
    return d


In [517]:
my_BPE_tokenizer = bpe.BPETokenizer()

vocab, merges = my_BPE_tokenizer.train_BPE(
    input_path=input_path+'corpus.en',
    vocab_size=500,
    special_tokens=["<|endoftext|>"],
)

# Path to the reference tokenizer vocab and merges
reference_vocab_path = input_path + "train-bpe-reference-vocab.json"
reference_merges_path  = input_path + "train-bpe-reference-merges.txt"

# Compare the learned merges to the expected output merges
gpt2_byte_decoder = {v: k for k, v in gpt2_bytes_to_unicode().items()}
with open(reference_merges_path) as f:
    gpt2_reference_merges = [tuple(line.rstrip().split(" ")) for line in f]
    reference_merges = [
        (
            bytes([gpt2_byte_decoder[token] for token in merge_token_1]),
            bytes([gpt2_byte_decoder[token] for token in merge_token_2]),
        )
        for merge_token_1, merge_token_2 in gpt2_reference_merges
    ]
assert merges == reference_merges

# Compare the vocab to the expected output vocab
with open(reference_vocab_path) as f:
    gpt2_reference_vocab = json.load(f)
    reference_vocab = {
        gpt2_vocab_index: bytes([gpt2_byte_decoder[token] for token in gpt2_vocab_item])
        for gpt2_vocab_item, gpt2_vocab_index in gpt2_reference_vocab.items()
    }
# Rather than checking that the vocabs exactly match (since they could
# have been constructed differently, we'll make sure that the vocab keys and values match)
assert set(vocab.keys()) == set(reference_vocab.keys())
assert set(vocab.values()) == set(reference_vocab.values())


AssertionError: 

In [521]:
# Compare the vocab to the expected output vocab
with open(reference_vocab_path) as f:
    gpt2_reference_vocab = json.load(f)
    reference_vocab = {
        gpt2_vocab_index: bytes([gpt2_byte_decoder[token] for token in gpt2_vocab_item])
        for gpt2_vocab_item, gpt2_vocab_index in gpt2_reference_vocab.items()
    }
# Rather than checking that the vocabs exactly match (since they could
# have been constructed differently, we'll make sure that the vocab keys and values match)
assert set(vocab.keys()) == set(reference_vocab.keys())
assert set(vocab.values()) == set(reference_vocab.values())

AssertionError: 

In [525]:
vocab.values()

dict_values([b'<|endoftext|>', b'\x00', b'\x01', b'\x02', b'\x03', b'\x04', b'\x05', b'\x06', b'\x07', b'\x08', b'\t', b'\n', b'\x0b', b'\x0c', b'\r', b'\x0e', b'\x0f', b'\x10', b'\x11', b'\x12', b'\x13', b'\x14', b'\x15', b'\x16', b'\x17', b'\x18', b'\x19', b'\x1a', b'\x1b', b'\x1c', b'\x1d', b'\x1e', b'\x1f', b' ', b'!', b'"', b'#', b'$', b'%', b'&', b"'", b'(', b')', b'*', b'+', b',', b'-', b'.', b'/', b'0', b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8', b'9', b':', b';', b'<', b'=', b'>', b'?', b'@', b'A', b'B', b'C', b'D', b'E', b'F', b'G', b'H', b'I', b'J', b'K', b'L', b'M', b'N', b'O', b'P', b'Q', b'R', b'S', b'T', b'U', b'V', b'W', b'X', b'Y', b'Z', b'[', b'\\', b']', b'^', b'_', b'`', b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y', b'z', b'{', b'|', b'}', b'~', b'\x7f', b'\x80', b'\x81', b'\x82', b'\x83', b'\x84', b'\x85', b'\x86', b'\x87', b'\x88', b'\x89', b'\x8a', b'\x8b',

In [523]:
set(reference_vocab.values())

{b'\x00',
 b'\x01',
 b'\x02',
 b'\x03',
 b'\x04',
 b'\x05',
 b'\x06',
 b'\x07',
 b'\x08',
 b'\t',
 b'\n',
 b'\x0b',
 b'\x0c',
 b'\r',
 b'\x0e',
 b'\x0f',
 b'\x10',
 b'\x11',
 b'\x12',
 b'\x13',
 b'\x14',
 b'\x15',
 b'\x16',
 b'\x17',
 b'\x18',
 b'\x19',
 b'\x1a',
 b'\x1b',
 b'\x1c',
 b'\x1d',
 b'\x1e',
 b'\x1f',
 b' ',
 b' &',
 b' &#',
 b' (',
 b' )',
 b' ,',
 b' .',
 b' ...',
 b' 1',
 b' 2',
 b' ;',
 b' @',
 b' @-@',
 b' A',
 b' And',
 b' B',
 b' C',
 b' D',
 b' E',
 b' F',
 b' G',
 b' H',
 b' I',
 b' L',
 b' M',
 b' O',
 b' P',
 b' S',
 b' T',
 b' W',
 b' a',
 b' ac',
 b' al',
 b' all',
 b' an',
 b' and',
 b' are',
 b' as',
 b' at',
 b' b',
 b' be',
 b' by',
 b' c',
 b' can',
 b' ch',
 b' com',
 b' comp',
 b' con',
 b' d',
 b' de',
 b' e',
 b' ex',
 b' f',
 b' for',
 b' fr',
 b' from',
 b' g',
 b' h',
 b' have',
 b' he',
 b' in',
 b' is',
 b' it',
 b' l',
 b' m',
 b' man',
 b' my',
 b' n',
 b' ne',
 b' not',
 b' o',
 b' of',
 b' on',
 b' or',
 b' p',
 b' pr',
 b' r',
 b' re',
 b' s',

In [518]:
reference_merges

[(b' ', b't'),
 (b' ', b'a'),
 (b'h', b'e'),
 (b'i', b'n'),
 (b' t', b'he'),
 (b'r', b'e'),
 (b' ', b'o'),
 (b' ', b','),
 (b'e', b'r'),
 (b' ', b's'),
 (b'a', b't'),
 (b' ', b'.'),
 (b'n', b'd'),
 (b'i', b's'),
 (b'o', b'r'),
 (b' ', b'w'),
 (b' ', b'c'),
 (b'o', b'n'),
 (b' ', b'b'),
 (b' ', b'f'),
 (b'o', b'u'),
 (b'i', b't'),
 (b'e', b'n'),
 (b'e', b's'),
 (b' o', b'f'),
 (b' ', b'p'),
 (b'in', b'g'),
 (b' ', b'in'),
 (b'e', b'd'),
 (b'a', b'l'),
 (b' ', b'm'),
 (b' a', b'nd'),
 (b' ', b'd'),
 (b'a', b'n'),
 (b'a', b'r'),
 (b' t', b'o'),
 (b'o', b'm'),
 (b' t', b'h'),
 (b'i', b'c'),
 (b'i', b'on'),
 (b' ', b'h'),
 (b' ', b'l'),
 (b' ', b'y'),
 (b' ', b'e'),
 (b'a', b's'),
 (b'o', b't'),
 (b'i', b'l'),
 (b' ', b'n'),
 (b' ', b'u'),
 (b'en', b't'),
 (b' b', b'e'),
 (b' ', b'&'),
 (b' ', b'is'),
 (b' y', b'ou'),
 (b'o', b's'),
 (b' ', b're'),
 (b'e', b't'),
 (b' f', b'or'),
 (b'u', b't'),
 (b'e', b'l'),
 (b' ', b'g'),
 (b'a', b'y'),
 (b's', b't'),
 (b'o', b'w'),
 (b'l', b'e'),
 (b'c',

In [519]:
merges

[(b'\x1f', b's'),
 (b'\x1f', b'`'),
 (b'g', b'd'),
 (b'h', b'm'),
 (b' t', b'he'),
 (b'q', b'd'),
 (b'\x1f', b'n'),
 (b'\x1f', b'+'),
 (b'd', b'q'),
 (b'\x1f', b'r'),
 (b'`', b's'),
 (b'\x1f', b'-'),
 (b'm', b'c'),
 (b'h', b'r'),
 (b'n', b'q'),
 (b'\x1f', b'v'),
 (b'\x1f', b'b'),
 (b'n', b'm'),
 (b'\x1f', b'a'),
 (b'\x1f', b'e'),
 (b'n', b't'),
 (b'h', b's'),
 (b'd', b'm'),
 (b'd', b'r'),
 (b' o', b'e'),
 (b'\x1f', b'o'),
 (b'in', b'f'),
 (b'\x1f', b'in'),
 (b'd', b'c'),
 (b'`', b'k'),
 (b'\x1f', b'l'),
 (b' a', b'nd'),
 (b'\x1f', b'c'),
 (b'`', b'm'),
 (b'`', b'q'),
 (b' t', b'n'),
 (b'n', b'l'),
 (b' t', b'g'),
 (b'h', b'b'),
 (b'h', b'on'),
 (b'\x1f', b'g'),
 (b'\x1f', b'k'),
 (b'\x1f', b'x'),
 (b'\x1f', b'd'),
 (b'`', b'r'),
 (b'n', b's'),
 (b'h', b'k'),
 (b'\x1f', b'm'),
 (b'\x1f', b't'),
 (b'en', b's'),
 (b' b', b'd'),
 (b'\x1f', b'%'),
 (b'\x1f', b'is'),
 (b' y', b'ou'),
 (b'n', b'r'),
 (b'\x1f', b're'),
 (b'd', b's'),
 (b' f', b'or'),
 (b't', b's'),
 (b'd', b'k'),
 (b'\x1f', b'