In [3]:
import sys
import os
import pandas as pd
import copy

# Add the parent directory of the 'src' directory to the Python path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname("."), "..")))

In [7]:
from transformers import AutoTokenizer
from benchmark_workload_gen import ToolBenchDataLoader, LoadDistribution
from sglang.srt.managers.router.radix_cache import RadixCache

num_workloads = 100
num_requests = 4096
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
dataloader = ToolBenchDataLoader('G1_workload_updated_input_output_lengths_4096_cropped_to_50.json', num_workloads, num_requests, tokenizer, LoadDistribution.EVEN)

In [8]:
workload = dataloader.generate_workload(k=1.1)

4000


In [12]:
import re
def get_tool(workload_item):
    text = workload_item["text"]
    match = re.search(r"You have access of the following tools:\n1.(.+?): ", text)
    if match:
        tool = match.group(1)
        return tool
get_tool(workload[0]), get_tool(workload[1])

('surebets', 'manga_scrapper')

In [130]:
import heapq
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import Tuple

import torch

class TreeNode:
    def __init__(self):
        self.children = defaultdict(TreeNode)
        self.parent = None
        self.value = None
        self.ref_counter = 0
        self.last_access_time = time.time()

    @property
    def num_tokens(self):
        return len(self.value)
        
    def __lt__(self, other):
        return self.last_access_time < other.last_access_time


def match(key, seq):
    i = 0
    for k, w in zip(key, seq):
        if k != w:
            break
        i += 1
    return i


class RadixCache:
    def __init__(self, disable=False):
        self.reset()
        self.disable = disable

    ##### Public API #####

    def reset(self):
        self.root_node = TreeNode()
        self.root_node.value = []
        self.root_node.ref_counter = 1
        self.evictable_size_ = 0

    def match_prefix(self, key):
        if self.disable:
            return [], self.root_node

        value = []
        last_node = [self.root_node]
        self._match_prefix_helper(self.root_node, key, value, last_node)
        if value:
            if isinstance(value[0], torch.Tensor):
                value = torch.concat(value)
            else:
                concatenated_value = []
                for v in value:
                    concatenated_value.extend(v)  # Assuming each element in value is a list itself
                value = concatenated_value
        return value, last_node[0]

    def match_prefix_return_str(self, key):
        return "".join(self.match_prefix(key)[0])

    def insert(self, key, value=None):
        if self.disable:
            return len(key)

        if value is None:
            value = [x for x in key]
        return self._insert_helper(self.root_node, key, value)

    def pretty_print(self):
        self._print_helper(self.root_node, 0)
        print(f"#tokens: {self.total_size()}")

    def total_size(self):
        return self._total_size_helper(self.root_node)

    def evict(self, num_tokens, evict_callback):
        if self.disable:
            raise RuntimeError()

        leaves = self._collect_leaves()
        heapq.heapify(leaves)

        num_evicted = 0
        while num_evicted < num_tokens and len(leaves):
            x = heapq.heappop(leaves)

            if x == self.root_node:
                break
            if x.ref_counter > 0:
                continue

            num_evicted += evict_callback(x.value)
            self._delete_leaf(x)

            if len(x.parent.children) == 0:
                heapq.heappush(leaves, x.parent)

    def inc_ref_counter(self, node):
        delta = 0
        while node != self.root_node:
            if node.ref_counter == 0:
                self.evictable_size_ -= len(node.value)
                delta -= len(node.value)
            node.ref_counter += 1
            node = node.parent
        return delta

    def dec_ref_counter(self, node):
        delta = 0
        while node != self.root_node:
            if node.ref_counter == 1:
                self.evictable_size_ += len(node.value)
                delta += len(node.value)
            node.ref_counter -= 1
            node = node.parent
        return delta

    def evictable_size(self):
        return self.evictable_size_

    ##### Internal Helper Functions #####
    def _match_prefix_helper(self, node, key, value, last_node):
        node.last_access_time = time.time()
        for c_key, child in node.children.items():
            prefix_len = match(c_key, key)
            if prefix_len != 0:
                if prefix_len < len(c_key):
                    new_node = self._split_node(c_key, child, prefix_len)
                    value.append(new_node.value)
                    last_node[0] = new_node
                else:
                    value.append(child.value)
                    last_node[0] = child
                    self._match_prefix_helper(child, key[prefix_len:], value, last_node)
                break

    def _split_node(self, key, child, split_len):
        # new_node -> child
        new_node = TreeNode()
        new_node.children = {key[split_len:]: child}
        new_node.parent = child.parent
        new_node.ref_counter = child.ref_counter
        new_node.value = child.value[:split_len]
        child.parent = new_node
        child.value = child.value[split_len:]
        new_node.parent.children[key[:split_len]] = new_node
        del new_node.parent.children[key]
        return new_node

    def _insert_helper(self, node, key, value):
        node.last_access_time = time.time()
        node.ref_counter += 1
        for c_key, child in node.children.items():
            prefix_len = match(c_key, key)

            if prefix_len == len(c_key):
                if prefix_len == len(key):
                    child.ref_counter += 1
                    return prefix_len
                else:
                    key = key[prefix_len:]
                    value = value[prefix_len:]
                    return prefix_len + self._insert_helper(child, key, value)

            if prefix_len:
                new_node = self._split_node(c_key, child, prefix_len)
                return prefix_len + self._insert_helper(
                    new_node, key[prefix_len:], value[prefix_len:]
                )

        if len(key):
            new_node = TreeNode()
            new_node.parent = node
            new_node.value = value
            new_node.ref_counter = 1
            node.children[key] = new_node
            self.evictable_size_ += len(value)
        return 0

    def _print_helper(self, node, indent, depth=0):
        if depth == 5:
            return
        for key, child in node.children.items():
            print(" " * indent, len(key), key[:10], f"r={child.ref_counter}")
            self._print_helper(child, indent=indent + 2, depth=depth + 1)
        

    def _delete_leaf(self, node):
        for k, v in node.parent.children.items():
            if v == node:
                break
        del node.parent.children[k]
        self.evictable_size_ -= len(k)

    def _total_size_helper(self, node):
        x = len(node.value)
        for child in node.children.values():
            x += self._total_size_helper(child)
        return x

    def _collect_leaves(self):
        ret_list = []

        def dfs_(cur_node):
            if len(cur_node.children) == 0:
                ret_list.append(cur_node)

            for x in cur_node.children.values():
                dfs_(x)

        dfs_(self.root_node)
        return ret_list

    

