### Jupyter Notebook Settings

In [None]:
from IPython.core.display import display, HTML                                    
display(HTML("<style>.container { width:100% !important; }</style>"))  
import IPython.display as display

### Libraries

In [None]:
import re
import heapq
import struct
from tqdm import tqdm
from collections import Counter, defaultdict
from threading import Lock
from concurrent.futures import ThreadPoolExecutor, as_completed

### Conver BPE Vocabulary to a Binary Radix Tree
##### It will apply the special characters escaping function to have it working properly in C++
##### This function will keep the original BPE "_" marker for prefix and "</w>" for end of word

In [None]:
# Function to escape special characters in a string for C++ string literals
def escape_special_characters(token):
    token = token.replace('\\', '\\\\')
    token = token.replace('\"', '\\\"')
    token = token.replace('\'', '\\\'')
    token = token.replace('\n', '\\n')
    token = token.replace('\r', '\\r')
    token = token.replace('\t', '\\t')
    token = token.replace('<space>', ' ')
    return token

class RadixNode:
    def __init__(self):
        self.children = {}
        self.tokenID = -1

class RadixTree:
    def __init__(self):
        self.root = RadixNode()

    def insert(self, word, tokenID):
        node = self.root

        # Handle words that start with '_'
        if word.startswith('_'):
            prefix_marker = '_'
            word = word[1:]  # Remove the underscore for further processing
            if prefix_marker not in node.children:
                node.children[prefix_marker] = RadixNode()
            node = node.children[prefix_marker]

        while word:
            for prefix in node.children:
                common_length = 0
                for i in range(min(len(prefix), len(word))):
                    if prefix[i] != word[i]:
                        break
                    common_length += 1

                if common_length > 0:
                    if common_length < len(prefix):
                        existing_child = node.children.pop(prefix)
                        new_node = RadixNode()
                        new_node.children[prefix[common_length:]] = existing_child
                        node.children[prefix[:common_length]] = new_node
                        node = new_node
                    else:
                        node = node.children[prefix]

                    word = word[common_length:]
                    break
            else:
                node.children[word] = RadixNode()
                node = node.children[word]
                word = ""

        # Set tokenID at the end of the word
        node.tokenID = tokenID

    def traverse(self, word):
        node = self.root

        # Handle words that start with '_'
        if word.startswith('_'):
            prefix_marker = '_'
            word = word[1:]  # Remove the underscore for further processing
            if prefix_marker in node.children:
                node = node.children[prefix_marker]
            else:
                return None

        while word:
            found = False
            for prefix, child in node.children.items():
                if word.startswith(prefix):
                    word = word[len(prefix):]
                    node = child
                    found = True
                    break
            if not found:
                return None
        return node.tokenID if node.tokenID != -1 else None

    def serialize_to_binary(self):
        def node_to_binary(node):
            children_data = b''.join(
                struct.pack(f'i{len(prefix)}s', len(prefix), prefix.encode('utf-8')) + node_to_binary(child)
                for prefix, child in node.children.items()
            )
            tokenID_data = struct.pack('i', node.tokenID)
            num_children = len(node.children)
            num_children_data = struct.pack('i', num_children)
            return tokenID_data + num_children_data + children_data

        return node_to_binary(self.root)

    def deserialize_from_binary(self, binary_data):
        def binary_to_node(data, offset=0):
            tokenID = struct.unpack_from('i', data, offset)[0]
            offset += 4
            num_children = struct.unpack_from('i', data, offset)[0]
            offset += 4
            node = RadixNode()
            node.tokenID = tokenID
            for _ in range(num_children):
                prefix_len = struct.unpack_from('i', data, offset)[0]
                offset += 4
                prefix = struct.unpack_from(f'{prefix_len}s', data, offset)[0].decode('utf-8')
                offset += prefix_len
                child, offset = binary_to_node(data, offset)
                node.children[prefix] = child
            return node, offset

        self.root, _ = binary_to_node(binary_data)

    def save_to_file(self, filename):
        binary_data = self.serialize_to_binary()
        with open(filename, 'wb') as file:
            file.write(binary_data)

    def load_from_file(self, filename):
        with open(filename, 'rb') as file:
            binary_data = file.read()
        self.deserialize_from_binary(binary_data)

