# BPE

Ref: [Byte Pair Encoding (BPE) Tokenizer From Scratch](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch02/05_bpe-from-scratch/bpe-from-scratch.ipynb)

## Main idea

BPE's idea is to *convert text into an integer representation (token IDs)* for LLM training. It breaks down words that aren't in its predefined vocabulary into smaller subword units or even individual characters to handle the out-of-vocabulary words.

BPE 是一种基于频率的贪心算法，通过反复合并出现频率最高的字符对/子词对，以生成子词单元。

- [Original BPE tokenizer](https://github.com/openai/gpt-2/blob/master/src/encoder.py)
- [Original BPE paper](http://www.pennelynn.com/Documents/CUJ/HTML/94HTML/19940045.HTM)

### Bits and bytes

Consider converting text into a byte array (BPE stands for "byte" pair encoding after all):

In [2]:
text = "Here are some texts"
byte_arr = bytearray(text, 'utf-8')
byte_arr

bytearray(b'Here are some texts')

When we call `list()` on a bytearray object, each byte is treated as an individual element, and the result is a list of *integers corresponding to the byte values*:

In [4]:
ids = list(byte_arr)
print(ids)

[72, 101, 114, 101, 32, 97, 114, 101, 32, 115, 111, 109, 101, 32, 116, 101, 120, 116, 115]


- This would be a valid way to convert text into a token ID representation that we need for the embedding layer of an LLM
- However, the downside of this approach is that it is creating one ID for each character (that's a lot of IDs for a short text!)
- I.e., this means for a 17-character input text, we have to use 17 token IDs as input to the LLM:

这种将 text 转换成 ids 表示的方法可以用于 LLM 的 嵌入层，但是为每个字符创建一个 id 的量很大。

In [5]:
# 19 个字符 -> 19 个 id
print("Number of characters:", len(text))
print("Number of token IDs:", len(ids))

Number of characters: 19
Number of token IDs: 19


而实际上，LLM 嵌入的词编码有一个 vocabulary list，对其中的每个词/子词进行编码。如使用 `tiktoken`：

In [9]:
import tiktoken

gpt2_tokenizer = tiktoken.get_encoding('gpt2')
gpt2_tokenizer.encode(text), len(gpt2_tokenizer.encode(text))

([4342, 389, 617, 13399], 4)

A byte consists of 8 bits, there are $2^8=256$ possible values for a single byte, ranging from 0 to 255.

1 个字节对应 8 个比特，所有对于一个字节可以有 256 种表示（0～255）。

A BPE tokenizer usually uses these 256 values as its first 256 single-character tokens; 

一个 BPE 分词器通常将这 256 个值作为其前 256 个单字符标记，即，字符的 ASCII 表（或扩展版本，如 UTF-8）中的前 256 个值，每个值对应一个字符。

one could visually check this by running the following code:

In [11]:
for i in range(300):
    decoded = gpt2_tokenizer.decode([i])
    print(f'{i}: {decoded}')

0: !
1: "
2: #
3: $
4: %
5: &
6: '
7: (
8: )
9: *
10: +
11: ,
12: -
13: .
14: /
15: 0
16: 1
17: 2
18: 3
19: 4
20: 5
21: 6
22: 7
23: 8
24: 9
25: :
26: ;
27: <
28: =
29: >
30: ?
31: @
32: A
33: B
34: C
35: D
36: E
37: F
38: G
39: H
40: I
41: J
42: K
43: L
44: M
45: N
46: O
47: P
48: Q
49: R
50: S
51: T
52: U
53: V
54: W
55: X
56: Y
57: Z
58: [
59: \
60: ]
61: ^
62: _
63: `
64: a
65: b
66: c
67: d
68: e
69: f
70: g
71: h
72: i
73: j
74: k
75: l
76: m
77: n
78: o
79: p
80: q
81: r
82: s
83: t
84: u
85: v
86: w
87: x
88: y
89: z
90: {
91: |
92: }
93: ~
94: �
95: �
96: �
97: �
98: �
99: �
100: �
101: �
102: �
103: �
104: �
105: �
106: �
107: �
108: �
109: �
110: �
111: �
112: �
113: �
114: �
115: �
116: �
117: �
118: �
119: �
120: �
121: �
122: �
123: �
124: �
125: �
126: �
127: �
128: �
129: �
130: �
131: �
132: �
133: �
134: �
135: �
136: �
137: �
138: �
139: �
140: �
141: �
142: �
143: �
144: �
145: �
146: �
147: �
148: �
149: �
150: �
151: �
152: �
153: �
154: �
155: �
156: �
157: �
158:

### Building the vocabulary

The goal of the BPE tokenization algorithm is to *build a vocabulary of commonly occurring subwords* like 298: ent (which can be found in entangle, entertain, enter, entrance, entity, ..., for example), or even complete words like

```python
'''
318: is
617: some
1212: This
2420: text
'''
```


### BPE algorithm outline

1. Identify frequent pairs

- In each iteration, scan the text to find the *most commonly occurring* pair of bytes (or characters)
- 找到出现最频繁的字节对（或字符对）

2. Replace and record

- Replace that pair with a new placeholder ID (one not already in use, e.g., if we start with 0...255, the first placeholder would be 256)  用一个新的占位符 ID（尚未使用的 ID，例如，如果起始范围是 0…255，第一个占位符 ID 将是 256）替换该字符对；

- Record this mapping in a lookup table  将这种映射关系记录在一个查找表中；

- The size of the lookup table is a hyperparameter, also called "vocabulary size" (for GPT-2, that's 50,257)  查找表的大小是一个超参数，也称为“词汇表大小”（如 GPT-2 的词汇表大小为 50,257）

3. Repeat until no gains

- Keep repeating steps 1 and 2, continually merging the most frequent pairs  不断重复步骤 1 和步骤 2，持续合并最频繁的字符对
- Stop when no further compression is possible (e.g., no pair occurs more than once)  当无法进行进一步压缩时（例如，没有任何字符对的出现次数超过 1 次），停止迭代


4. Decompression (decoding)

- To restore the original text, reverse the process by substituting each ID with its corresponding pair, using the lookup table
- 为了还原原始文本，反向执行上述过程，即使用查找表将每个 ID 替换为其对应的字符对

## A simple BPE implementation

1. Split the input text into individual bytes

2. Repeatedly find & replace (merge) adjacent tokens (pairs) when they match any pair in the learned BPE merges (from highest to lowest "rank," i.e., in the order they were learned)

3. Continue merging until no more merges can be applied

4. The final list of token IDs is the encoded output

In [12]:
from collections import Counter, deque
from functools import lru_cache
import json

In [27]:
class BPETokenizer:
    def __init__(self):
        # token_ids -> token_str
        self.vocab = {}

        # token_str -> token_ids
        self.inverse_vocab = {}

        # Dictionary of BPE merges
        # {(token_id1, id2): merged_token_id}
        self.bpe_merges = {}

    def train(self, text, vocab_size, allowed_special={"<|endoftext|>"}):
        """
        Train the BPE tokenizer from scratch.

        Args:
            text (str): The training text.
            vocab_size (int): The desired vocabulary size.
            allowed_special (set): A set of special tokens to include.
        """

        # Preprocess: Replace spaces with 'Ġ'
        # Note that Ġ is a particularity of the GPT-2 BPE implementation
        # E.g., "Hello world" might be tokenized as ["Hello", "Ġworld"]
        # (GPT-4 BPE would tokenize it as ["Hello", " world"])
        processed_text = []
        for i, char in enumerate(text):
            if char == ' ' and i != 0:
                processed_text.append("Ġ")
            if char != ' ':
                processed_text.append(char)
        processed_text = ''.join(processed_text)

        # Initialize vocab with unique characters, including 'Ġ' if present
        # Start with the first 256 ASCII characters
        unique_chars = [chr(i) for i in range(256)]

        # extend with chars from processed_text
        unique_chars.append(c for c in sorted(set(processed_text)) if c not in unique_chars)

        # Optionally, ensure 'Ġ' is included if it is relevant to your text processing
        if "Ġ" not in unique_chars: unique_chars.append("Ġ")

        # create the vocab and inverse vocab
        self.vocab = {i: c for i, c in enumerate(unique_chars)}
        self.inverse_vocab = {c: i for i, c in self.vocab.items()}

        # add allowed special tokens
        if allowed_special:
            for token in allowed_special:
                if token not in self.inverse_vocab:
                    new_id = len(self.vocab)
                    self.vocab[new_id] = token
                    self.inverse_vocab[token] = new_id

        # Tokenize the processed_text into token IDs
        token_ids = [self.inverse_vocab[char] for char in processed_text]

        # BPE step 1-3
        for new_id in range(len(self.vocab), vocab_size):
            pair_id = self.find_freq_pair(token_ids, mode='most')
            if pair_id is None:
                break  # end

            token_ids = self.replace_pair(token_ids, pair_id, new_id)
            self.bpe_merges[pair_id] = new_id

        # Build the vocabulary with merged tokens
        for (p0, p1), new_id in self.bpe_merges.items():
            merged_token = self.vocab[p0] + self.vocab[p1]
            self.vocab[new_id] = merged_token
            self.inverse_vocab[merged_token] = new_id

    def load_vocab_and_merges_from_openai(self, vocab_path, bpe_merges_path):
        """
        Load pre-trained vocabulary and BPE merges from OpenAI's GPT-2 files.

        Args:
            vocab_path (str): Path to the vocab file (GPT-2 calls it 'encoder.json').
            bpe_merges_path (str): Path to the bpe_merges file  (GPT-2 calls it 'vocab.bpe').
        """
        # load vocabulary
        with open(vocab_path, 'r', encoding='utf-8') as f:
            loaded_vocab = json.load(f)
            # loaded_vocab maps token_str to token_id
            self.vocab = {int(v): k for k, v in loaded_vocab.items()}  # token_id: token_str
            self.inverse_vocab = {k: int(v) for k, v in loaded_vocab.items()}  # token_str: token_id

        # load bpe merges
        with open(bpe_merges_path, "r", encoding="utf-8") as file:
            lines = file.readlines()
            # Skip header line if present
            if lines and lines[0].startswith("#"):
                lines = lines[1:]

            for rank, line in enumerate(lines):
                pair = tuple(line.strip().split())
                if len(pair) != 2:
                    print(f"Line {rank+1} has more than 2 entries: {line.strip()}")
                    continue
                token1, token2 = pair
                if token1 in self.inverse_vocab and token2 in self.inverse_vocab:
                    token_id1 = self.inverse_vocab[token1]
                    token_id2 = self.inverse_vocab[token2]
                    merged_token = token1 + token2
                    if merged_token in self.inverse_vocab:
                        merged_token_id = self.inverse_vocab[merged_token]
                        self.bpe_merges[(token_id1, token_id2)] = merged_token_id
                        # print(f"Loaded merge: '{token1}' + '{token2}' -> '{merged_token}' (ID: {merged_token_id})")
                    else:
                        print(f"Merged token '{merged_token}' not found in vocab. Skipping.")
                else:
                    print(f"Skipping pair {pair} as one of the tokens is not in the vocabulary.")

    def encode(self, text):
        """
        Encode the input text into a list of token IDs.

        Args:
            text (str): The text to encode.

        Returns:
            List[int]: The list of token IDs.
        """
        tokens = []

        # spilt text into tokens, keeping newlines intact
        words = text.replace('\n', ' \n ').split()  # Ensure '\n' is treated as a separate token

        for i, word in enumerate(words):
            if i > 0 and not word.startswith('\n'):
                tokens.append("Ġ" + word)  # Add 'Ġ' to words that follow a space or newline
            else:
                tokens.append(word)

        token_ids = []
        for token in tokens:
            if token in self.inverse_vocab:
                token_id = self.inverse_vocab[token]
                token_ids.append(token_id)
            else:
                # Attempt to handle subword tokenization via BPE
                sub_token_ids = self.tokenize_with_bpe(token)
                token_ids.extend(sub_token_ids)

        return token_ids
    
    def tokenize_with_bpe(self, token):
        """
        Tokenize a single token using BPE merges.

        Args:
            token (str): The token to tokenize.

        Returns:
            List[int]: The list of token IDs after applying BPE.
        """
        # Tokenize the token into individual characters (as initial token IDs)
        token_ids = [self.inverse_vocab.get(char, None) for char in token]
        if None in token_ids:
            missing_chars = [char for char, tid in zip(token, token_ids) if tid is None]
            raise ValueError(f"Characters not found in vocab: {missing_chars}")

        can_merge = True
        while can_merge and len(token_ids) > 1:
            can_merge = False
            new_tokens = []
            i = 0
            while i < len(token_ids) - 1:
                pair = (token_ids[i], token_ids[i + 1])
                if pair in self.bpe_merges:
                    merged_token_id = self.bpe_merges[pair]
                    new_tokens.append(merged_token_id)
                    # Uncomment for educational purposes:
                    # print(f"Merged pair {pair} -> {merged_token_id} ('{self.vocab[merged_token_id]}')")
                    i += 2  # Skip the next token as it's merged
                    can_merge = True
                else:
                    new_tokens.append(token_ids[i])
                    i += 1
            if i < len(token_ids):
                new_tokens.append(token_ids[i])
            token_ids = new_tokens

        return token_ids
    
    def decode(self, token_ids):
        """
        Decode a list of token IDs back into a string.

        Args:
            token_ids (List[int]): The list of token IDs to decode.

        Returns:
            str: The decoded string.
        """
        decoded_string = ""
        for token_id in token_ids:
            if token_id not in self.vocab:
                raise ValueError(f"Token ID {token_id} not found in vocab.")
            token = self.vocab[token_id]
            if token.startswith("Ġ"):
                # Replace 'Ġ' with a space
                decoded_string += " " + token[1:]
            else:
                decoded_string += token
        return decoded_string
    
    def save_vocab_and_merges(self, vocab_path, bpe_merges_path):
        """
        Save the vocabulary and BPE merges to JSON files.

        Args:
            vocab_path (str): Path to save the vocabulary.
            bpe_merges_path (str): Path to save the BPE merges.
        """
        # Save vocabulary
        with open(vocab_path, "w", encoding="utf-8") as file:
            json.dump({k: v for k, v in self.vocab.items()}, file, ensure_ascii=False, indent=2)

        # Save BPE merges as a list of dictionaries
        with open(bpe_merges_path, "w", encoding="utf-8") as file:
            merges_list = [{"pair": list(pair), "new_id": new_id}
                           for pair, new_id in self.bpe_merges.items()]
            json.dump(merges_list, file, ensure_ascii=False, indent=2)

    def load_vocab_and_merges(self, vocab_path, bpe_merges_path):
        """
        Load the vocabulary and BPE merges from JSON files.

        Args:
            vocab_path (str): Path to the vocabulary file.
            bpe_merges_path (str): Path to the BPE merges file.
        """
        # Load vocabulary
        with open(vocab_path, "r", encoding="utf-8") as file:
            loaded_vocab = json.load(file)
            self.vocab = {int(k): v for k, v in loaded_vocab.items()}
            self.inverse_vocab = {v: int(k) for k, v in loaded_vocab.items()}

        # Load BPE merges
        with open(bpe_merges_path, "r", encoding="utf-8") as file:
            merges_list = json.load(file)
            for merge in merges_list:
                pair = tuple(merge['pair'])
                new_id = merge['new_id']
                self.bpe_merges[pair] = new_id


    @lru_cache(maxsize=None)
    def get_special_token_id(self, token):
        return self.inverse_vocab.get(token, None)

    @staticmethod
    def find_freq_pair(token_ids, mode="most"):
        pairs = Counter(zip(token_ids, token_ids[1:]))

        if mode == "most":
            return max(pairs.items(), key=lambda x: x[1])[0]
        elif mode == "least":
            return min(pairs.items(), key=lambda x: x[1])[0]
        else:
            raise ValueError("Invalid mode. Choose 'most' or 'least'.")

    @staticmethod
    def replace_pair(token_ids, pair_id, new_id):
        dq = deque(token_ids)
        replaced = []

        while dq:
            current = dq.popleft()
            if dq and (current, dq[0]) == pair_id:
                replaced.append(new_id)
                # Remove the 2nd token of the pair, 1st was already removed
                dq.popleft()
            else:
                replaced.append(current)

        return replaced

## Training, encoding, and decoding

In [15]:
import os
import urllib.request

if not os.path.exists("the-verdict.txt"):
    url = ("https://raw.githubusercontent.com/rasbt/"
           "LLMs-from-scratch/main/ch02/01_main-chapter-code/"
           "the-verdict.txt")
    file_path = "the-verdict.txt"
    urllib.request.urlretrieve(url, file_path)

with open("the-verdict.txt", "r", encoding="utf-8") as f:
    text = f.read()

In [16]:
tokenizer = BPETokenizer()
tokenizer.train(text, vocab_size=1000, allowed_special={"<|endoftext|>"})

In [17]:
# print(tokenizer.vocab)
print(len(tokenizer.vocab))

1000


In [18]:
print(len(tokenizer.bpe_merges))

741


In [20]:
input_text = "Jack embraced beauty through art and life."
token_ids = tokenizer.encode(input_text)
print(token_ids)
print("Number of characters:", len(input_text))
print("Number of token IDs:", len(token_ids))

[425, 257, 655, 532, 303, 312, 257, 297, 97, 466, 121, 596, 842, 116, 288, 467, 257, 327, 973, 46]
Number of characters: 42
Number of token IDs: 20


In [21]:
print(token_ids)
print(tokenizer.decode(token_ids))

[425, 257, 655, 532, 303, 312, 257, 297, 97, 466, 121, 596, 842, 116, 288, 467, 257, 327, 973, 46]
Jack embraced beauty through art and life.


In [22]:
for token_id in token_ids:
    print(f"{token_id} -> {tokenizer.decode([token_id])}")

425 -> Jack
257 ->  
655 -> em
532 -> br
303 -> ac
312 -> ed
257 ->  
297 -> be
97 -> a
466 -> ut
121 -> y
596 ->  through
842 ->  ar
116 -> t
288 ->  a
467 -> nd
257 ->  
327 -> li
973 -> fe
46 -> .


In [24]:
tokenizer.decode(tokenizer.encode("This is some text."))

'This is some text.'

### saving and loading the tokenizer

In [28]:
# Save trained tokenizer
tokenizer.save_vocab_and_merges(vocab_path="vocab.json", bpe_merges_path="bpe_merges.txt")
# Load tokenizer
tokenizer2 = BPETokenizer()
tokenizer2.load_vocab_and_merges(vocab_path="vocab.json", bpe_merges_path="bpe_merges.txt")

TypeError: Object of type generator is not JSON serializable