### Jupyter Notebook Settings

In [1]:
from IPython.core.display import display, HTML                                    
display(HTML("<style>.container { width:100% !important; }</style>"))  
import IPython.display as display

  from IPython.core.display import display, HTML


### Libraries

In [None]:
import re
import heapq
import struct
from tqdm import tqdm
from collections import Counter, defaultdict
from threading import Lock
from concurrent.futures import ThreadPoolExecutor, as_completed

### Radix Tree Class

In [None]:
# RadixNode and RadixTree classes as defined before
class RadixNode:
    def __init__(self):
        self.children = {}
        self.tokenID = -1

class RadixTree:
    def __init__(self):
        self.root = RadixNode()

    def insert(self, word, tokenID):
        node = self.root
        while word:
            for prefix in node.children:
                common_length = 0
                for i in range(min(len(prefix), len(word))):
                    if prefix[i] != word[i]:
                        break
                    common_length += 1

                if common_length > 0:
                    if common_length < len(prefix):
                        existing_child = node.children.pop(prefix)
                        new_node = RadixNode()
                        new_node.children[prefix[common_length:]] = existing_child
                        node.children[prefix[:common_length]] = new_node
                        node = new_node
                    else:
                        node = node.children[prefix]

                    word = word[common_length:]
                    break
            else:
                node.children[word] = RadixNode()
                node = node.children[word]
                word = ""
        node.tokenID = tokenID
        
    def traverse(self, word):
        node = self.root
        while word:
            found = False
            for prefix, child in node.children.items():
                if word.startswith(prefix):
                    word = word[len(prefix):]
                    node = child
                    found = True
                    break
            if not found:
                return None
        return node.tokenID if node.tokenID != -1 else None

    def serialize_to_binary(self):
        def node_to_binary(node):
            children_data = b''.join(
                struct.pack(f'i{len(prefix)}s', len(prefix), prefix.encode('utf-8')) + node_to_binary(child)
                for prefix, child in node.children.items()
            )
            tokenID_data = struct.pack('i', node.tokenID)
            num_children = len(node.children)
            num_children_data = struct.pack('i', num_children)
            return tokenID_data + num_children_data + children_data

        return node_to_binary(self.root)

    def save_to_file(self, filename):
        binary_data = self.serialize_to_binary()
        with open(filename, 'wb') as file:
            file.write(binary_data)

    def deserialize_from_binary(self, binary_data):
        def binary_to_node(data, offset=0):
            tokenID = struct.unpack_from('i', data, offset)[0]
            offset += 4
            num_children = struct.unpack_from('i', data, offset)[0]
            offset += 4
            node = RadixNode()
            node.tokenID = tokenID
            for _ in range(num_children):
                prefix_len = struct.unpack_from('i', data, offset)[0]
                offset += 4
                prefix = struct.unpack_from(f'{prefix_len}s', data, offset)[0].decode('utf-8')
                offset += prefix_len
                child, offset = binary_to_node(data, offset)
                node.children[prefix] = child
            return node, offset

        self.root, _ = binary_to_node(binary_data)
        
    def get_word_from_id(self, token_id):
        return find_word_by_node_id(self.root, token_id)

    def load_from_file(self, filename):
        with open(filename, 'rb') as file:
            binary_data = file.read()
        self.deserialize_from_binary(binary_data)

### Functions

In [None]:
# Function to escape special characters in a string for C++ string literals
def escape_special_characters(token):
    token = token.replace('\\', '\\\\')
    token = token.replace('\"', '\\\"')
    token = token.replace('\'', '\\\'')
    token = token.replace('\n', '\\n')
    token = token.replace('\r', '\\r')
    token = token.replace('\t', '\\t')
    token = token.replace('<space>', ' ')  # Convert <space> to whitespace
    return token

# Function to clean the input phrase
def clean_input(phrase):
    # Remove any spaces at the beginning or end
    phrase = phrase.strip()
    # Convert to lowercase
    phrase = phrase.lower()
    # Remove non-ASCII characters
    phrase = ''.join(char for char in phrase if ord(char) < 128)
    return phrase