def load_bpe_vocab(filename):
    with open(filename, 'r') as file:
        vocab = file.read().splitlines()
    return vocab

def clean_and_insert_vocab_into_tree(vocab, radix_tree):
    for idx, word in enumerate(tqdm(vocab, desc="Inserting vocab into Radix Tree")):
        # Process words according to the rules
        cleaned_word = escape_special_characters(word)
        radix_tree.insert(cleaned_word, idx)

def test_radix_tree(radix_tree, test_words):
    results = {}
    for word in test_words:
        cleaned_word = escape_special_characters(word)
        result = radix_tree.traverse(cleaned_word)
        results[word] = {"found": result is not None, "node": result}
    return results

# Load the vocab from the file
filename = 'bpe-vocab_bookcorpus-30p_25000-merges_cleaned.txt'
vocab = load_bpe_vocab(filename)

# Create the Radix Tree
radix_tree = RadixTree()

# Clean and insert the vocab into the Radix Tree
clean_and_insert_vocab_into_tree(vocab, radix_tree)

# Save the Radix tree to a .bin file
radix_tree.save_to_file('bpe-vocab_bookcorpus-30p_25000-merges_cleaned_radix-tree.bin')

### Visaulise Radix Tree

In [None]:
def visualize_radix_tree(node, prefix='', depth=0, max_depth=3):
    """
    Visualizes part of the Radix Tree.
    
    Args:
    - node: The current RadixNode to visualize.
    - prefix: The current prefix string.
    - depth: The current depth in the tree.
    - max_depth: The maximum depth to traverse for visualization.
    """
    if depth > max_depth:
        return

    indent = '  ' * depth
    node_info = f"TokenID: {node.tokenID}" if node.tokenID != -1 else "No TokenID"
    print(f"{indent}{prefix} ({node_info})")

    for child_prefix, child_node in node.children.items():
        visualize_radix_tree(child_node, prefix + child_prefix, depth + 1, max_depth)

# Example usage to visualize the first few levels of the Radix Tree
radix_tree_deserialized = RadixTree()
radix_tree_deserialized.load_from_file('bpe-vocab_bookcorpus-30p_25000-merges_cleaned_radix-tree.bin')

# Visualize the first 3 levels of the Radix Tree starting from the root
visualize_radix_tree(radix_tree_deserialized.root, max_depth=4)

### Convert the binary file to a C header (.h)
#### This can also be achieved by using the command `xxd -i bpe_vocab_small_radix_tree.bin > bpe_vocab_small_radix_tree.h` in the terminal

In [None]:
def binary_to_c_header(binary_file_path, header_file_path, array_name):
    with open(binary_file_path, "rb") as binary_file, open(header_file_path, "w") as header_file:
        binary_data = binary_file.read()

        # Start writing the header file, but we won't count this part in the binary comparison
        header_file.write(f"#ifndef {array_name.upper()}_H\n")
        header_file.write(f"#define {array_name.upper()}_H\n\n")
        header_file.write(f"const unsigned char {array_name}[] = {{\n    ")

        # Write the binary data as a C array
        for i, byte in enumerate(binary_data):
            if i > 0:
                header_file.write(", ")
            header_file.write(f"0x{byte:02x}")

# Regenerate the C header file
binary_file_path = "simplified_radix_tree.bin"
header_file_path = "simplified_radix_tree.h"
array_name = "simplified_radix_tree"

binary_to_c_header(binary_file_path, header_file_path, array_name)

### Testing the Binary Tree to find given words

