In [None]:
from gurobipy import GRB
import gurobipy as gp
from typing import Optional
from collections import defaultdict
import heapq
import time
from collections import defaultdict
from uuid import uuid4
import copy
import random
import threading
from enum import Enum, auto
import logging
from benchmarks.benchmark_utils import RequestFuncOutput
import os
import collections
import numpy as np

logging = logging.getLogger(__name__)
DEBUG_COUNTER = 0
class LpNode:
    def __init__(self, node_id, num_gpus):
        self.node_id = node_id
        self.variables = [
            None for _ in range(num_gpus)
        ]  # Will be initialized as binary variables in the model
        self.children_token_cost_at_max_depth = 0  # Issue is that depth_limit will cut off the tokens for children and that will treat it as free
        self.randomly_selected_gpu = None
        self.load_variables = [None for _ in range(num_gpus)]
        self.common_load = None

    def __repr__(self):
        variable_values = [var.x if var else None for var in self.variables]
        load_variable_values = [var.x if var else None for var in self.load_variables]
        common_load = self.common_load.x if self.common_load else None
        # ignore printing laod variables if None
        if any(load_variable_values):
            return f"LpNode(node_id={self.node_id}, variables={variable_values}, load_variables={load_variable_values}, common_load={common_load})"
        else:
            return f"LpNode(node_id={self.node_id}, variables={variable_values})"


class LPTreeNode:
    def __init__(self):
        self.id = uuid4()
        self.children = defaultdict(LPTreeNode)
        self.parent: Optional[LPTreeNode] = None
        self.value = None
        self.ref_counter = 0
        self.last_access_time = time.time()
        self.gpu_selections = set()
        self.is_leaf = False
        self.decode_length = 0
        self.context_length = 0

    @property
    def num_tokens(self):
        return len(self.value)

    def __lt__(self, other):
        return self.last_access_time < other.last_access_time

    def __eq__(self, other):
        if isinstance(other, LPTreeNode):
            return self.id == other.id  # Compare nodes based on their unique ID
        return False

    def __hash__(self):
        return hash(self.id)  # Use the unique ID for hashing

    def __repr__(self) -> str:
        return f"LPTreeNode(id={self.id}, ref_counter={self.ref_counter}, gpu_selections={self.gpu_selections})"


def match(key, seq):
    i = 0
    for k, w in zip(key, seq):
        if k != w:
            break
        i += 1
    return i