# Function to clean the word, removing `_` and `</w>` if applicable
def clean_word(word):
    if word != '_':  # Keep standalone underscore
        word = word.lstrip('_')  # Remove leading underscore
    word = word.replace('</w>', '')  # Remove end-of-word marker
    word = escape_special_characters(word)
    return word

def clean_decoded_phrase(decoded_phrase):
    words = decoded_phrase.split()
    cleaned_words = []
    i = 0
    while i < len(words):
        word = words[i]
        
        if word.startswith('_') and word.endswith('</w>'):
            word = word[1:-4]  # Remove _ and </w>
        elif word.startswith('_'):
            word = word[1:]  # Just remove the _
        elif word.endswith('</w>'):
            word = word[:-4]  # Remove just </w>
        
        # Check if the current word is followed by a symbol
        if i < len(words) - 1 and words[i + 1] in ['!', '?', '.', ',', ':', ';']:
            word += words[i + 1]  # Attach the symbol directly to the word
            i += 1  # Skip the next word as it's already merged
        
        cleaned_words.append(word)
        i += 1
    
    return ' '.join(cleaned_words)

# Function to split the phrase into words and symbols
def tokenize_phrase(phrase):
    # Split the phrase into words and standalone symbols
    tokens = re.findall(r'\w+|[^\w\s]', phrase, re.UNICODE)
    return tokens

# Function to load the BPE vocabulary and insert it into the Radix Tree
def load_and_insert_vocab_into_tree(filename, radix_tree):
    with open(filename, 'r') as file:
        vocab = file.read().splitlines()

    for idx, word in enumerate(vocab):
        cleaned_word = clean_word(word)
        radix_tree.insert(cleaned_word, idx)

# Function to encode node IDs to binary and hex
def encode_node_ids(node_ids):
    # Encode to binary
    binary_representation = b''.join(struct.pack('i', node_id) for node_id in node_ids)
    # Encode to hex
    hex_representation = binary_representation.hex()
    return binary_representation, hex_representation
        
# Function to decode a phrase using the new system
def phrase_to_node_ids(phrase, radix_tree):
    tokens = tokenize_phrase(phrase)
    node_ids = []
    for token in tokens:
        node_id = radix_tree.traverse(escape_special_characters(token))
        if node_id is not None:
            node_ids.append(node_id)
        else:
            raise ValueError(f"Token '{token}' not found in the Radix Tree.")
    return node_ids

def node_ids_to_phrase(node_ids, radix_tree_deserialized):
    decoded_words = []
    for node_id in node_ids:
        word = radix_tree_deserialized.get_word_from_id(node_id)
        
        # Check if the word is a symbol or punctuation
        if word.isalnum():  # If it's alphanumeric, it might need the _ or </w>
            if not word.startswith('_'):
                word = f"_{word}"
            word += "</w>"
        decoded_words.append(word)
    
    return ' '.join(decoded_words)

# Helper function to find the word by node ID
def find_word_by_node_id(node, target_id, prefix=''):
    if node.tokenID == target_id:
        return prefix
    for child_prefix, child_node in node.children.items():
        result = find_word_by_node_id(child_node, target_id, prefix + child_prefix)
        if result:
            return result
    return None

def find_children_by_node_id(node, target_id):
    if node.tokenID == target_id:
        return node.children
    for child_node in node.children.values():
        children = find_children_by_node_id(child_node, target_id)
        if children is not None:
            return children
    return None

def test_radix_tree(radix_tree, test_words):
    results = {}
    for word in test_words:
        cleaned_word = escape_special_characters(word)
        node_id = radix_tree.traverse(cleaned_word)
        results[word] = {
            "found": node_id is not None,
            "node_id": node_id
        }
    return results

### Conver BPE Vocabulary to a Binary Radix Tree
##### It will apply the special characters escaping function to have it working properly in C++
##### This function will keep the original BPE "_" marker for prefix and "</w>" for end of word

In [None]:
bpe_vocab_filepath = '/Users/ciprian/Desktop/Projects/Microcontroller Tokenizer/microcontroller-tokenizer/data/bpe-vocab_bookcorpus-small_25000-merges_cleaned-v2.txt'
radix_tree_filepath = '/Users/ciprian/Desktop/Projects/Microcontroller Tokenizer/microcontroller-tokenizer/data/bpe-vocab_bookcorpus-small_25000-merges_cleaned-v2_radix-tree_no-markers.bin'

