In [None]:
import os
import glob
import re

import numpy as np
import pandas as pd
import orjson

import matplotlib.pyplot as plt

from sklearn.linear_model import LinearRegression, RANSACRegressor
from sklearn.metrics import mean_squared_error, r2_score

In [None]:
%matplotlib widget

data_list = []

for dir in os.listdir("."):
    if not os.path.isdir(dir):
        continue
    print(f"Processing dir: {dir}")
    freq = int(dir)
    for file in glob.glob(f"{dir}/*-extra.json"):
        print(f"Processing file: {file}")
        matches = re.match(r"rps-(\d+)", file)
        rps = int(matches.group(1)) if matches else None

        with open(file, "rb") as f:
            reqs = orjson.loads(f.read())
            for req_id, items in reqs.items():
                for item in items:
                    if item["forward_mode"] != 1:
                        continue
                    iteration_counter = item["iteration_counter"]
                    tokens = sum(item["num_total_computing_tokens_list"])
                    ttft_exec = (item["timestamp_before_output"] - item["timestamp_begin"]) * 1000
                    data_list.append({
                        "freq": freq,
                        "rps": rps,
                        "iteration_counter": iteration_counter,
                        "tokens": tokens,
                        "ttft_exec": ttft_exec,
                    })

df = pd.DataFrame(data_list)

df

In [None]:
df.drop_duplicates(subset=["iteration_counter"], inplace=True, keep="last")

df.sort_values(by=["freq", "tokens", "ttft_exec"], inplace=True)

df

In [None]:
plt.rcParams["font.size"] = 18

fig, ax = plt.subplots(figsize=(12, 5))

table_list = []

for freq in df["freq"].unique():
    df_freq = df[df["freq"] == freq]
    xs = df_freq["tokens"].to_numpy()
    ys = df_freq["ttft_exec"].to_numpy()


    sel = (xs < 20000) & (xs > 2000)
    xs_sel = xs[sel]
    ys_sel = ys[sel]

    model = RANSACRegressor(
        LinearRegression(
            positive=True,
        ),
        max_trials=100000,
        # loss="squared_error",
        random_state=1,
    )
    model.fit(xs_sel.reshape(-1, 1), ys_sel)
    y_pred = model.predict(xs.reshape(-1, 1))
    r2 = r2_score(ys, y_pred)

    if freq >= 1005:
        ax.scatter(xs, ys, alpha=0.5)
        ax.plot(xs, y_pred, label=f"f = {freq} MHz (r² = {r2:.5f})", linewidth=2)

    table_list.append({
        "freq": freq,
        "coef_tokens": model.estimator_.coef_[0],
        "coef_intercept": model.estimator_.intercept_,
        "r_2": r2,
    })

ax.grid(True)
ax.legend(fontsize=11, loc='upper left')
ax.set_xlim(0)
ax.set_ylim(0)

ax.set_xlabel("#Batched Tokens")
ax.set_ylabel("TTFT (ms)")

# ax.set_xlim(0, 3000)
# ax.set_ylim(0, 100)

fig.tight_layout()
fig.savefig("ttft_vs_tokens.pdf", dpi=300, bbox_inches='tight')
fig.show()

table = pd.DataFrame(table_list)
table.sort_values(by="freq", inplace=True, ignore_index=True)

table.to_csv("latency-prefill.csv", index=False)