In [None]:
# Load the Radix tree from the .bin file
radix_tree_deserialized = RadixTree()
radix_tree_deserialized.load_from_file('bpe-vocab_bookcorpus-30p_25000-merges_cleaned_radix-tree.bin')

# Test with 5 words from the vocab
test_words = ["rocket</w>", "rocketman</w>", "exasperated</w>", "hello</w>", "_am</w>", " ", "."]
results = test_radix_tree(radix_tree_deserialized, test_words)

results

### Testing the Binary Radix Tree to encode and decode a given phrase

In [None]:
# Load the BPE vocabulary from the cleaned vocabulary file
def load_bpe_vocabulary(vocab_file_path):
    with open(vocab_file_path, "r") as file:
        vocabulary = [line.strip() for line in file.readlines()]
    # Remove the `</w>` suffix and handle the <space> token
    vocabulary = [token[:-4] if token.endswith("</w>") else token for token in vocabulary]
    vocab_dict = {token: idx for idx, token in enumerate(vocabulary)}
    return vocab_dict, vocabulary

# Function to clean the input: remove non-ASCII characters, trim whitespace, and convert to lowercase
def clean_input(phrase):
    # Remove non-ASCII characters and convert to lowercase
    cleaned_phrase = ''.join([c.lower() for c in phrase if 32 <= ord(c) <= 126])

    # Trim leading and trailing whitespace
    cleaned_phrase = cleaned_phrase.strip()

    return cleaned_phrase

# Tokenize a phrase using the BPE vocabulary
def tokenize_bpe(phrase, vocab_dict):
    tokens = []
    phrase = phrase.replace(" ", "<space>")  # Replace spaces with the <space> token
    i = 0
    while i < len(phrase):
        match = None
        max_length = 0
        for j in range(i + 1, len(phrase) + 1):
            subword = phrase[i:j]
            if subword in vocab_dict:
                match = subword
                max_length = j - i  # Track the length of the longest match

        if match:
            tokens.append(vocab_dict[match])
            i += max_length  # Move the index forward by the length of the matched token
        else:
            # Debugging: Print the unmatched character(s)
            print(f"No match for: {phrase[i]}")
            i += 1  # Skip characters not found in the vocabulary
    return tokens

# Convert token IDs to binary representation (24-bit)
def tokens_to_binary(tokens):
    binary_output = bytearray()
    for token_id in tokens:
        binary_output.extend(struct.pack('>I', token_id)[1:])  # Pack as big-endian, skip the first byte for 24-bit
    return binary_output

# Decode the binary output back to the original phrase
def decode_bpe(binary_output, vocabulary):
    tokens = []
    for i in range(0, len(binary_output), 3):
        token_id = struct.unpack('>I', b'\x00' + binary_output[i:i+3])[0]  # Unpack with padding for 24-bit
        tokens.append(vocabulary[token_id])
    decoded_phrase = ''.join(tokens)
    return decoded_phrase.replace("<space>", " ")  # Replace <space> with a regular space

# Load the vocabulary
vocab_file_path = "bpe_vocab_cleaned.txt"
vocab_dict, vocabulary = load_bpe_vocabulary(vocab_file_path)

# Input phrase from user
phrase = input("Enter a phrase to tokenize: ")

# Clean the input string
cleaned_phrase = clean_input(phrase)

# Tokenize the phrase
tokens = tokenize_bpe(cleaned_phrase, vocab_dict)
print("Token IDs:", tokens)

# Convert tokens to binary
if tokens:
    binary_output = tokens_to_binary(tokens)
    print("Binary output:", ' '.join(format(byte, '08b') for byte in binary_output))
else:
    print("No tokens matched.")

# Decode the binary output back to the original phrase
if tokens:
    decoded_phrase = decode_bpe(binary_output, vocabulary)
    print("Decoded phrase:", decoded_phrase)
else:
    print("Nothing to decode.")