In [2]:
import wandb 
import pandas as pd
from pathlib import Path
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib

WANDB_PROJECT = "hdbo-embeddings-benchmark"
WANDB_ENTITY = "hdbo-benchmark"

matplotlib.rcParams.update({"text.usetex": True})
sns.set_theme(style="whitegrid", font_scale=1.75)

In [3]:
def get_all_runs_for_experiment(
    experiment_name: str,
    solver_name: str | None = None,
    function_name: str | None = None,
    n_dimensions: int | None = None,
    seed: int | None = None,
    tags: list[str] | None = None,
) -> list[wandb.apis.public.Run]:
    api = wandb.Api()

    filter_: dict[str, str | int | dict] = {
        "state": {"$in": ["finished", "running", "failed", "crashed"]},
    }
    if solver_name is not None:
        filter_["config.solver_name"] = solver_name

    if function_name is not None:
        filter_["config.function_name"] = function_name

    if n_dimensions is not None:
        # If it's a discrete solver, we don't need to filter by n_dimensions
        if solver_name in ["bounce", "pr"]:
            pass
        else:
            filter_["config.n_dimensions"] = n_dimensions

    if seed is not None:
        filter_["config.seed"] = seed

    if tags is not None:
        filter_["tags"] = {"$in": tags}

    runs = api.runs(
        f"{WANDB_ENTITY}/{experiment_name}",
        filter_ if filter_ else None,
    )
    return list(runs)

In [5]:
def convert_data_to_dataframes(
    all_data: list[wandb.apis.public.Run],
) -> list[pd.DataFrame]:
    dfs: list[pd.DataFrame] = []
    for i, run in enumerate(all_data):
        print(f"Processing run {i + 1}/{len(all_data)}")
        df = run.history()
        df["solver_name"] = run.config["solver_name"]
        df["function_name"] = run.config["function_name"]
        df["seed"] = run.config["seed"]
        df["poli_hash"] = run.config["poli_hash"]
        df["hdbo_benchmark_hash"] = run.config["hdbo_benchmark_hash"]
        df["poli_baselines_hash"] = run.config["poli_baselines_hash"]
        df["state"] = run.state
        dfs.append(df)

    return dfs


def create_base_table(
    n_dimensions: int = 128,
    save_cache: bool = True,
    use_cache: bool = False,
    tags: list[str] | None = None,
) -> pd.DataFrame:
    CACHE_PATH = Path("cache/results_cache")
    CACHE_PATH.mkdir(exist_ok=True, parents=True)
    tags_str = "-".join(tags) if tags is not None else "all"
    CACHE_FILE = (
        CACHE_PATH / f"base_table-n_dimensions-{n_dimensions}-tags-{tags_str}.csv"
    )

    if use_cache and CACHE_FILE.exists():
        df = pd.read_csv(CACHE_FILE)
        return df

    all_runs = get_all_runs_for_experiment(
        experiment_name="benchmark_on_pmo", n_dimensions=n_dimensions, tags=tags
    )

    # Append with the results from PR on 2D
    if n_dimensions != 2:
        pr_runs = get_all_runs_for_experiment(
            experiment_name="benchmark_on_pmo",
            solver_name="pr",
            n_dimensions=2,
            tags=tags,
        )
        all_runs.extend(pr_runs)

    # Append the results of Bounce on 128D
    if n_dimensions != 128:
        bounce_runs = get_all_runs_for_experiment(
            experiment_name="benchmark_on_pmo",
            solver_name="bounce",
            n_dimensions=128,
            tags=tags,
        )
        all_runs.extend(bounce_runs)

    all_dfs = convert_data_to_dataframes(all_runs)

    df = pd.concat(all_dfs)
    if save_cache:
        df.to_csv(CACHE_FILE, index=False)

    return df

if __name__ == "__main__":
    df = create_base_table(
        n_dimensions=2,
        save_cache=True,
        use_cache=False,
        tags=None,
    )
    print(df)

Processing run 1/1025
Processing run 2/1025
Processing run 3/1025
Processing run 4/1025
Processing run 5/1025
Processing run 6/1025
Processing run 7/1025
Processing run 8/1025
Processing run 9/1025
Processing run 10/1025
Processing run 11/1025
Processing run 12/1025
Processing run 13/1025
Processing run 14/1025
Processing run 15/1025
Processing run 16/1025
Processing run 17/1025
Processing run 18/1025
Processing run 19/1025
Processing run 20/1025
Processing run 21/1025
Processing run 22/1025
Processing run 23/1025
Processing run 24/1025
Processing run 25/1025
Processing run 26/1025
Processing run 27/1025
Processing run 28/1025
Processing run 29/1025
Processing run 30/1025
Processing run 31/1025
Processing run 32/1025
Processing run 33/1025
Processing run 34/1025
Processing run 35/1025
Processing run 36/1025
Processing run 37/1025
Processing run 38/1025
Processing run 39/1025
Processing run 40/1025
Processing run 41/1025
Processing run 42/1025
Processing run 43/1025
Processing run 44/10

