# Byte Pair Encoding Tokenizer

Here we explain how the Byte Pair Encoding (BPE) Tokenizer works with a simple implementation.

# Table of contents
1. [Introduction](#introduction)

Before getting into BPE details, we first need to introduce the notion of bytes.

<a name="introduction"></a>
## Bits and bytes

Before getting into BPE details, we first need to introduce the notion of bytes.


In [36]:
text = "This is some text"
byte_ary = bytearray(text, "utf-8")
print(byte_ary)

bytearray(b'This is some text')


In [37]:
ids = list(byte_ary)
print(ids)

[84, 104, 105, 115, 32, 105, 115, 32, 115, 111, 109, 101, 32, 116, 101, 120, 116]


In [4]:
text = "مرحبا كيف حالك؟"
byte_ary = bytearray(text, "utf-8")
print(byte_ary)

bytearray(b'\xd9\x85\xd8\xb1\xd8\xad\xd8\xa8\xd8\xa7 \xd9\x83\xd9\x8a\xd9\x81 \xd8\xad\xd8\xa7\xd9\x84\xd9\x83\xd8\x9f')


In [5]:
ids = list(byte_ary)
print(ids)

[217, 133, 216, 177, 216, 173, 216, 168, 216, 167, 32, 217, 131, 217, 138, 217, 129, 32, 216, 173, 216, 167, 217, 132, 217, 131, 216, 159]


In [38]:
vocab = {}

# Preprocess: Replace spaces with 'Ġ'
# Note that Ġ is a particularity of the GPT-2 BPE implementation
# E.g., "Hello world" might be tokenized as ["Hello", "Ġworld"]
# (GPT-4 BPE would tokenize it as ["Hello", " world"])
processed_text = []
for i, char in enumerate(text):
    if char == " " and i != 0:
        processed_text.append("Ġ")
    if char != " ":
        processed_text.append(char)
processed_text = "".join(processed_text)
print('Processed text: {}'.format(processed_text))

Processed text: ThisĠisĠsomeĠtext


In [47]:
# Initialize vocab with unique characters, including 'Ġ' if present
# Start with the first 256 ASCII characters
unique_chars = [chr(i) for i in range(256)]
print(f'Initial length of unique chars: {len(unique_chars)}')

# Extend unique_chars with characters from processed_text that are not already included
unique_chars.extend(char for char in sorted(set(processed_text)) if char not in unique_chars)

# Optionally, ensure 'Ġ' is included if it is relevant to your text processing
if 'Ġ' not in unique_chars:
    unique_chars.append('Ġ')
print(f'Extended length of unique chars: {len(unique_chars)}')

# Now create the vocab and inverse vocab dictionaries
vocab = {i: char for i, char in enumerate(unique_chars)}
inverse_vocab = {char: i for i, char in vocab.items()}

Initial length of unique chars: 256
Extended length of unique chars: 257


In [48]:
allowed_special = {"<|endoftext|>"}

# Add allowed special tokens
if allowed_special:
    for token in allowed_special:
        if token not in inverse_vocab:
            new_id = len(vocab)
            vocab[new_id] = token
            inverse_vocab[token] = new_id

In [49]:
# Tokenize the processed_text into token IDs
token_ids = [inverse_vocab[char] for char in processed_text]
print(len(processed_text))
print(len(token_ids))

17
17


In [50]:
token_ids

[84,
 104,
 105,
 115,
 256,
 105,
 115,
 256,
 115,
 111,
 109,
 101,
 256,
 116,
 101,
 120,
 116]

In [51]:
from collections import Counter, deque


def find_freq_pair(token_ids, mode="most"):
    pairs = Counter(zip(token_ids, token_ids[1:]))

    if mode == "most":
        return max(pairs.items(), key=lambda x: x[1])[0]
    elif mode == "least":
        return min(pairs.items(), key=lambda x: x[1])[0]
    else:
        raise ValueError("Invalid mode. Choose 'most' or 'least'.")

# list(zip(token_ids, token_ids[1:]))
#Counter(zip(token_ids, token_ids[1:]))

pair_id = find_freq_pair(token_ids)
pair_id

(105, 115)

In [52]:
def replace_pair(token_ids, pair_id, new_id):
    dq = deque(token_ids)
    replaced = []

    while dq:
        current = dq.popleft()
        if dq and (current, dq[0]) == pair_id:
            print(f'Replacing {pair_id}')
            replaced.append(new_id)
            # Remove the 2nd token of the pair, 1st was already removed
            dq.popleft()
        else:
            replaced.append(current)

    return replaced

In [33]:
dq = deque(token_ids)
current = dq.popleft()
current

84

In [34]:
(current, dq[0])

(84, 104)

In [53]:
new_id = len(vocab)
replace_pair(token_ids, pair_id, new_id)

Replacing (105, 115)
Replacing (105, 115)


[84, 104, 258, 256, 258, 256, 115, 111, 109, 101, 256, 116, 101, 120, 116]

In [46]:
len(replace_pair(token_ids, pair_id, new_id))

15

In [None]:
# BPE steps 1-3: Repeatedly find and replace frequent pairs
for new_id in range(len(vocab), vocab_size):
    pair_id = find_freq_pair(token_ids, mode="most")
    if pair_id is None:  # No more pairs to merge. Stopping training.
        break
    token_ids = replace_pair(token_ids, pair_id, new_id)
    # bpe_merges[pair_id] = new_id

In [6]:

from collections import Counter, deque
from functools import lru_cache


class BPETokenizerSimple:
    def __init__(self):
        # Maps token_id to token_str (e.g., {11246: "some"})
        self.vocab = {}
        # Maps token_str to token_id (e.g., {"some": 11246})
        self.inverse_vocab = {}
        # Dictionary of BPE merges: {(token_id1, token_id2): merged_token_id}
        self.bpe_merges = {}

    def train(self, text, vocab_size, allowed_special={"<|endoftext|>"}):
        """
        Train the BPE tokenizer from scratch.

        Args:
            text (str): The training text.
            vocab_size (int): The desired vocabulary size.
            allowed_special (set): A set of special tokens to include.
        """

        # Preprocess: Replace spaces with 'Ġ'
        # Note that Ġ is a particularity of the GPT-2 BPE implementation
        # E.g., "Hello world" might be tokenized as ["Hello", "Ġworld"]
        # (GPT-4 BPE would tokenize it as ["Hello", " world"])
        processed_text = []
        for i, char in enumerate(text):
            if char == " " and i != 0:
                processed_text.append("Ġ")
            if char != " ":
                processed_text.append(char)
        processed_text = "".join(processed_text)

        # Initialize vocab with unique characters, including 'Ġ' if present
        # Start with the first 256 ASCII characters
        unique_chars = [chr(i) for i in range(256)]

        # Extend unique_chars with characters from processed_text that are not already included
        unique_chars.extend(char for char in sorted(set(processed_text)) if char not in unique_chars)

        # Optionally, ensure 'Ġ' is included if it is relevant to your text processing
        if 'Ġ' not in unique_chars:
            unique_chars.append('Ġ')

        # Now create the vocab and inverse vocab dictionaries
        self.vocab = {i: char for i, char in enumerate(unique_chars)}
        self.inverse_vocab = {char: i for i, char in self.vocab.items()}

        # Add allowed special tokens
        if allowed_special:
            for token in allowed_special:
                if token not in self.inverse_vocab:
                    new_id = len(self.vocab)
                    self.vocab[new_id] = token
                    self.inverse_vocab[token] = new_id

        # Tokenize the processed_text into token IDs
        token_ids = [self.inverse_vocab[char] for char in processed_text]

        # BPE steps 1-3: Repeatedly find and replace frequent pairs
        for new_id in range(len(self.vocab), vocab_size):
            pair_id = self.find_freq_pair(token_ids, mode="most")
            if pair_id is None:  # No more pairs to merge. Stopping training.
                break
            token_ids = self.replace_pair(token_ids, pair_id, new_id)
            self.bpe_merges[pair_id] = new_id

        # Build the vocabulary with merged tokens
        for (p0, p1), new_id in self.bpe_merges.items():
            merged_token = self.vocab[p0] + self.vocab[p1]
            self.vocab[new_id] = merged_token
            self.inverse_vocab[merged_token] = new_id

    def encode(self, text):
        """
        Encode the input text into a list of token IDs.

        Args:
            text (str): The text to encode.

        Returns:
            List[int]: The list of token IDs.
        """
        tokens = []
        # Split text into tokens, keeping newlines intact
        words = text.replace("\n", " \n ").split()  # Ensure '\n' is treated as a separate token

        for i, word in enumerate(words):
            if i > 0 and not word.startswith("\n"):
                tokens.append("Ġ" + word)  # Add 'Ġ' to words that follow a space or newline
            else:
                tokens.append(word)  # Handle first word or standalone '\n'

        token_ids = []
        for token in tokens:
            if token in self.inverse_vocab:
                # token is contained in the vocabulary as is
                token_id = self.inverse_vocab[token]
                token_ids.append(token_id)
            else:
                # Attempt to handle subword tokenization via BPE
                sub_token_ids = self.tokenize_with_bpe(token)
                token_ids.extend(sub_token_ids)

        return token_ids

    def tokenize_with_bpe(self, token):
        """
        Tokenize a single token using BPE merges.

        Args:
            token (str): The token to tokenize.

        Returns:
            List[int]: The list of token IDs after applying BPE.
        """
        # Tokenize the token into individual characters (as initial token IDs)
        token_ids = [self.inverse_vocab.get(char, None) for char in token]
        if None in token_ids:
            missing_chars = [char for char, tid in zip(token, token_ids) if tid is None]
            raise ValueError(f"Characters not found in vocab: {missing_chars}")

        can_merge = True
        while can_merge and len(token_ids) > 1:
            can_merge = False
            new_tokens = []
            i = 0
            while i < len(token_ids) - 1:
                pair = (token_ids[i], token_ids[i + 1])
                if pair in self.bpe_merges:
                    merged_token_id = self.bpe_merges[pair]
                    new_tokens.append(merged_token_id)
                    # Uncomment for educational purposes:
                    # print(f"Merged pair {pair} -> {merged_token_id} ('{self.vocab[merged_token_id]}')")
                    i += 2  # Skip the next token as it's merged
                    can_merge = True
                else:
                    new_tokens.append(token_ids[i])
                    i += 1
            if i < len(token_ids):
                new_tokens.append(token_ids[i])
            token_ids = new_tokens

        return token_ids

    def decode(self, token_ids):
        """
        Decode a list of token IDs back into a string.

        Args:
            token_ids (List[int]): The list of token IDs to decode.

        Returns:
            str: The decoded string.
        """
        decoded_string = ""
        for token_id in token_ids:
            if token_id not in self.vocab:
                raise ValueError(f"Token ID {token_id} not found in vocab.")
            token = self.vocab[token_id]
            if token.startswith("Ġ"):
                # Replace 'Ġ' with a space
                decoded_string += " " + token[1:]
            else:
                decoded_string += token
        return decoded_string

    @lru_cache(maxsize=None)
    def get_special_token_id(self, token):
        return self.inverse_vocab.get(token, None)

    @staticmethod
    def find_freq_pair(token_ids, mode="most"):
        pairs = Counter(zip(token_ids, token_ids[1:]))

        if mode == "most":
            return max(pairs.items(), key=lambda x: x[1])[0]
        elif mode == "least":
            return min(pairs.items(), key=lambda x: x[1])[0]
        else:
            raise ValueError("Invalid mode. Choose 'most' or 'least'.")

    @staticmethod
    def replace_pair(token_ids, pair_id, new_id):
        dq = deque(token_ids)
        replaced = []

        while dq:
            current = dq.popleft()
            if dq and (current, dq[0]) == pair_id:
                replaced.append(new_id)
                # Remove the 2nd token of the pair, 1st was already removed
                dq.popleft()
            else:
                replaced.append(current)

        return replaced