In [1]:
from sglang.srt.managers.router.radix_cache import RadixCache
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

class RadixBalancer:
    def __init__(self):
        self.root = {}  # Root node of the radix trie
        self.request_count = 0
        self.split_frequency = 2  # Split trie every 1000 requests
        self.radix_cache = RadixCache()
    
    def process_request(self, request):
        # Process the request and update trie
        tokens = tokenizer.encode(request)
        self.radix_cache.insert(tuple(tokens))
        # Increment request count
        self.request_count += 1
        
        # Check if it's time to split the trie
        if self.request_count >= self.split_frequency:
            self.split_trie()
            self.request_count = 0  # Reset request count after splitting
    
    def split_trie(self):
        # Initialize new nodes for the split
        new_nodes = [{}, {}]
        
        # Distribute existing nodes' children evenly between new nodes
        for child_key in self.root:
            child_node = self.root[child_key]
            index = 0 if len(new_nodes[0]) <= len(new_nodes[1]) else 1
            new_nodes[index][child_key] = child_node
        
        # Update root to point to the new nodes
        self.root = new_nodes

# Example usage
radix_cache = RadixBalancer()

# Simulate processing requests
requests = ["hello", "world", "python", "algorithm", "implementation", "splitting"]
for request in requests:
    radix_cache.process_request(request)