In [6]:
solver_name_but_pretty = {
    "random_mutation": "HillClimbing",             # r"\texttt{HillClimbing}",
    # "cma_es": r"\texttt{CMAES}",
    "vanilla_bo_hvarfner": "Hvarfner's&nbsp;VanillaBO", # r"Hvarfner's \texttt{VanillaBO}",
    "line_bo": "RandomLineBO",                     # r"\texttt{RandomLineBO}",
    "saas_bo": "SAASBO",                           # r"\texttt{SAASBO}",
    # "alebo": r"\texttt{ALEBO}",
    # "turbo": r"\texttt{Turbo}",
    # "baxus": r"\texttt{BAxUS}",
    "bounce": "Bounce",                            # r"\texttt{Bounce}",
    "pr":     "ProbRep",                           # r"\texttt{ProbRep}",
}

def summary_per_function(
    df: pd.DataFrame,
    normalized_per_row: bool = True,
) -> tuple[pd.DataFrame, pd.DataFrame]:
    rows = []
    for function_name in df["function_name"].unique():
        for solver_name in df["solver_name"].unique():
            slice_df = df[
                (df["function_name"] == function_name)
                & (df["solver_name"] == solver_name)
            ]
            slice_df = slice_df[slice_df["seed"].isin([1, 2, 3])]
            if (
                len(slice_df["seed"].unique()) != 3
                and solver_name in solver_name_but_pretty.keys()
            ):
                # Only prints for one function (valsartan_smarts) in Bounce.
                # Should be fixed by Wednesday.
                print(
                    f"Something's fishy with {function_name} and {solver_name} ({slice_df['seed'].unique()})"
                )
            # assert len(slice_df["seed"].unique()) == 3
            best_y_per_seed = slice_df.groupby("seed")["y"].max()
            best_y = best_y_per_seed.mean()
            best_y_std = best_y_per_seed.std()
            rows.append(
                {
                    "function_name": function_name,
                    "solver_name": solver_name,
                    "average_best_y": best_y,
                    "std_best_y": best_y_std,
                }
            )

    summary = pd.DataFrame(rows)
    summary_avg = summary.pivot(
        index="function_name", columns="solver_name", values="average_best_y"
    )

    summary_std = summary.pivot(
        index="function_name", columns="solver_name", values="std_best_y"
    )

    # Normalize each row to be a percentage of the best value
    if normalized_per_row:
        for i, row in summary_avg.iterrows():
            best_value = row.max()
            if best_value == 0:
                continue
            summary_avg.loc[i] = row / best_value  # type: ignore

    return summary_avg, summary_std


def plot_heatmap(df, normalized: bool = True):
    summary_avg, _ = summary_per_function(df, normalized_per_row=normalized)

    # We keep the columns in solver_name_but_pretty order
    summary_avg = summary_avg[solver_name_but_pretty.keys()]

    # Rename columns to their pretty names
    summary_avg.columns = [solver_name_but_pretty[col] for col in summary_avg.columns]

    # Adjust the size of the figure to make squares smaller
    fig, ax = plt.subplots(1, 1, figsize=(17, 5))  # Adjust these numbers as needed

    # Capture the heatmap in a variable
    hmap = sns.heatmap(
        summary_avg.T,
        ax=ax,
        cmap="inferno",
        cbar_kws={"orientation": "vertical", "pad": 0.01},
    )

    # Modify the colorbar to only show min and max
    colorbar = hmap.collections[0].colorbar
    colorbar.set_ticks([colorbar.vmin, colorbar.vmax])
    colorbar.set_ticklabels(["min", "max"])

    ax.set_xlabel("")
    ax.set_ylabel("")

    plt.xticks(rotation=45, ha="right", rotation_mode="anchor")

    ax.set_title(
        f"Avg. Best Value ({n_dimensions}D latent space, 3 seeds, max. 10+100 function calls)"
        + "\n",
        fontsize=25,
    )
    fig.tight_layout()
    # fig.savefig(
    #     ROOT_DIR / "reports" / "figures" / f"table_as_heatmap_pmo_{n_dimensions}.jpg",
    #     dpi=300,
    #     bbox_inches="tight",
    # )
    plt.show()


def print_table(df, normalized: bool = False):
    summary_avg, summary_std = summary_per_function(df, normalized_per_row=normalized)

    final_table_rows: list[dict[str, str]] = []
    for function_name in summary_avg.index:
        row = {
            "Oracle": function_name#.replace("_", r"\_"),
        }
        for solver_name, pretty_solver_name in solver_name_but_pretty.items():
            if solver_name not in summary_avg.columns:
                row[pretty_solver_name] = r"\alert{[TBD]}"
                continue
            average = summary_avg.loc[function_name, solver_name]
            std = summary_std.loc[function_name, solver_name]

            if np.isnan(average):
                row[pretty_solver_name] = r"\alert{[TBD]}"
            else:
                avg = f"{average:.2f}"
                std = f"{std:.2f}" if not np.isnan(std) else r"\alert{?}"
                row[pretty_solver_name] = f"{avg}&nbsp;&plusmn;&nbsp;{std}"

        final_table_rows.append(row)

    final_table = pd.DataFrame(final_table_rows)
    final_table.set_index("Oracle", inplace=True)
    return final_table


n_dimensions = 2
normalized = True
tags = ["2024-06-02", "2024-06-01", "2024-05-31", "Old-PR-Results"]
tags: None = None
df = create_base_table(
    n_dimensions=n_dimensions, save_cache=False, use_cache=True, tags=tags
)
# plot_heatmap(df, normalized=normalized)
final_table = print_table(df, normalized=False)
final_table.to_csv("csv/benchmark_task.csv")

Something's fishy with osimetrinib_mpo and line_bo ([3 2])
Something's fishy with rdkit_qed and line_bo ([2 1])
