In [None]:
import os

os.chdir("../")
print(os.getcwd())

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
import json
from typing import List, Union, Iterable
from pprint import pprint
import seaborn as sns


sns.set_theme(style="whitegrid")

In [None]:
def search_json(search_json_dir):
    if isinstance(search_json_dir, str):
        search_json_dir = [search_json_dir]
    search_json_dir = [Path(_dir) for _dir in search_json_dir]

    json_path_list = []
    # recursively find
    for _dir in search_json_dir:
        json_path_list.extend(list(_dir.rglob("metrics.json")))

    if len(json_path_list) == 0:
        raise FileNotFoundError(f"No json file found in the search_json_dir: {search_json_dir}")
    return json_path_list


def get_clean_exp_name(exp_name):
    clean_exp_name = exp_name
    clean_exp_name = clean_exp_name.split("-")[0]

    clean_exp_name = clean_exp_name.split("_")[-1]
    try:
        clean_exp_name = int(clean_exp_name)
    except ValueError:
        clean_exp_name = exp_name

    return clean_exp_name


def load_metrics_to_df(json_path_list, use_clean_exp_name=False):
    metrics_dict_list = []

    for json_path in json_path_list:
        # ignore "version_/metrics.json"
        exp_name = json_path.parents[1].name
        if use_clean_exp_name is True:
            clean_exp_name = get_clean_exp_name(exp_name)
        else:
            clean_exp_name = exp_name

        metrics_dict = {
            "path": json_path,
            "exp_name": exp_name,
            "clean_exp_name": clean_exp_name,
        }

        with open(json_path, "r") as f:
            metrics_dict_ = json.load(f)
        metrics_dict.update({k: v["accuracy"] for k, v in metrics_dict_.items()})
        metrics_dict_list.append(metrics_dict)

    df = pd.DataFrame(metrics_dict_list)
    df.sort_values("clean_exp_name", inplace=True)
    return df


def gather_and_plot_metrics(
    search_json_dir,
    use_clean_exp_name=False,
    *args,
    **kwargs,
):
    json_path_list = search_json(search_json_dir)
    df = load_metrics_to_df(json_path_list, use_clean_exp_name)

    return df


def plot_df(
    df,
    ax,
    plot_title=None,
    plot_xlabel=None,
    plot_ylabel=None,
    plot_xlim=None,
    plot_ylim=None,
    legend=True,
    *args,
    **kwargs,
):
    plot_df = df.drop(["path", "exp_name"], axis=1)
    plot_df.set_index("clean_exp_name", inplace=True)

    sns.lineplot(data=plot_df, linewidth=2.5, markers=True, ax=ax, legend=legend)

    if legend:
        sns.move_legend(ax, "upper left", bbox_to_anchor=(-0.6, 1))
    if plot_title is not None:
        ax.set_title(plot_title)
    if plot_xlabel is not None:
        ax.set_xlabel(plot_xlabel)
    if plot_ylabel is not None:
        ax.set_ylabel(plot_ylabel)

    if plot_xlim is not None:
        ax.set_xlim(*plot_xlim)

    if plot_ylim is not None:
        y_bottom, y_top = plot_ylim
        y_bottom = min(y_bottom, plot_df.min().min())
        y_top = max(y_top, plot_df.max().max())
        ax.set_ylim(y_bottom, y_top)

In [None]:
fig_title = "Thinking Budget, 1k SFT Data, 7B Model"
exp_arg_list = [
    {
        "search_json_dir": "outputs/250318-eval-medical_llm",
        "plot_title": "Epoch 5",
    },
]
exp_shared_arg = {
"plot_xlabel": "Thinking Budget (in # Tokens)",
"plot_ylabel": "Accuracy",
"plot_ylim": (0.45, 0.78),
# "plot_xlim": (0, 7000),
"use_clean_exp_name": False,
}


df_list = []
for exp_arg in exp_arg_list:
    exp_arg.update(exp_shared_arg)
    df = gather_and_plot_metrics(**exp_arg)
    df_list.append(df)
all_df = pd.concat(df_list)
all_df
output_path = f"outputs/{fig_title.replace('/', '_')}.tsv"
all_df.to_csv(output_path, sep="\t", index=False)
print(f"Saved to {output_path}")
display(all_df)


num_plots = len(df_list)
fig, axes = plt.subplots(1, num_plots, figsize=(7 * num_plots, 6))
if not isinstance(axes, Iterable):
    axes = [axes]
for idx, (df, ax, exp_arg) in enumerate(zip(df_list, axes, exp_arg_list)):
    if idx != 0:
        exp_arg["legend"] = False
    else:
        exp_arg["legend"] = True
    plot_df(df=df, ax=ax, **exp_arg)
fig.suptitle(fig_title)