In [16]:
import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from dataclasses import field, asdict, dataclass
from typing import List
from benchmarks.benchmark_utils import RequestFuncOutput
from benchmarks.benchmark_workload_gen import *

In [17]:
def retrive_request_outputs(path):
    with open(path, 'r') as file:
        data = json.load(file)
        outputs = [RequestFuncOutput(**d) for d in data]
    return outputs

In [21]:
def lat_tpot_ttft(outputs: List[RequestFuncOutput], match = None, plot=False):
    if not match:
        match = lambda o: True
    outputs = [o for o in outputs if match(o)]
    if not outputs:
        print('No outputs to analyze')
        return
    lats = [o.request_latency for o in outputs if o.request_latency]
    tpots = [o.tpot for o in outputs if o.tpot]
    ttfts = [o.ttft for o in outputs if o.ttft]
    lat_p50, lat_p90, lat_p99 = np.percentile(lats, [50, 90, 99], method='nearest')
    tpot_p50, tpot_p90, tpot_p99 = np.percentile(tpots, [50, 90, 99], method='nearest')
    ttft_p50, ttft_p90, ttft_p99 = np.percentile(ttfts, [50, 90, 99], method='nearest')
    
    print(f"Num finished: {len([o for o in outputs if o.success])}")
    print(f'Latency: p50={lat_p50:.2f}, p90={lat_p90:.2f}, p99={lat_p99:.2f}')
    print(f'TPOT: p50={tpot_p50:.2f}, p90={tpot_p90:.2f}, p99={tpot_p99:.2f}')
    print(f'TTFT: p50={ttft_p50:.2f}, p90={ttft_p90:.2f}, p99={ttft_p99:.2f}')

    if plot:
        fig, axs = plt.subplots(1,3, figsize=(16, 4))
        sns.kdeplot(x=lats,  ax=axs[0], cumulative=True).set_title('Latency')
        sns.kdeplot(x=tpots, ax=axs[1], cumulative=True).set_title('TPOT')
        sns.kdeplot(x=ttfts, ax=axs[2], cumulative=True).set_title('TTFT')
    
def ttft_slo(outputs, slo):
    ttfts = [o.ttft for o in outputs if o.ttft]
    slo_ttfts = [1 for ttft in ttfts if ttft <= slo]
    return sum(slo_ttfts) / len(ttfts)

def windowed_metric(start, end, outputs: List[RequestFuncOutput], exp_time, match = None):
    if not match:
        match = lambda o: True
    within_window = [o for o in outputs if o.send_out_time >= start and o.send_out_time <= end and match(o)]
    lats = [o.request_latency if o.success else exp_time - o.send_out_time for o in within_window]
    tpots = [o.tpot for o in within_window if o.tpot]
    ttfts = [o.ttft if o.ttft else exp_time - o.send_out_time for o in within_window]
    lat_p50, lat_p90, lat_p99 = np.percentile(lats, [50, 90, 99], method='nearest')
    tpot_p50, tpot_p90, tpot_p99 = np.percentile(tpots, [50, 90, 99], method='nearest')
    ttft_p50, ttft_p90, ttft_p99 = np.percentile(ttfts, [50, 90, 99], method='nearest')
    avg_lat, avg_tpot, avg_ttft = np.mean(lats), np.mean(tpots), np.mean(ttfts)
    
    print('-'*20)
    print(f"Requests within window: {len(within_window)}")
    print(f"Num finished: {len([o for o in within_window if o.success])}")
    print(f'Latency: p50={lat_p50:.2f}, p90={lat_p90:.2f}, p99={lat_p99:.2f}')
    print(f'TPOT: p50={tpot_p50:.2f}, p90={tpot_p90:.2f}, p99={tpot_p99:.2f}')
    print(f'TTFT: p50={ttft_p50:.2f}, p90={ttft_p90:.2f}, p99={ttft_p99:.2f}')
    print(f'Avg Latency: {avg_lat:.2f}, Avg TPOT: {avg_tpot:.2f}, Avg TTFT: {avg_ttft:.2f}')
    print('-'*20)

