In [None]:
from latentis import PROJECT_ROOT

In [None]:
import pandas as pd
import wandb

In [None]:
api = wandb.Api()
runs = api.runs("resi_dual/residual", filters={"tags": "cr"})

In [None]:
from wandb.apis.public.runs import Run

data = []
for run in runs:
    run: Run

    data.append(
        {
            **run.config,
            "id": run.id,
            "name": run.name,
            "accuracy": run.summary.get("test/accuracy", None),
            "logits_loss": run.summary.get("test/logits_loss", None),
        }
    )
data = pd.DataFrame(data)
data.rename(
    columns={
        "encoder_name": "encoder",
        "dataset_name": "dataset",
    },
    inplace=True,
)
data

In [None]:
data.encoder.unique()

In [None]:
score_data = (
    data[
        [
            "dataset",
            "exp_type",
            "encoder",
            "accuracy",
            "logits_loss",
        ]
    ]
    .pivot(
        index=["dataset", "encoder"],
        columns="exp_type",
        values=["accuracy"],  # , "test/logits_loss"],
    )
    .droplevel(0, axis=1)
)
score_data.columns.name = None
score_data = score_data.reset_index()
score_data

In [None]:
score_data.to_csv(PROJECT_ROOT / "results" / "exp4.tsv", index=False, sep="\t")

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
from typing import Sequence

from residual.data import data_registry
from residual.nn.model_registry import model_names

exp_type2label = {
    "residual_fine": "ResiDual*",
    "residual_coarse": "Optimized",
    "residual_full": "ResiDual",
}


def create_model_diamond_plot(df, encoder: str, exp_types: Sequence[str]):
    encoder_data = df.copy()
    encoder_data = encoder_data[encoder_data["encoder"] == encoder]
    datasets = encoder_data["dataset"].values
    datasets = [data_registry.dataset_names[dataset] for dataset in datasets]

    angles = np.linspace(0, 2 * np.pi, len(datasets), endpoint=False).tolist()
    angles += angles[:1]  # Close the loop

    fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))

    for exp_type in exp_types:
        if exp_type not in encoder_data.columns:
            continue
        values = encoder_data[exp_type].values
        values = np.concatenate((values, [values[0]]))  # Close the loop
        ax.plot(angles, values, linewidth=2, label=exp_type2label[exp_type])
        ax.fill(angles, values, alpha=0.1)

    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(datasets)

    encoder_label = model_names[encoder]
    ax.set_yticklabels([])
    plt.title(f"{encoder_label}", size=15, color="b", y=1.1)
    plt.legend(loc="upper right", bbox_to_anchor=(0.1, 0.1))
    plt.savefig(PROJECT_ROOT / "results" / f"{encoder}_exp4.pdf")

    return fig, ax


encoders = score_data["encoder"].unique()
# encoders = ["openclip_l"]
for encoder in encoders:
    fig, ax = create_model_diamond_plot(
        score_data, encoder=encoder, exp_types=exp_type2label.keys()
    )

In [None]:
score_data

In [None]:
tab_encoders = {
    # "clip_b",
    # "clip_l",
    "openclip_b",
    "openclip_l",
    # "openai_l",
    # "blip_l_flickr",
}
# for encoder in tab_encoders:
encoder_data = (
    score_data[score_data.encoder.isin(tab_encoders)]
    .pivot(
        index="dataset",
        columns=("encoder"),
        values=[
            # "residual_coarse",
            "residual_fine",
            "residual_full",
        ],
    )
    .sort_index(axis=1, level=1)
).reset_index()
# encoder_data.columns = encoder_data.columns.droplevel(1)
encoder_data["dataset"] = encoder_data["dataset"].replace(data_registry.dataset_names)

encoder_data = encoder_data.rename(
    columns={
        "dataset": "Dataset",
        "openclip_l": "OpenCLIP-l",
        "clip_l": "CLIP-l",
        "blip_l_flickr": "BLIP-l",
        "linear_adapter": "Linear",
        "residual_full": "ResiDual",
        "residual_fine": "ResiDual*",
    },
)
numeric_means = encoder_data.select_dtypes(include="number").mean()
average_row = {
    col: numeric_means[col] if col in numeric_means.index else None
    for col in encoder_data.columns
}
encoder_data.loc["Average"] = average_row

# encoder_data
print(
    encoder_data.to_latex(
        float_format="{:0.2f}".format,
        index=False,
        column_format="l" + "c" * 9,
        multicolumn_format="c",
        multicolumn=True,
    )
)