class LPRadixCache:
    def __init__(self, disable=False):
        self.reset()
        self.disable = disable

    ##### Public API #####

    def reset(self):
        self.root_node = LPTreeNode()
        self.root_node.value = []
        self.root_node.ref_counter = 1
        self.evictable_size_ = 0

    def find_node(self, key):
        if self.disable:
            return None
        current_gpu_selection, node = self.match_prefix_get_gpu_selection(key)
        return node

    def match_prefix_get_gpu_selection(self, key, path_to_node=[]):
        if self.disable:
            return [], self.root_node

        value = []
        current_gpu_selection = self.root_node.gpu_selections
        current_gpu_selection, node = self._match_prefix_helper_gpu_selection(
            self.root_node, key, value, current_gpu_selection
        )
        return current_gpu_selection, node

    def _match_prefix_helper_gpu_selection(
        self, node, key, value, current_gpu_selection
    ):
        child: LPTreeNode
        for c_key, child in node.children.items():
            prefix_len = match(c_key, key)
            if prefix_len != 0:
                if child.gpu_selections:
                    current_gpu_selection = child.gpu_selections
                if prefix_len < len(c_key):
                    print(prefix_len, len(c_key))
                    assert False
                    new_node = self._split_node(
                        c_key, child, prefix_len, new_nodes_created=new_nodes_created
                    )
                    value.append(new_node.value)
                    # last_node[0] = new_node
                else:
                    value.append(child.value)
                    # last_node[0] = child
                    return self._match_prefix_helper_gpu_selection(
                        child, key[prefix_len:], value, current_gpu_selection
                    )
        return current_gpu_selection, node

    def match_prefix_return_str(self, key):
        return "".join(self.match_prefix(key)[0])

    def insert(
        self,
        key,
        value=None,
        node_map=None,
        all_modified_nodes=None,
        split_nodes=None,
        depth_limit=0,
    ):
        if node_map is None:
            node_map = {}
        if all_modified_nodes is None:
            all_modified_nodes = set()
        if split_nodes is None:
            split_nodes = {}  # key -> node
        if self.disable:
            return len(key)

        if value is None:
            value = [x for x in key]
        modified_nodes = set()
        created_node = self._insert_helper(
            self.root_node,
            key,
            value,
            node_map=node_map,
            modified_nodes=modified_nodes,
            depth_limit=depth_limit,
            current_depth=0,
            split_nodes=split_nodes,
        )

        node: LPTreeNode = created_node
        while node is not None:
            if node in all_modified_nodes:
                break
            all_modified_nodes.add(node)
            node = node.parent
        return created_node

    def pretty_print(self):
        self._print_helper(self.root_node, 0)
        print(f"#tokens: {self.total_size()}")

    def total_size(self):
        return self._total_size_helper(self.root_node)

    def evict(self, num_tokens, evict_callback):
        if self.disable:
            raise RuntimeError()

        leaves = self._collect_leaves()
        heapq.heapify(leaves)

        num_evicted = 0
        while num_evicted < num_tokens and len(leaves):
            x = heapq.heappop(leaves)

            if x == self.root_node:
                break
            if x.ref_counter > 0:
                continue

            num_evicted += evict_callback(x)
            self._delete_leaf(x)

            if len(x.parent.children) == 0:
                heapq.heappush(leaves, x.parent)

    def inc_ref_counter(self, node):
        delta = 0
        while node != self.root_node:
            if node.ref_counter == 0:
                self.evictable_size_ -= len(node.value)
                delta -= len(node.value)
            node.ref_counter += 1
            node = node.parent
        return delta

    def dec_ref_counter(self, node):
        delta = 0
        while node != self.root_node:
            # if node.ref_counter == 1: TODO why does this exist?
            #     self.evictable_size_ += len(node.value)
            #     delta += len(node.value)
            node.ref_counter -= 1
            node = node.parent
        return delta

    def remove_completed_input_ids(self, input_ids):
        node = self.find_node(input_ids)
        self.dec_ref_counter(node)  # remove reference counter up to parent
    
    def evictable_size(self):
        return self.evictable_size_

    def _split_node(
        self, key, child: LPTreeNode, split_len, node_map, depth_limit, current_depth
    ):
        # new_node -> child
        new_node = LPTreeNode()
        new_node.gpu_selections = copy.deepcopy(child.gpu_selections)
        new_node.children = {key[split_len:]: child}
        new_node.parent = child.parent
        new_node.ref_counter = child.ref_counter
        new_node.context_length = child.parent.context_length + split_len

        new_node.value = child.value[:split_len]
        child.parent = new_node
        child.value = child.value[split_len:]

        new_node.parent.children[key[:split_len]] = new_node
        del new_node.parent.children[key]
        return new_node

    def _insert_helper(
        self,
        node: LPTreeNode,
        key,
        value,
        node_map,
        modified_nodes,
        depth_limit,
        current_depth,
        split_nodes,
        parent_context_length = 0
    ):
        node.last_access_time = time.time()
        node.ref_counter += 1

        for c_key, child in node.children.items():
            prefix_len = match(c_key, key)
            if prefix_len == len(c_key):
                if prefix_len == len(key):
                    child.ref_counter += 1
                    modified_nodes.add(child)
                    return child
                else:
                    key = key[prefix_len:]
                    value = value[prefix_len:]
                    return self._insert_helper(
                        child,
                        key,
                        value,
                        node_map=node_map,
                        modified_nodes=modified_nodes,
                        depth_limit=depth_limit,
                        current_depth=current_depth + 1,
                        split_nodes=split_nodes,
                        parent_context_length=parent_context_length + prefix_len,
                    )

            if prefix_len:
                new_node = self._split_node(
                    c_key,
                    child,
                    prefix_len,
                    node_map,
                    depth_limit=depth_limit,
                    current_depth=current_depth + 1,
                )
                # modified_nodes.add(new_node)
                # modified_nodes.add(child)
                # TODO check if this makes sense to ignore this?
                # if child in node_map and current_depth < depth_limit:
                split_nodes[child] = new_node
                return self._insert_helper(
                    new_node,
                    key[prefix_len:],
                    value[prefix_len:],
                    node_map=node_map,
                    modified_nodes=modified_nodes,
                    depth_limit=depth_limit,
                    current_depth=current_depth + 1,
                    split_nodes=split_nodes,
                    parent_context_length=parent_context_length + prefix_len,
                )

        if len(key):
            new_node = LPTreeNode()
            new_node.gpu_selections = set()
            new_node.parent = node
            new_node.value = value
            new_node.ref_counter = 1
            new_node.context_length = parent_context_length + len(key)

            node.children[key] = new_node
            self.evictable_size_ += len(value)
            # if current_depth < depth_limit:
            modified_nodes.add(new_node)
            # return new_node
            return new_node
        return node

    def _print_helper(self, node, indent, depth=0):
        if depth == 5:
            return
        for key, child in node.children.items():
            print(" " * indent, child)
            self._print_helper(child, indent=indent + 2, depth=depth + 1)

    def _delete_leaf(self, node):
        for k, v in node.parent.children.items():
            if v == node:
                break
        del node.parent.children[k]
        self.evictable_size_ -= len(k)

    def _total_size_helper(self, node):
        x = len(node.value)
        for child in node.children.values():
            x += self._total_size_helper(child)
        return x

    def _collect_leaves(self):
        ret_list = []

        def dfs_(cur_node):
            if len(cur_node.children) == 0:
                ret_list.append(cur_node)

            for x in cur_node.children.values():
                dfs_(x)

        dfs_(self.root_node)
        return ret_list

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