def runtime_selection_consistency(outputs: List[RequestFuncOutput], match = None, plot=False):
    if not match:
        match = lambda o: True
    outputs = [o for o in outputs if match(o)]
    if not outputs:
        print('No outputs to analyze')
        return
    runtime_load = {}
    runtime_prefix = {}
    prefix_cnt = defaultdict(int)
    for o in outputs:
        if o.runtime_selected not in runtime_load:
            runtime_load[o.runtime_selected] = [0, 0]
        if o.runtime_selected not in runtime_prefix:
            runtime_prefix[o.runtime_selected] = set()
        prefix_index = WorkloadPrefixDataLoader.get_prefix_index(o)
        if prefix_index is None:
            runtime_load[o.runtime_selected][1] += 1
        else:
            runtime_load[o.runtime_selected][0] += 1
            runtime_prefix[o.runtime_selected].add(prefix_index)
            prefix_cnt[prefix_index] += 1
    print(runtime_load)
    print(runtime_prefix)
    print(sorted(prefix_cnt.items()))

In [22]:
is_cold = lambda o: not WorkloadPrefixDataLoader.is_hot(o)
def is_on_gpu(ks):
    def match(o: RequestFuncOutput, ks):
        return o.runtime_selected in ks
    return lambda o: match(o, ks)
def is_workload(i):
    def math(o: RequestFuncOutput, i):
        prefix_pattern = WorkloadPrefixDataLoader.get_prefix_index(o)
        return prefix_pattern is not None and prefix_pattern == i
    return lambda o: math(o, i)

In [23]:
sim_oracle_fcfs = retrive_request_outputs('/mnt/ssd1/alm-os/sglang_multi_model/logs/debug/mistralai-Mistral-7B-v0.1_80_0.2_2700_9_DataParallelRuntimeSelectionPolicy.CUSTOM-CustomPolicyType.ORACLE:rr_inf.json')
lat_tpot_ttft(sim_oracle_fcfs)
lat_tpot_ttft(sim_oracle_fcfs, is_cold)
lat_tpot_ttft(sim_oracle_fcfs, WorkloadPrefixDataLoader.is_hot)
runtime_selection_consistency(sim_oracle_fcfs)