In [None]:
# Load the BPE vocabulary and create the Radix Tree
radix_tree = RadixTree()

load_and_insert_vocab_into_tree(bpe_vocab_filepath, radix_tree)

# Save the Radix tree to a .bin file
radix_tree.save_to_file(radix_tree_filepath)

### Visaulise Radix Tree

In [None]:
def visualize_radix_tree(node, prefix='', depth=0, max_depth=3):
    """
    Visualizes part of the Radix Tree.
    
    Args:
    - node: The current RadixNode to visualize.
    - prefix: The current prefix string.
    - depth: The current depth in the tree.
    - max_depth: The maximum depth to traverse for visualization.
    """
    if depth > max_depth:
        return

    indent = '  ' * depth
    node_info = f"TokenID: {node.tokenID}" if node.tokenID != -1 else "No TokenID"
    print(f"{indent}{prefix} ({node_info})")

    for child_prefix, child_node in node.children.items():
        visualize_radix_tree(child_node, prefix + child_prefix, depth + 1, max_depth)

# Example usage to visualize the first few levels of the Radix Tree
radix_tree_deserialized = RadixTree()
radix_tree_deserialized.load_from_file(radix_tree_filepath)

# Visualize the first 3 levels of the Radix Tree starting from the root
visualize_radix_tree(radix_tree_deserialized.root, max_depth=2)

### Convert the binary file to a C header (.h)
#### This can also be achieved by using the command `xxd -i bpe_vocab_small_radix_tree.bin > bpe_vocab_small_radix_tree.h` in the terminal

In [None]:
def binary_to_c_header(binary_file_path, header_file_path, array_name):
    with open(binary_file_path, "rb") as binary_file, open(header_file_path, "w") as header_file:
        binary_data = binary_file.read()

        # Start writing the header file, but we won't count this part in the binary comparison
        header_file.write(f"#ifndef {array_name.upper()}_H\n")
        header_file.write(f"#define {array_name.upper()}_H\n\n")
        header_file.write(f"const unsigned char {array_name}[] = {{\n    ")

        # Write the binary data as a C array
        for i, byte in enumerate(binary_data):
            if i > 0:
                header_file.write(", ")
            header_file.write(f"0x{byte:02x}")

# Regenerate the C header file
binary_file_path = "simplified_radix_tree.bin"
header_file_path = "simplified_radix_tree.h"
array_name = "simplified_radix_tree"

binary_to_c_header(binary_file_path, header_file_path, array_name)

### Testing the Binary Tree to find given words

In [None]:
# Load the Radix tree from the .bin file
radix_tree_deserialized = RadixTree()
radix_tree_deserialized.load_from_file(radix_tree_filepath)

# Test with 5 words from the vocab
test_words = ["rocket", "rocketman", "exasperated", "hello", "am", " ", ".", "153"]
results = test_radix_tree(radix_tree_deserialized, test_words)

results

### Testing the Binary Radix Tree to encode and decode a given phrase

In [None]:
# Example usage of the new system
radix_tree_deserialized = RadixTree()
radix_tree_deserialized.load_from_file(radix_tree_filepath)

phrase = "  Hello There 5! How are You?  "
cleaned_phrase = clean_input(phrase)
node_ids = phrase_to_node_ids(cleaned_phrase, radix_tree_deserialized)
binary_representation, hex_representation = encode_node_ids(node_ids)
decoded_phrase = node_ids_to_phrase(node_ids, radix_tree_deserialized)
cleaned_decoded_phrase = clean_decoded_phrase(decoded_phrase)

print(f"Original Input: {phrase}")
print(f"Cleaned Input: {cleaned_phrase}")
print(f"Node IDs: {node_ids}")
print(f"Hex Encoding: {hex_representation}")
print(f"Binary Encoding: {binary_representation}")
print(f"Decoded Words: {decoded_phrase}")
print(f"Cleaned Decoded Phrase: {cleaned_decoded_phrase}")