In [None]:
from benchmarks.benchmark_workload_gen import WorkloadPrefixDataLoader, ToolBenchDataLoader, LooGLEDataset, LoadDistribution

random_workload = WorkloadPrefixDataLoader(
        num_patterns=200,
        total_num_requests=400,
        tokenizer=tokenizer,
        load_dist = LoadDistribution.EVEN,
        distribution_of_non_shared = 0.2,
        output_len=16,
        num_in_context_examples = 3,
        random_workload_path="benchmarks/datasets/ShareGPT_V3_unfiltered_cleaned_split.json"
)
requests = random_workload.generate_workload(k=None)

In [None]:
toolbench_workload = ToolBenchDataLoader(
    num_patterns=200,
    total_num_requests=400,
    tokenizer=tokenizer,
    load_dist = LoadDistribution.EVEN,
    data_path="benchmarks/datasets/G1_workload_updated_input_output_lengths_4096.json",
)
toolbench_requets = toolbench_workload.generate_workload(k=None)

In [None]:
from collections import defaultdict
mem_cost = [0, 0]
num_gpus = 2
gpu_allocations = defaultdict(set)

def get_recomp_cost(node: LPTreeNode, gpu_id):
    if not node or gpu_id in gpu_allocations[node]:
        return 0
    else:
        return node.num_tokens + get_recomp_cost(node.parent, gpu_id)

def update_gpu_selections_of_parent(node: LPTreeNode, gpu_id):
    if not node:
        return
    node.gpu_selections.add(gpu_id)
    update_gpu_selections_of_parent(node.parent, gpu_id)
def handle_split_nodes(split_nodes, gpu_allocations):
    for k, v in split_nodes.items():
        gpu_allocations[k] = gpu_allocations[v].copy()

cache = LPRadixCache()
for request in requests[:64]:
    split_nodes = {}
    leaf_node = cache.insert(tuple(request["input_ids"]), split_nodes=split_nodes)
    handle_split_nodes(split_nodes, gpu_allocations)
    recom_costs = []
    for gpu_id in range(num_gpus):
        recomputation_cost = get_recomp_cost(leaf_node, gpu_id)
        recom_costs.append(recomputation_cost)
    gpu_selected = np.argmin([recom_costs[gpu_id] + mem_cost[gpu_id] for gpu_id in range(num_gpus)])
    mem_cost[gpu_selected] += recom_costs[gpu_selected]
    update_gpu_selections_of_parent(leaf_node, gpu_selected)

In [None]:
reqs = []
for req in toolbench_requets:
    reqs.append(req["input_ids"])

In [None]:
def _print_helper(node, indent, depth=0):
    for key, child in node.children.items():
        print(" " * indent, tokenizer.decode(child.value)[:20], child.gpu_selections, len(child.value))
        _print_helper(child, indent=indent + 2, depth=depth + 1)
_print_helper(cache.root_node, 0)

In [None]:
from collections import defaultdict
mem_cost = [0, 0]
num_gpus = 2
gpu_allocations = defaultdict(set)

def get_recomp_cost(node: LPTreeNode, gpu_id):
    if not node or gpu_id in gpu_allocations[node]:
        return 0
    else:
        return node.num_tokens + get_recomp_cost(node.parent, gpu_id)

