### Jupyter Notebook Settings

In [None]:
from IPython.core.display import display, HTML                                    
display(HTML("<style>.container { width:100% !important; }</style>"))  
import IPython.display as display

### Libraries

In [None]:
import struct
import os

### Convert existing Bin file to Hex

In [None]:
def bin_to_hex(bin_file_path, output_header_path, array_name="radix_tree_hex"):
    with open(bin_file_path, 'rb') as bin_file:
        binary_data = bin_file.read()

    hex_data = binary_data.hex()

    with open(output_header_path, 'w') as header_file:
        header_file.write('#include <cstdint>\n\n')
        header_file.write(f'const uint8_t {array_name}[] = {{\n')

        for i in range(0, len(hex_data), 2):
            header_file.write(f'0x{hex_data[i:i+2]}, ')
            if (i // 2 + 1) % 16 == 0:
                header_file.write('\n')
        
        header_file.write('\n};\n')
        header_file.write(f'const unsigned int {array_name}_len = {len(binary_data)};\n')

bin_file_path = '/Users/ciprian/Desktop/Projects/Microcontroller Tokenizer/microcontroller-tokenizer/data/bpe-vocab_bookcorpus-small_25000-merges_cleaned-v2_radix-tree_no-markers.bin'
output_header_path = '/Users/ciprian/Desktop/Projects/Microcontroller Tokenizer/microcontroller-tokenizer/data/bpe-vocab_bookcorpus-small_25000-merges_cleaned-v2_hex-radix-tree_no-markers.h'

bin_to_hex(bin_file_path, output_header_path)

In [None]:
### Create Radix Tree directly in Hex (Serialization and Deserialization being done in hex format)

In [None]:
bpe_vocab_filepath = '/Users/ciprian/Desktop/Projects/Microcontroller Tokenizer/microcontroller-tokenizer/data/bpe-vocab_bookcorpus-30p_25000-merges_cleaned-v2.txt'
radix_tree_filepath = '/Users/ciprian/Desktop/Projects/Microcontroller Tokenizer/microcontroller-tokenizer/data/bpe-vocab_bookcorpus-30p_25000-merges_cleaned-v2_hex-radix-tree_no-markers.h'

In [None]:
class RadixNode:
    def __init__(self):
        self.children = {}
        self.tokenID = -1

class RadixTree:
    def __init__(self):
        self.root = RadixNode()

    def insert(self, word, tokenID):
        node = self.root
        while word:
            for prefix in node.children:
                common_length = 0
                for i in range(min(len(prefix), len(word))):
                    if prefix[i] != word[i]:
                        break
                    common_length += 1

                if common_length > 0:
                    if common_length < len(prefix):
                        existing_child = node.children.pop(prefix)
                        new_node = RadixNode()
                        new_node.children[prefix[common_length:]] = existing_child
                        node.children[prefix[:common_length]] = new_node
                        node = new_node
                    else:
                        node = node.children[prefix]

                    word = word[common_length:]
                    break
            else:
                node.children[word] = RadixNode()
                node = node.children[word]
                word = ""
        node.tokenID = tokenID

    def serialize_to_hex(self):
        def node_to_hex(node):
            children_data = ''.join(
                struct.pack(f'i{len(prefix)}s', len(prefix), prefix.encode('utf-8')).hex() + node_to_hex(child)
                for prefix, child in node.children.items()
            )
            tokenID_data = struct.pack('i', node.tokenID).hex()
            num_children = len(node.children)
            num_children_data = struct.pack('i', num_children).hex()
            return tokenID_data + num_children_data + children_data

        return node_to_hex(self.root)

    def save_to_hex_file(self, filename):
        hex_data = self.serialize_to_hex()
        with open(filename, 'w') as file:
            file.write(hex_data)
            
    def traverse(self, word):
        node = self.root
        while word:
            found = False
            for prefix, child in node.children.items():
                if word.startswith(prefix):
                    word = word[len(prefix):]
                    node = child
                    found = True
                    break
            if not found:
                return None
        return node.tokenID if node.tokenID != -1 else None

    def deserialize_from_hex(self, hex_data):
        def hex_to_node(data, offset=0):
            # Convert the hex string to bytes
            byte_data = bytes.fromhex(data)
            if offset + 4 > len(byte_data):
                raise ValueError("Unexpected end of data while reading tokenID")

            tokenID = struct.unpack_from('i', byte_data, offset)[0]
            offset += 4

            if offset + 4 > len(byte_data):
                raise ValueError("Unexpected end of data while reading num_children")

            num_children = struct.unpack_from('i', byte_data, offset)[0]
            offset += 4

            node = RadixNode()
            node.tokenID = tokenID

            for _ in range(num_children):
                if offset + 4 > len(byte_data):
                    raise ValueError("Unexpected end of data while reading prefix_len")

                prefix_len = struct.unpack_from('i', byte_data, offset)[0]
                offset += 4

                if offset + prefix_len > len(byte_data):
                    raise ValueError("Unexpected end of data while reading prefix")

                prefix = struct.unpack_from(f'{prefix_len}s', byte_data, offset)[0]
                offset += prefix_len

                child, offset = hex_to_node(data, offset)
                node.children[prefix.decode('utf-8')] = child  # Decode using utf-8

            return node, offset

        self.root, _ = hex_to_node(hex_data)

    def load_from_hex_file(self, filename):
        with open(filename, 'r') as file:
            hex_data = file.read().strip()  # Strip to remove any extra spaces or newlines
        self.deserialize_from_hex(hex_data)

    def get_word_from_id(self, token_id):
        return find_word_by_node_id(self.root, token_id)

# Helper functions to traverse the Radix Tree
def find_word_by_node_id(node, target_id, prefix=''):
    if node.tokenID == target_id:
        return prefix
    for child_prefix, child_node in node.children.items():
        result = find_word_by_node_id(child_node, target_id, prefix + child_prefix)
        if result:
            return result
    return None

def find_children_by_node_id(node, target_id):
    if node.tokenID == target_id:
        return node.children
    for child_node in node.children.values():
        children = find_children_by_node_id(child_node, target_id)
        if children is not None:
            return children
    return None

# Function to clean the word, removing `_` and `</w>` if applicable
def clean_word(word):
    if word != '_':  # Keep standalone underscore
        word = word.lstrip('_')  # Remove leading underscore
    word = word.replace('</w>', '')  # Remove end-of-word marker
    return word

# Function to load the BPE vocabulary and insert it into the Radix Tree
def load_and_insert_vocab_into_tree(filename, radix_tree):
    with open(filename, 'r') as file:
        vocab = file.read().splitlines()

    for idx, word in enumerate(vocab):
        cleaned_word = clean_word(word)
        radix_tree.insert(cleaned_word, idx)

# Assuming `bpe_vocab_filepath` contains the vocab and `radix_tree_hex_filepath` is the hex output file.
radix_tree = RadixTree()

# Load and insert the vocabulary
load_and_insert_vocab_into_tree(bpe_vocab_filepath, radix_tree)

# Save the Radix tree to a hex file
radix_tree.save_to_hex_file(radix_tree_filepath)

In [None]:
# Load the Radix tree from the hex file
radix_tree.load_from_hex_file(radix_tree_filepath)

In [None]:
import re
import struct

# Assuming the RadixNode, RadixTree classes, and other helper functions are already defined as before.

def clean_input(phrase):
    # Remove any spaces at the beginning or end
    phrase = phrase.strip()
    # Convert to lowercase
    phrase = phrase.lower()
    # Remove non-ASCII characters
    phrase = ''.join(char for char in phrase if ord(char) < 128)
    return phrase

# Function to split the phrase into words and symbols
def tokenize_phrase(phrase):
    # Split the phrase into words and standalone symbols
    tokens = re.findall(r'\w+|[^\w\s]', phrase, re.UNICODE)
    return tokens

# Function to encode node IDs to hex
def encode_node_ids_to_hex(node_ids):
    # Encode to binary first
    binary_representation = b''.join(struct.pack('i', node_id) for node_id in node_ids)
    # Encode to hex
    hex_representation = binary_representation.hex()
    return hex_representation

# Function to handle user input and perform the encoding
def encode_input_phrase(phrase, radix_tree):
    cleaned_input = clean_input(phrase)
    tokens = tokenize_phrase(cleaned_input)
    node_ids = []

    for token in tokens:
        node_id = radix_tree.traverse(token)
        if node_id is not None:
            node_ids.append(node_id)
        else:
            raise ValueError(f"Token '{token}' not found in the Radix Tree.")

    hex_encoding = encode_node_ids_to_hex(node_ids)
    return cleaned_input, node_ids, hex_encoding

# Example usage:
user_input = "Hello There 5! How are You?"

# Assuming `radix_tree` is already loaded with the hex data
cleaned_input, node_ids, hex_encoding = encode_input_phrase(user_input, radix_tree)

print("Original Input:", user_input)
print("Cleaned Input:", cleaned_input)
print("Node IDs:", node_ids)
print("Hex Encoding:", hex_encoding)

In [None]:
def convert_hex_to_formatted_header(input_filepath, output_filepath):
    # Get the filename without the extension to use in the include guard
    filename = os.path.splitext(os.path.basename(output_filepath))[0].upper()
    
    include_guard = f"{filename}_H"

    with open(input_filepath, 'r') as infile:
        hex_data = infile.read().replace('\n', '').replace(',', '')

    # Prepare the formatted header content
    header_content = f"#ifndef {include_guard}\n#define {include_guard}\n\nconst uint8_t radix_tree_hex[] = {{\n"

    # Split the hex data into chunks of 2 characters (1 byte)
    bytes_list = [hex_data[i:i+2] for i in range(0, len(hex_data), 2)]

    # Group the bytes into lines with a max of 12 bytes per line for readability
    for i in range(0, len(bytes_list), 12):
        line = bytes_list[i:i+12]
        formatted_line = "  " + ", ".join(f"0x{byte}" for byte in line)
        if i + 12 < len(bytes_list):
            formatted_line += ","
        header_content += formatted_line + "\n"

    # Complete the header content
    header_content += "};\n\nconst size_t radix_tree_hex_size = sizeof(radix_tree_hex);\n\n"
    header_content += f"#endif // {include_guard}\n"

    # Write the formatted header content to the output file
    with open(output_filepath, 'w') as outfile:
        outfile.write(header_content)

# Define the output file path
output_filepath = 'bpe-vocab_radix-tree.h'

# Convert the hex data to the formatted header
convert_hex_to_formatted_header(radix_tree_filepath, output_filepath)

### Testing the C Header Hex File

In [None]:
import re
import struct

class RadixNode:
    def __init__(self):
        self.children = {}
        self.tokenID = -1

class RadixTree:
    def __init__(self):
        self.root = RadixNode()

    def insert(self, word, tokenID):
        node = self.root
        while word:
            for prefix in node.children:
                common_length = 0
                for i in range(min(len(prefix), len(word))):
                    if prefix[i] != word[i]:
                        break
                    common_length += 1

                if common_length > 0:
                    if common_length < len(prefix):
                        existing_child = node.children.pop(prefix)
                        new_node = RadixNode()
                        new_node.children[prefix[common_length:]] = existing_child
                        node.children[prefix[:common_length]] = new_node
                        node = new_node
                    else:
                        node = node.children[prefix]

                    word = word[common_length:]
                    break
            else:
                node.children[word] = RadixNode()
                node = node.children[word]
                word = ""
        node.tokenID = tokenID

    def traverse(self, word):
        node = self.root
        while word:
            found = False
            for prefix, child in node.children.items():
                if word.startswith(prefix):
                    word = word[len(prefix):]
                    node = child
                    found = True
                    break
            if not found:
                return None
        return node.tokenID if node.tokenID != -1 else None

    def deserialize_from_hex(self, hex_data):
        def hex_to_node(data, offset=0):
            # Convert the hex string to bytes
            byte_data = bytes.fromhex(data)
            if offset + 4 > len(byte_data):
                raise ValueError("Unexpected end of data while reading tokenID")

            tokenID = struct.unpack_from('i', byte_data, offset)[0]
            offset += 4

            if offset + 4 > len(byte_data):
                raise ValueError("Unexpected end of data while reading num_children")

            num_children = struct.unpack_from('i', byte_data, offset)[0]
            offset += 4

            node = RadixNode()
            node.tokenID = tokenID

            for _ in range(num_children):
                if offset + 4 > len(byte_data):
                    raise ValueError("Unexpected end of data while reading prefix_len")

                prefix_len = struct.unpack_from('i', byte_data, offset)[0]
                offset += 4

                if offset + prefix_len > len(byte_data):
                    raise ValueError("Unexpected end of data while reading prefix")

                prefix = struct.unpack_from(f'{prefix_len}s', byte_data, offset)[0]
                offset += prefix_len

                child, offset = hex_to_node(data, offset)
                node.children[prefix.decode('utf-8')] = child  # Decode using utf-8

            return node, offset

        self.root, _ = hex_to_node(hex_data)

    def load_from_hex_file(self, filename):
        with open(filename, 'r') as file:
            content = file.read()

        # Extract the hex data from the C array in the .h file
        hex_data = re.findall(r'0x[0-9a-fA-F]{2}', content)
        hex_data = ''.join(h[2:] for h in hex_data)  # Remove "0x" and concatenate

        self.deserialize_from_hex(hex_data)

    def get_word_from_id(self, token_id):
        return find_word_by_node_id(self.root, token_id)

# Helper functions to traverse the Radix Tree
def find_word_by_node_id(node, target_id, prefix=''):
    if node.tokenID == target_id:
        return prefix
    for child_prefix, child_node in node.children.items():
        result = find_word_by_node_id(child_node, target_id, prefix + child_prefix)
        if result:
            return result
    return None

def find_children_by_node_id(node, target_id):
    if node.tokenID == target_id:
        return node.children
    for child_node in node.children.values():
        children = find_children_by_node_id(child_node, target_id)
        if children is not None:
            return children
    return None

def clean_input(phrase):
    phrase = phrase.strip().lower()
    phrase = ''.join(char for char in phrase if ord(char) < 128)
    return phrase

def tokenize_phrase(phrase):
    return re.findall(r'\w+|[^\w\s]', phrase, re.UNICODE)

def encode_node_ids_to_hex(node_ids):
    binary_representation = b''.join(struct.pack('i', node_id) for node_id in node_ids)
    return binary_representation.hex()

def encode_input_phrase(phrase, radix_tree):
    cleaned_input = clean_input(phrase)
    tokens = tokenize_phrase(cleaned_input)
    node_ids = []

    for token in tokens:
        node_id = radix_tree.traverse(token)
        if node_id is not None:
            node_ids.append(node_id)
        else:
            raise ValueError(f"Token '{token}' not found in the Radix Tree.")

    hex_encoding = encode_node_ids_to_hex(node_ids)
    return cleaned_input, node_ids, hex_encoding

# Load the Radix tree from the hex data in the .h file
radix_tree_hex_filepath = '/Users/ciprian/Desktop/Projects/Microcontroller Tokenizer/microcontroller-tokenizer/data_processing_python/bpe-vocab_radix-tree.h'
radix_tree = RadixTree()
radix_tree.load_from_hex_file(radix_tree_hex_filepath)

# Example usage: Encode an input phrase
user_input = "buddies"
cleaned_input, node_ids, hex_encoding = encode_input_phrase(user_input, radix_tree)

print("Original Input:", user_input)
print("Cleaned Input:", cleaned_input)
print("Node IDs:", node_ids)
print("Hex Encoding:", hex_encoding)

In [None]:
class RadixNode:
    def __init__(self, tokenID=-1):
        self.children = {}
        self.tokenID = tokenID

class RadixTree:
    def __init__(self):
        self.root = RadixNode()

    def insert(self, word, tokenID):
        node = self.root
        while word:
            for prefix in node.children:
                common_length = 0
                for i in range(min(len(prefix), len(word))):
                    if prefix[i] != word[i]:
                        break
                    common_length += 1

                if common_length > 0:
                    if common_length < len(prefix):
                        existing_child = node.children.pop(prefix)
                        new_node = RadixNode()
                        new_node.children[prefix[common_length:]] = existing_child
                        node.children[prefix[:common_length]] = new_node
                        node = new_node
                    else:
                        node = node.children[prefix]

                    word = word[common_length:]
                    break
            else:
                node.children[word] = RadixNode()
                node = node.children[word]
                word = ""
        node.tokenID = tokenID

    def generate_c_code(self, max_children):
        def escape_special_characters(token):
            token = token.replace('\\', '\\\\')
            token = token.replace('\"', '\\\"')
            token = token.replace('\'', '\\\'')
            token = token.replace('\n', '\\n')
            token = token.replace('\r', '\\r')
            token = token.replace('\t', '\\t')
            token = token.replace('<space>', ' ')  # Convert <space> to whitespace
            return token

        def node_to_c_code(node, node_name, child_index=0):
            # First declare all the children
            child_code = ""
            for i, (prefix, child) in enumerate(node.children.items()):
                child_name = f"{node_name}_child{i}"
                child_code += node_to_c_code(child, child_name)
            
            # Then declare the node itself
            code = f"RadixNode {node_name} = {{.tokenID = {node.tokenID}, .children = {{\n"
            for i, (prefix, child) in enumerate(node.children.items()):
                escaped_prefix = escape_special_characters(prefix)
                child_name = f"{node_name}_child{i}"
                code += f'  {{"{escaped_prefix}", &{child_name}}},\n'
            code += f"}}}};\n"

            return child_code + code

        return node_to_c_code(self.root, "root")

    def save_to_c_file(self, filename, max_children):
        c_code = self.generate_c_code(max_children)
        header_content = f"""
#ifndef RADIX_TREE_H
#define RADIX_TREE_H

typedef struct RadixNode {{
    int tokenID;
    struct {{
        const char *prefix;
        struct RadixNode *node;
    }} children[{max_children}];  // Adjusted based on the max children needed
}} RadixNode;

{c_code}

#endif // RADIX_TREE_H
"""
        with open(filename, 'w') as file:
            file.write(header_content)

    def find_max_children(self):
        def max_children(node):
            if not node.children:
                return 0
            return max(len(node.children), max(max_children(child) for child in node.children.values()))

        return max_children(self.root)

# Function to clean the word, removing `_` and `</w>` if applicable
def clean_word(word):
    if word != '_':  # Keep standalone underscore
        word = word.lstrip('_')  # Remove leading underscore
    word = word.replace('</w>', '')  # Remove end-of-word marker
    return word

# Function to load the BPE vocabulary and insert it into the Radix Tree
def load_and_insert_vocab_into_tree(filename, radix_tree):
    with open(filename, 'r') as file:
        vocab = file.read().splitlines()

    for idx, word in enumerate(vocab):
        cleaned_word = clean_word(word)
        radix_tree.insert(cleaned_word, idx)

# Assuming `bpe_vocab_filepath` contains the vocab and `radix_tree_c_filepath` is the C header output file.
bpe_vocab_filepath = '/Users/ciprian/Desktop/Projects/Microcontroller Tokenizer/microcontroller-tokenizer/data/bpe-vocab_bookcorpus-small_25000-merges_cleaned-v2.txt'
radix_tree_c_filepath = 'radix_tree.h'

# Initialize the radix tree
radix_tree = RadixTree()

# Load and insert the vocabulary
load_and_insert_vocab_into_tree(bpe_vocab_filepath, radix_tree)

# Calculate the maximum number of children any node has
max_children = radix_tree.find_max_children()

# Save the Radix tree to a C file, passing the max_children value
radix_tree.save_to_c_file(radix_tree_c_filepath, max_children)