# Define utility functions

In [1]:
import pandas as pd


def apply_filters(filters, info):
    for k, v in filters.items():
        if (item := info.get(k)) is None:
            continue
        if isinstance(v, list):
            if v[0] == "not":
                if item in v[1:]:
                    return False
            elif info[k] not in v:
                return False
        else:
            if item != v:
                return False
    return True


def read_eq_dict_file(file_path, **filters):
    data = []
    convert = {
        "best_layer": lambda x: int(x.split(".")[-1]),
        "custom_label": lambda x: tuple(x.replace("+", " ").split(",")),
        "model": str,
        "task": str,
        "shot": int,
        "ext_shot": int,
        "ext_batch": int,
        "episodes": int,
        "ie_episodes": int,
        "n_top_heads": int,
        "seed": int,
        "eval_episodes": int,
    }
    with open(file_path) as f:
        lines = f.read().strip().splitlines()
    for line in lines:
        if line.lstrip().startswith("#"):
            continue
        info = {}
        for x in line.split():
            try:
                k, v = x.split("=")
            except ValueError:
                print(f"Error in line: {x}")
                continue
            info[k] = convert.get(k, float)(v)
        if apply_filters(filters, info):
            data.append(info)
    return pd.DataFrame(data)


def best_hyperparams(
    result_fn,
    filters,
    split_by="task",
    card_split=1,
    metric="macro",
    hyperparams=["strength", "ext_strength", "ext_shot"],
    fish_shell=False,
    return_pivot=False,
):
    val_results = read_eq_dict_file(result_fn, **filters)
    if val_results.empty:
        print(f"No data found for {result_fn} and {filters}")
        return

    # warn if there are duplicates
    dup_col = val_results.columns.difference(["macro", "micro", "weighted"])
    dups = val_results.duplicated(subset=dup_col)
    if dups.any():
        print(f"{dups.sum()} duplicate(s) found:")
        display(val_results[dups])

    best_indices = val_results.groupby(split_by)[metric].idxmax()
    best_results = val_results.loc[best_indices]

    params = {}
    hyperparams = [split_by] + hyperparams
    for col in set(hyperparams):
        params[col] = list(map(str, best_results[col].values))
    n_tasks = len(params[split_by])
    card_tasks = [n_tasks // card_split] * card_split
    for i in range(n_tasks % card_split):
        card_tasks[i] += 1
    for i in range(card_split):
        args = []
        for p, v in filters.items():
            if v is True:
                args.append(f"set {p} --{p}" if fish_shell else f'{p}="--{p}"')
            elif v is False:
                args.append(f"set {p}" if fish_shell else f'{p}=""')
            else:
                args.append(f"set {p} {v}" if fish_shell else f'{p}="{v}"')
        print("\n".join(args))
        for col, p in params.items():
            if col not in filters:
                sp = p[sum(card_tasks[:i]) : sum(card_tasks[: i + 1])]
                print(
                    f"set {col}s {' '.join(sp)}"
                    if fish_shell
                    else f'{col}s=({" ".join(sp)})'
                )
        print()

    best_results.sort_values(by=split_by, inplace=True)
    if return_pivot:  # pivot for better view
        return best_results.pivot_table(
            index=split_by,
            values=[
                "strength",
                "ext_strength",
                "macro",
                "micro",
                "weighted",
                "ext_shot",
                "seed",
            ],
        )
    return best_results

# Pick best hyperparameters

In [None]:
result_fn = "results/<your_config>/edit.txt"  # path to the results file
filters = dict(model='llama-2-7b', episodes=200)  # select model here
card_split = 4  # how many GPU cards to split the tasks into
fish_shell = True  # change to False if you use bash

best_hyperparams(result_fn, filters, card_split=card_split, fish_shell=fish_shell)