In [None]:
import gc
import glob
import os
import re
import multiprocessing as mp
from typing import Optional, List, Tuple

import orjson
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression, RANSACRegressor
from sklearn.metrics import r2_score

In [None]:
def convert_to_df(metric_file: str, extra_file: str, freq: int) -> Tuple[pd.DataFrame, pd.DataFrame]:

    match = re.search(r"bs-(\d+)-outlen-(\d+)", metric_file)
    bs = int(match.group(1))
    outlen = int(match.group(2))

    # assert f"bs-{bs}-outlen-{outlen}" in extra_file, f"Mismatch in bs/outlen for {metric_file} and {extra_file}"
    print(f"Loading {metric_file} and {extra_file} with bs={bs}, outlen={outlen}")

    try:
        with open(metric_file, "r") as f:
            content = f.readlines()[-1].replace("Infinity", "null", 1)
            metric_json = orjson.loads(content)
        metrics = {
            "freq": freq,
            "bs": len(metric_json["input_lens"]),
            "tokens": metric_json["total_output_tokens"],
            "duration": metric_json["duration"],
            "energy": metric_json["old_energy"],
        }
        df_metric = pd.DataFrame([metrics])
    except Exception as e:
        print(f"Error reading {metric_file}: {e}")
        df_metric = pd.DataFrame(columns=["freq", "bs", "tokens", "duration", "energy"])

    gc.collect()

    extras = []

    with open(extra_file, "r") as f:
        extra_json = orjson.loads(f.read())

    for request_id, extra_batch_infos in extra_json.items():
        for token_idx, info in enumerate(extra_batch_infos):
            if info["forward_mode"] != 2:
                continue
            extras.append({
                "freq": freq,
                "iteration_counter": info["iteration_counter"],
                "batch_size": info["batch_size"],
                "input_len": 1,
                "output_len": outlen,
                "tokens": sum(info["num_total_computed_tokens_list"]),
                "itl": (info.get("itl", 0) or 0) * 1000,  # ms
                "forward_time": (info["timestamp_after_forward"] - info["timestamp_before_forward"]) * 1000,  # ms
            })

    df_extra = pd.DataFrame(extras)

    gc.collect()

    return df_metric, df_extra

def load_data(dir_path: str, freq: int, mode: str, mod: int = 1, mod_rem: int = 0) -> Tuple[pd.DataFrame, pd.DataFrame]:
    extra_files = glob.glob(os.path.join(dir_path, "*-extra.json"))
    extra_files.sort()
    metric_files = [f.replace("-extra", "") for f in extra_files]

    assert mode in ["save", "load"], "Mode must be either 'save' or 'load'."

    metric_dfs = []
    extra_dfs = []

    for idx, (metric_file, extra_file) in enumerate(zip(metric_files, extra_files)):

        if idx % mod != mod_rem:
            continue

        df_metric_file = metric_file.replace(".json", ".csv")
        df_extra_file = extra_file.replace(".json", ".csv")
        if mode == "load":
            print(f"Loading {df_metric_file} and {df_extra_file}")
            if not os.path.exists(df_metric_file) or not os.path.exists(df_extra_file):
                print(f"Files {df_metric_file} or {df_extra_file} do not exist.")
                continue
            df_metric = pd.read_csv(df_metric_file)
            df_extra = pd.read_csv(df_extra_file)
        elif mode == "save":
            metric_file = metric_file + "-detail"
            df_metric, df_extra = convert_to_df(metric_file, extra_file, freq)
            df_metric.to_csv(df_metric_file, index=False)
            df_extra.to_csv(df_extra_file, index=False)
        metric_dfs.append(df_metric)
        extra_dfs.append(df_extra)

        gc.collect()

    df_metric = pd.concat(metric_dfs, ignore_index=True)
    df_extra = pd.concat(extra_dfs, ignore_index=True)

    gc.collect()

    print(f"Loaded {len(metric_dfs)} metric files and {len(extra_dfs)} extra files for freq={freq}, mode={mode}, mod={mod}, mod_rem={mod_rem}.")

    return df_metric, df_extra

In [None]:
df_metric_list = []
df_extra_list = []

pool = mp.Pool(mp.cpu_count())

args = []

mod = 10

for freq in [1005, 1095, 1200, 1305, 1410]:
    for mod_rem in range(mod):
        args.append((str(freq), freq, "save", mod, mod_rem))
        # args.append((str(freq), freq, "load", mod, mod_rem))

results = pool.starmap(load_data, args)

pool.close()
pool.join()

for df_metric, df_extra in results:
    df_metric_list.append(df_metric)
    df_extra_list.append(df_extra)

print(f"Loaded {len(df_metric_list)} metric files and {len(df_extra_list)} extra files.")

In [None]:
df_dask = pd.concat(df_extra_list, ignore_index=True)

print(df_dask)

In [None]:
df_dask = pd.concat(df_extra_list, ignore_index=True)
df_dask.sort_values(by=["freq", "batch_size", "tokens", "itl", "forward_time"], inplace=True)

df = df_dask
gc.collect()

In [None]:
df

In [None]:
df.reset_index(drop=True, inplace=True)
df_itl = df[(df["itl"] > 3) & (df["itl"] < 300)]  # adjust it