if __name__ == "__main__":
    tree = RadixCache(disable=False)

    tree.insert("Hello")
    tree.insert("Hello There")
    tree.insert("Hello_L.A.!")
    # tree.insert("Hello_world! Happy")
    # tree.insert("I love you!")
    tree.pretty_print()

    print(tree.match_prefix_return_str("Hello T"))

    # def evict_callback(x):
    #    print("evict", x)
    #    return len(x)

    # tree.evict(5, evict_callback)
    # tree.evict(10, evict_callback)
    # tree.pretty_print()


 5 Hello r=3
   6  There r=1
   6 _L.A.! r=1
#tokens: 17
Hello T


In [152]:
cache = RadixCache()
for item in workload[:3]:
    cache.insert(tuple(item["input_ids"]["input_ids"]))
# for i in range(100):
#     cache.insert(tuple(workload[0]["input_ids"]["input_ids"]))
# for j in range(100):
#     cache.insert(tuple(workload[1]["input_ids"]["input_ids"]))
cache.insert(tuple(tokenizer.encode("bdd_special_tokens=True")))
cache.pretty_print()


 1 (1,) r=4
   357 (2135, 28747, 995, 460, 12191, 28777, 6316, 28725, 368, 541) r=3
     443 (17989, 28726, 1468, 28747, 7086, 354, 12875, 28726, 1468, 297) r=1
     3250 (28719, 14836, 28730, 824, 6131, 28747, 2483, 4686, 532, 266) r=1
     1141 (17064, 28730, 28713, 2729, 28730, 350, 3673, 28747, 451, 1036) r=1
   8 (287, 1036, 28730, 14908, 28730, 20228, 28746, 4365) r=1
#tokens: 5200


In [164]:
import numpy as np
from itertools import permutations


total_tokens_available = [33077, 33077, 33077]
nodes = {0: 0, 1: 0, 2:0}

def cost_fn(nodes, total_tokens):
    # how close are the loads of the two nodes
    # how many tokens are being used
    total_load = sum(nodes)

    return {
        "total_load": total_load,
        "total_tokens": sum(total_tokens),
        "load_ratio": np.std(nodes),
        "token_ratio": np.std(total_tokens)
    }


In [165]:
import cvxpy as cp

