### Jupyter Notebook Settings

In [None]:
from IPython.core.display import display, HTML                                    
display(HTML("<style>.container { width:100% !important; }</style>"))  
import IPython.display as display

### Libraries

In [None]:
import re
import struct

### Data

In [None]:
header_filepath = '/Users/ciprian/Desktop/Projects/Microcontroller Tokenizer/microcontroller-tokenizer/data/bpe-vocab_bookcorpus-small_100000-merges_cleaned_radix-tree_no-markers.h'

### Main Code

In [None]:
# Step 1: Load the byte array from the .h file
def load_byte_array_from_h_file(h_file_path):
    with open(h_file_path, 'r') as f:
        content = f.read()
        
    # Use regex to find the array definition
    array_match = re.search(r"unsigned char.*\[\] = \{(.*?)\};", content, re.DOTALL)
    
    if not array_match:
        raise ValueError("No byte array found in the .h file.")
    
    # Extract the array contents and convert them to a bytes object
    byte_array = array_match.group(1)
    byte_values = bytes(int(b, 16) for b in re.findall(r'0x[0-9a-fA-F]+', byte_array))
    
    return byte_values

# Inspect first 32 bytes
inspect_byte_array(byte_array, 0, 32)

# Step 2: Inspect the byte array in small chunks
def inspect_byte_array(byte_array, start=0, num_bytes=16):
    print(f"Bytes from {start} to {start + num_bytes}: {byte_array[start:start + num_bytes]}")

# Inspect the first 16 bytes to start with
inspect_byte_array(byte_array, 0, 16)

# Step 3: Manually parse the first node with additional debugging
def parse_first_node(byte_array):
    offset = 0
    buffer_size = len(byte_array)
    
    print(f"Buffer size: {buffer_size}")
    
    if offset + 8 > buffer_size:
        raise ValueError("Insufficient data to read tokenID and num_children")
    
    tokenID, num_children = struct.unpack_from('>II', byte_array, offset)
    print(f"Token ID: {tokenID}, Number of Children: {num_children}")
    offset += 8
    print(f"New Offset: {offset}")
    
    for child_index in range(num_children):
        if offset + 1 > buffer_size:
            raise ValueError("Insufficient data to read prefix length")
        
        prefix_length = struct.unpack_from('B', byte_array, offset)[0]
        print(f"Prefix Length: {prefix_length}")
        offset += 1
        print(f"Offset after reading prefix length: {offset}")
        
        if offset + prefix_length > buffer_size:
            raise ValueError("Insufficient data to read prefix")
        
        prefix = byte_array[offset:offset + prefix_length]
        print(f"Prefix: {prefix}")
        offset += prefix_length
        print(f"Offset after reading prefix: {offset}")
        
        if offset + 8 > buffer_size:
            raise ValueError("Insufficient data to read child tokenID and num_children")
        
        child_tokenID, child_num_children = struct.unpack_from('>II', byte_array, offset)
        print(f"Child Token ID: {child_tokenID}, Child Number of Children: {child_num_children}")
        offset += 8
        print(f"Offset after reading child node: {offset}")

# Parse the first node
parse_first_node(byte_array)

In [None]:
# Step 2: Create the RadixTree and traverse logic
class RadixNode:
    def __init__(self, tokenID, children):
        self.tokenID = tokenID
        self.children = children

class RadixTree:
    def __init__(self, data):
        self.data = data
        self.root = self.load_tree_iterative()

    def load_tree_iterative(self):
        offset = 0
        stack = []
        root_node = None
        current_node = None

        while offset < len(self.data):
            tokenID, num_children = struct.unpack_from('>II', self.data, offset)
            offset += 8
            children = {}

            # If there's a node on the stack, this becomes its child
            if stack:
                parent_node, prefix_length, expected_children = stack[-1]
                prefix = self.data[offset:offset + prefix_length]
                offset += prefix_length
                children = parent_node.children
                children[prefix] = RadixNode(tokenID, {})
                current_node = children[prefix]

                expected_children -= 1
                if expected_children == 0:
                    stack.pop()

            else:
                # If there's no parent, this is the root node
                root_node = RadixNode(tokenID, {})
                current_node = root_node

            # If this node has children, push it to the stack
            if num_children > 0:
                stack.append((current_node, struct.unpack_from('B', self.data, offset)[0], num_children))
                offset += 1  # Move past the prefix length byte

        return root_node

    def traverse(self, word):
        node = self.root
        word = word.encode('utf-8')  # Convert the input word to binary for comparison
        while word:
            found = False
            for prefix, child in node.children.items():
                if word.startswith(prefix):
                    word = word[len(prefix):]
                    node = child
                    found = True
                    break
            if not found:
                return -1
        return node.tokenID

# Step 3: Clean input function
def clean_input(phrase):
    # Remove leading/trailing whitespace
    phrase = phrase.strip()
    # Convert to lowercase
    phrase = phrase.lower()
    # Remove non-ASCII characters
    phrase = ''.join(c for c in phrase if ord(c) < 128)
    return phrase

# Initialize RadixTree with the loaded byte array
radixTree = RadixTree(byte_array)

# Step 4: Take user input and simulate
def simulate_user_input():
    input_phrase = input("Enter a phrase: ")
    cleaned_input = clean_input(input_phrase)
    print(f"Cleaned Input: {cleaned_input}")

    nodeID = radixTree.traverse(cleaned_input)
    
    if nodeID != -1:
        print(f"Node ID: {nodeID}")
    else:
        print("Phrase not found in the tree.")

simulate_user_input()