gc.collect()

df_itl

In [None]:
%matplotlib widget

plt.rcParams["font.size"] = 14

steps_end = [4, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 256, 384, 512, 640, 768, 896, 1024, 1152, 1280, 1408, 1536, 1664, 1792, 1920, 2048]
# steps_end = [4, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 256, 384, 512, 640, 768, 896, 1024]
steps_beg = [1] + [v + 1 for v in steps_end[:-1]]

freq_color_map = {
    1005: "blue",
    1095: "green",
    1200: "purple",
    1305: "cyan",
    1410: "orange",
}

def draw(df: pd.DataFrame, freq_list: Optional[List[int]] = None, save_csv: bool = False):

    xs_total = df["batch_size"].to_numpy()
    ys_total = df["tokens"].to_numpy()
    # zs_total = df["forward_time"].to_numpy()
    zs_total = df["itl"].to_numpy()

    table_items = []

    # figures
    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(111, projection='3d')

    freq_list = freq_list if freq_list else df["freq"].unique().tolist()

    for freq in freq_list:

        print(f"Processing frequency: {freq}")

        sel = df["freq"] == freq
        xs_freq = xs_total[sel]
        ys_freq = ys_total[sel]
        zs_freq = zs_total[sel]
        color = freq_color_map.get(freq, "gray")

        label_added = False

        for beg, end in zip(steps_beg, steps_end):
            sel = (xs_freq >= beg) & (xs_freq <= end)
            if np.sum(sel) == 0:
                continue
            xs_sel = xs_freq[sel]
            ys_sel = ys_freq[sel]
            zs_sel = zs_freq[sel]
            print(f"BS {beg}-{end}: {len(xs_sel)} points")

            if len(xs_sel) < 2:
                print(f"Skipping batch size range {beg}-{end} due to insufficient data points.")
                continue

            regressor = RANSACRegressor(
                LinearRegression(
                    positive=True,
                ),
                max_trials=10000,
                loss="squared_error",
                random_state=1,
            )

            xy_sel = np.vstack((xs_sel, ys_sel)).T
            regressor.fit(xy_sel, zs_sel)
            coef_bs = regressor.estimator_.coef_[0]
            coef_tokens = regressor.estimator_.coef_[1]
            intercept = regressor.estimator_.intercept_
            r_2 = r2_score(zs_sel, regressor.predict(xy_sel))
            print(f"RANSAC Regression Coefficients: {coef_bs:.10f} (#reqs), {coef_tokens:.10f} (#computed tokens), Intercept: {intercept:.10f}, R^2: {r_2:.10f}")

            table_items.append({
                "freq": freq,
                "bs_start": beg,
                "bs_end": end,
                "coef_bs": coef_bs,
                "coef_tokens": coef_tokens,
                "coef_intercept": intercept,
                "r_2": r_2,
            })

            x_range = np.linspace(beg, end, 100)
            y_range = np.linspace(ys_sel.min(), ys_sel.max(), 100)
            X, Y = np.meshgrid(x_range, y_range)
            Z = regressor.predict(np.vstack((X.ravel(), Y.ravel())).T).reshape(X.shape)

            if freq in [1005, 1410, 1305, 1200, 1095]:

                N = 100
                if len(xs_sel) > N:
                    sel_draw = np.random.choice(len(xs_sel), N, replace=False)
                    xs_draw = xs_sel[sel_draw]
                    ys_draw = ys_sel[sel_draw]
                    zs_draw = zs_sel[sel_draw]
                    ax.scatter(xs_draw, ys_draw, zs_draw, c=color, s=50, alpha=0.5)

                if not label_added:
                    ax.plot_surface(X, Y, Z, color=color, alpha=0.5, rstride=100, cstride=100, label=f"f = {freq} MHz")
                    label_added = True
                else:
                    ax.plot_surface(X, Y, Z, color=color, alpha=0.5, rstride=100, cstride=100)

    ax.set_xlim(0)
    ax.set_ylim(0)
    ax.view_init(elev=35, azim=-135)
    ax.set_xlabel("#Reqs")
    ax.set_ylabel("#Computed Tokens", labelpad=10)
    ax.set_zlabel("ITL (ms)")
    fig.legend(fontsize=11, loc='upper left', bbox_to_anchor=(0.1, 0.85))
    # ax.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    # ax.set_yticks(np.arange(ax.get_yticks().min(), ax.get_yticks().max() + 1, ))
    # ax.set_title("ITL vs #Reqs and #Computed Tokens")


    fig.tight_layout()

    from matplotlib.transforms import Bbox
    custom_bbox = Bbox.from_bounds(x0=-0.2, y0=0, width=5.95, height=5.55)
    # fig.savefig(f"itl-prediction.pdf", dpi=300, bbox_inches=custom_bbox)
    plt.show()

    if save_csv:
        table_df = pd.DataFrame(table_items)
        table_df.sort_values(by=["freq", "bs_start", "bs_end"], inplace=True, ignore_index=True)
        csv_file = f"latency-decode.csv"
        print(f"Saving data to {csv_file}")
        table_df.to_csv(csv_file, index=False)

draw(df_itl, save_csv=True)