class LPTreeTraversal:
    def __init__(self, num_gpus):
        self.num_gpus = num_gpus
        self.num_gpus_param = cp.Parameter(nonneg=True, value=num_gpus)
        self.counter = 0
        self.constraints = []
        self.node_map = {}  # Maps PrefixTreeNode to LpNode

        self.time_per_token_fixed_cost = cp.Parameter(nonneg=True, value=4.59)
        self.time_per_token_variable_cost = cp.Parameter(nonneg=True, value=0.1)
        self.load_cost_per_token = cp.Parameter(nonneg=True, value=0.1)

        self.num_tokens_params = {}
        self.ref_count_params = {}

    def _traverse_tree(self, current_prefix_node, parent_lp_node=None):
        self.counter += 1
        current_lp_node = LpNode(self.counter, self.num_gpus)
        self.node_map[current_prefix_node] = current_lp_node

        self.constraints += [cp.sum(current_lp_node.variables) >= 1] # At least one GPU must be allocated for a prefix
        if parent_lp_node:
            # Add constraints based on the relationship between the current node and its parent
            # If the child takes a node, then the parent must also take a node
            for gpu in range(self.num_gpus):
                self.constraints.append(current_lp_node.variables[gpu] <= parent_lp_node.variables[gpu])

        for child_prefix_node in current_prefix_node.children.values():
            self._traverse_tree(child_prefix_node, current_lp_node)

    def solve(self, objective):
        # Create problem and solve
        problem = cp.Problem(objective, self.constraints)
        problem.solve()

    def traverse_and_optimize(self, prefix_tree_root):
        self._traverse_tree(prefix_tree_root)

        memory_cost_terms = []
        load_cost_params = []
        for prefix_node, lp_node in self.node_map.items():
            num_tokens_param = self.num_tokens_params.get(prefix_node, cp.Parameter(nonneg=True))
            num_tokens_param.value = prefix_node.num_tokens  # Update parameter value
            
            load_param = self.ref_count_params.get(prefix_node, cp.Parameter(nonneg=True))
            load_param.value = prefix_node.ref_counter  # Update parameter value

            # TODO use existing matrix to calculate the difference of new tokens
            memory_cost_terms += [num_tokens_param * gpu_var * self.time_per_token_variable_cost + self.time_per_token_fixed_cost * gpu_var for gpu_var in lp_node.variables]
            # Load cost param 
            for gpu_var in lp_node.variables:
                # Directly penalize the non-selection of a GPU. This represents 1/x and is used to simplify the cost.
                penalty = (1 - gpu_var) * self.load_cost_per_token * num_tokens_param * load_param
                load_cost_params.append(penalty)

        # Define your objective function here, for example:
        objective = cp.Minimize(cp.sum(load_cost_params) + cp.sum(memory_cost_terms))  # Adjust according to your specific problem

        # Solve the problem
        self._solve(objective)

    def _solve(self, objective):
        problem = cp.Problem(objective, self.constraints)
        problem.solve()

        # After solving, you can access and print the values of your variables like so:
        # for prefix_node, lp_node in self.node_map.items():
        #     print(f"Node {lp_node.node_id}: {[var.value for var in lp_node.variables]}")
        #     print(f"Node {prefix_node.value[:15]}")
    
    def pretty_print(self, prefix_node, indent=""):
        lp_node = self.node_map.get(prefix_node)
        if lp_node:
            # Assuming each lp_node.variables[i].value gives whether GPU i is selected
            selected_gpus = [i for i, var in enumerate(lp_node.variables) if var.value >= 0.99]  # Adjust threshold as needed
            print(f"{indent}Node {lp_node.node_id} (Prefix: {prefix_node.value[:15]}): GPUs {selected_gpus}")
        else:
            print(f"{indent}Node (Prefix: {prefix_node.value[:15]}) has no LP Node mapping")

        for child in prefix_node.children.values():
            self.pretty_print(child, indent + "  ")

# Assuming a Node class with attributes 'children' and 'node_id'
class LpNode:
    def __init__(self, node_id, num_gpus):
        self.node_id = node_id
        self.variables = [cp.Variable(name=f"node_{node_id}_gpu_{gpu}", boolean=True) for gpu in range(num_gpus)]


lp_tree_traversal = LPTreeTraversal(3)
lp_tree_traversal.traverse_and_optimize(cache.root_node)
lp_tree_traversal.pretty_print(cache.root_node)
print("Combo based")
# import numpy

Node 1 (Prefix: []): GPUs [0, 1, 2]
  Node 2 (Prefix: [1]): GPUs [0, 1, 2]
    Node 3 (Prefix: [2135, 28747, 995, 460, 12191, 28777, 6316, 28725, 368, 541, 938, 1287, 7040, 28732, 19659]): GPUs [0, 1, 2]
      Node 4 (Prefix: [17989, 28726, 1468, 28747, 7086, 354, 12875, 28726, 1468, 297, 8657, 22567, 13, 13, 4947]): GPUs [2]
      Node 5 (Prefix: [28719, 14836, 28730, 824, 6131, 28747, 2483, 4686, 532, 266, 732, 19607, 1178, 477, 16020]): GPUs [2]
      Node 6 (Prefix: [17064, 28730, 28713, 2729, 28730, 350, 3673, 28747, 451, 1036, 28713, 1178, 7086, 354, 17868]): GPUs [2]
    Node 7 (Prefix: [287, 1036, 28730, 14908, 28730, 20228, 28746, 4365]): GPUs [2]
Combo based