Num finished: 2700
Latency: p50=36.10, p90=78.20, p99=109.41
TPOT: p50=0.33, p90=0.41, p99=0.47
TTFT: p50=17.30, p90=54.68, p99=86.69
Num finished: 540
Latency: p50=61.19, p90=95.31, p99=112.94
TPOT: p50=0.33, p90=0.43, p99=0.47
TTFT: p50=40.99, p90=72.06, p99=92.56
Num finished: 2160
Latency: p50=31.15, p90=72.04, p99=106.35
TPOT: p50=0.33, p90=0.41, p99=0.47
TTFT: p50=9.81, p90=49.49, p99=82.48
{0: [540, 138], 1: [540, 145], 3: [540, 132], 2: [540, 125]}
{0: {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64, 68, 72, 76}, 1: {1, 5, 9, 13, 17, 21, 25, 29, 33, 37, 41, 45, 49, 53, 57, 61, 65, 69, 73, 77}, 3: {3, 7, 11, 15, 19, 23, 27, 31, 35, 39, 43, 47, 51, 55, 59, 63, 67, 71, 75, 79}, 2: {2, 6, 10, 14, 18, 22, 26, 30, 34, 38, 42, 46, 50, 54, 58, 62, 66, 70, 74, 78}}
[(0, 27), (1, 27), (2, 27), (3, 27), (4, 27), (5, 27), (6, 27), (7, 27), (8, 27), (9, 27), (10, 27), (11, 27), (12, 27), (13, 27), (14, 27), (15, 27), (16, 27), (17, 27), (18, 27), (19, 27), (20, 27), (21, 27

In [24]:
sim_oracle = retrive_request_outputs('/mnt/ssd1/alm-os/sglang_multi_model/logs/debug/mistralai-Mistral-7B-v0.1_80_0.2_2700_9_DataParallelRuntimeSelectionPolicy.CUSTOM-CustomPolicyType.HistogramBasedMemoryLoadScheduler:mem_cost_add_selected_only_fix_current_time_inf.json')
lat_tpot_ttft(sim_oracle)
lat_tpot_ttft(sim_oracle, is_cold)
lat_tpot_ttft(sim_oracle, WorkloadPrefixDataLoader.is_hot)
runtime_selection_consistency(sim_oracle)

Num finished: 2700
Latency: p50=36.34, p90=81.42, p99=130.59
TPOT: p50=0.30, p90=0.41, p99=0.48
TTFT: p50=18.02, p90=61.29, p99=113.43
Num finished: 540
Latency: p50=64.69, p90=118.80, p99=137.78
TPOT: p50=0.29, p90=0.42, p99=0.49
TTFT: p50=48.00, p90=95.08, p99=115.15
Num finished: 2160
Latency: p50=30.13, p90=67.22, p99=120.83
TPOT: p50=0.30, p90=0.41, p99=0.47
TTFT: p50=10.07, p90=47.91, p99=98.13
{0: [459, 214], 1: [513, 159], 2: [594, 82], 3: [594, 85]}
{0: {2, 35, 68, 26, 3, 7, 66, 67, 76, 46, 48, 50, 53, 57, 58, 28, 63}, 1: {5, 6, 12, 15, 18, 22, 25, 27, 31, 37, 38, 39, 40, 41, 44, 45, 52, 55, 60}, 2: {1, 9, 10, 14, 16, 17, 20, 21, 29, 32, 34, 56, 62, 64, 69, 70, 72, 73, 74, 75, 77, 79}, 3: {0, 4, 8, 11, 13, 19, 23, 24, 30, 33, 36, 42, 43, 47, 49, 51, 54, 59, 61, 65, 71, 78}}
[(0, 27), (1, 27), (2, 27), (3, 27), (4, 27), (5, 27), (6, 27), (7, 27), (8, 27), (9, 27), (10, 27), (11, 27), (12, 27), (13, 27), (14, 27), (15, 27), (16, 27), (17, 27), (18, 27), (19, 27), (20, 27), (21, 

In [37]:

sim_oracle = retrive_request_outputs('/mnt/ssd1/alm-os/sglang_multi_model/logs/debug/mistralai-Mistral-7B-v0.1_80_0.2_2700_9_DataParallelRuntimeSelectionPolicy.CUSTOM-CustomPolicyType.HistogramBasedMemoryLoadScheduler:_inf.json')
lat_tpot_ttft(sim_oracle)
# lat_tpot_ttft(sim_oracle, is_cold)
# lat_tpot_ttft(sim_oracle, WorkloadPrefixDataLoader.is_hot)
runtime_selection_consistency(sim_oracle)

Num finished: 2700
Latency: p50=32.99, p90=94.25, p99=131.70
TPOT: p50=0.31, p90=0.41, p99=0.49
TTFT: p50=15.27, p90=71.59, p99=106.53
{0: [459, 213], 1: [540, 133], 2: [648, 34], 3: [486, 187]}
{0: {64, 33, 35, 68, 4, 71, 75, 77, 45, 46, 15, 49, 48, 19, 58, 28, 61}, 1: {5, 6, 11, 18, 22, 24, 27, 31, 36, 37, 41, 52, 54, 55, 57, 60, 63, 65, 66, 69}, 2: {1, 7, 9, 10, 14, 16, 17, 20, 25, 26, 29, 32, 34, 38, 43, 56, 62, 67, 70, 72, 73, 74, 76, 78}, 3: {2, 3, 39, 40, 8, 42, 44, 13, 12, 47, 79, 50, 51, 53, 21, 23, 59, 30}}
set()


In [24]:
sim_oracle = retrive_request_outputs('/mnt/ssd1/alm-os/sglang_multi_model/workload_prefix/4r_sim_80_0.2_2700_9_baseline_cp_1024/mistralai-Mistral-7B-v0.1_80_0.2_2700_9_DataParallelRuntimeSelectionPolicy.CUSTOM-CustomPolicyType.ORACLE:10_inf.json')
lat_tpot_ttft(sim_oracle)
lat_tpot_ttft(sim_oracle, is_cold)
lat_tpot_ttft(sim_oracle, WorkloadPrefixDataLoader.is_hot)

Num finished: 2700
Latency: p50=46.66, p90=81.98, p99=91.89
TPOT: p50=0.22, p90=0.23, p99=0.25
TTFT: p50=33.76, p90=68.73, p99=78.50
Num finished: 540
Latency: p50=47.77, p90=82.60, p99=92.50
TPOT: p50=0.21, p90=0.23, p99=0.24
TTFT: p50=34.72, p90=69.67, p99=79.16
Num finished: 2160
Latency: p50=46.22, p90=81.83, p99=91.69
TPOT: p50=0.22, p90=0.23, p99=0.25
TTFT: p50=33.43, p90=68.44, p99=78.34