def update_gpu_selections_of_parent(node: LPTreeNode, gpu_id):
    if not node:
        return
    node.gpu_selections = node.gpu_selections.union(gpu_id)
    update_gpu_selections_of_parent(node.parent, gpu_id)
def handle_split_nodes(split_nodes, gpu_allocations):
    for k, v in split_nodes.items():
        gpu_allocations[k] = gpu_allocations[v].copy()

def get_parent_gpu_selections(node: LPTreeNode):
    if not node:
        return set()
    if node.gpu_selections:
        return node.gpu_selections
    return get_parent_gpu_selections(node.parent)

cache = LPRadixCache()
for request in toolbench_requets[:64]:
    split_nodes = {}
    leaf_node = cache.insert(tuple(request["input_ids"]), split_nodes=split_nodes)
    handle_split_nodes(split_nodes, gpu_allocations)
    print(leaf_node.num_tokens, leaf_node.context_length)
    
    gpu_selected:set
    if leaf_node.num_tokens < leaf_node.context_length - leaf_node.num_tokens:
        gpu_selected = get_parent_gpu_selections(leaf_node)
        for gpu in gpu_selected:
            mem_cost[gpu] += get_recomp_cost(leaf_node, gpu)
    else:
        recom_costs = []
        for gpu_id in range(num_gpus):
            recomputation_cost = get_recomp_cost(leaf_node, gpu_id)
            recom_costs.append(recomputation_cost)
        gpu_selected = np.argmin([recom_costs[gpu_id] + mem_cost[gpu_id] for gpu_id in range(num_gpus)])
        mem_cost[gpu_selected] += recom_costs[gpu_selected]
        gpu_selected = set([gpu_selected])
    update_gpu_selections_of_parent(leaf_node, gpu_selected)

In [None]:
from benchmarks.benchmark_workload_gen import LooGLEDataset, LooGLEDatasetType
dataloader = LooGLEDataset(
    num_patterns=24,
    total_num_requests=250,
    tokenizer=tokenizer,
    loogle_dataset_type=LooGLEDatasetType.SHORT_QA
)
requests = dataloader.generate_workload(max_length=32768)

In [None]:
from collections import defaultdict
mem_cost = [0, 0]
num_gpus = 2
gpu_allocations = defaultdict(set)

def get_recomp_cost(node: LPTreeNode, gpu_id):
    if not node or gpu_id in gpu_allocations[node]:
        return 0
    else:
        return node.num_tokens + get_recomp_cost(node.parent, gpu_id)

def update_gpu_selections_of_parent(node: LPTreeNode, gpu_id):
    if not node:
        return
    node.gpu_selections = node.gpu_selections.union(gpu_id)
    update_gpu_selections_of_parent(node.parent, gpu_id)
def handle_split_nodes(split_nodes, gpu_allocations):
    for k, v in split_nodes.items():
        gpu_allocations[k] = gpu_allocations[v].copy()

def get_parent_gpu_selections(node: LPTreeNode):
    if not node:
        return set()
    if node.gpu_selections:
        return node.gpu_selections
    return get_parent_gpu_selections(node.parent)

cache = LPRadixCache()
for request in toolbench_requets[:64]:
    split_nodes = {}
    leaf_node = cache.insert(tuple(request["input_ids"]), split_nodes=split_nodes)
    handle_split_nodes(split_nodes, gpu_allocations)
    print(leaf_node.num_tokens, leaf_node.context_length)
    
    gpu_selected:set
    if leaf_node.num_tokens < leaf_node.context_length - leaf_node.num_tokens:
        gpu_selected = get_parent_gpu_selections(leaf_node)
        for gpu in gpu_selected:
            mem_cost[gpu] += get_recomp_cost(leaf_node, gpu)
    else:
        recom_costs = []
        for gpu_id in range(num_gpus):
            recomputation_cost = get_recomp_cost(leaf_node, gpu_id)
            recom_costs.append(recomputation_cost)
        gpu_selected = np.argmin([recom_costs[gpu_id] + mem_cost[gpu_id] for gpu_id in range(num_gpus)])
        mem_cost[gpu_selected] += recom_costs[gpu_selected]
        gpu_selected = set([gpu_selected])
    update_gpu_selections_of_parent(leaf_node, gpu_selected)
_print_helper(cache.root_node, 0)