In [None]:
from pathlib import Path

import torch
from latentis import PROJECT_ROOT

In [None]:
exp_dir: Path = PROJECT_ROOT / "results" / "exp3"
exp_dir.exists()

In [None]:
experiments = list(exp_dir.glob("*"))
len([exp.name for exp in experiments])

In [None]:
experiments

In [None]:
import pandas as pd

In [None]:
from collections import defaultdict

df = defaultdict(list)
for exp_path in experiments:
    # if "openclip" not in exp_path.name:
    #     continue
    if not exp_path.is_file():
        continue
    exp_data = torch.load(exp_path, map_location="cpu", weights_only=False)
    for ablation in exp_data["ablations"]:
        for k, v in ablation.items():
            if k == "keep_units" or k == "ablated_shape" or k == "decomp":
                continue
            if k == "residual_indices":
                df["n_units"].append(v.numel())
            df[k].append(v if not isinstance(v, torch.Tensor) else v.numpy())
        df["model"].append(exp_data["model_name"])
        df["dataset"].append(exp_data["dataset_name"])
df = pd.DataFrame(df)
df.drop_duplicates(subset=["model", "dataset", "type", "ablation"], inplace=True)
df.drop(columns=["decomp", "residual_indices"])

In [None]:
# f"greedy_{perc}%_corr_full_out_heads": "U",
# f"greedy_{perc}%_corr_task_heads": "U|T",
# f"greedy_{perc}%_supervised_heads": "S",
# f"greedy_{perc}%_random_0_heads": "R",
df.type.unique()

In [None]:
import matplotlib.pyplot as plt

from residual.data.data_registry import dataset_names
from residual.nn.model_registry import model_names

In [None]:
perc = 5
for selection_type, selection_label in (
    (f"greedy_{perc}%_corr_full_out_heads", "U"),
    (f"greedy_{perc}%_corr_task_heads", "U|T"),
    (f"greedy_{perc}%_supervised_heads", "S"),
):
    x = df[(df["type"] == selection_type) & (df["ablation"] == "zero")]
    for encoder in x["model"].unique():
        if encoder != "openclip_l":
            continue
        encoder_x = x[x["model"] == encoder][["dataset", "residual_indices"]].to_dict(
            orient="records"
        )
        dataset2indices = dict(
            list(
                zip(
                    [d["dataset"] for d in encoder_x],
                    [d["residual_indices"] for d in encoder_x],
                )
            )
        )
        jaccard_matrix = torch.zeros((len(encoder_x), len(encoder_x)))
        # compute rank-weighted jaccard similarity
        for i, (_dataset1, indices1) in enumerate(dataset2indices.items()):
            for j, (_dataset2, indices2) in enumerate(dataset2indices.items()):
                jaccard_matrix[i, j] = len(set(indices1).intersection(indices2)) / len(
                    set(indices1).union(indices2)
                )

        encoder_label = model_names[encoder]
        dataset_labels = [dataset_names[d["dataset"]] for d in encoder_x]

        plt.imshow(jaccard_matrix.numpy(), cmap="Blues")
        plt.title(f"{encoder_label} - {selection_label} - {perc}%")
        plt.xticks(range(len(encoder_x)), dataset_labels, rotation=45)
        plt.yticks(range(len(encoder_x)), dataset_labels)
        # add values in each cell
        for i in range(len(encoder_x)):
            for j in range(len(encoder_x)):
                color = "w" if jaccard_matrix[i, j] > 0.5 else "k"
                plt.text(
                    j,
                    i,
                    f"{jaccard_matrix[i, j]:.2f}",
                    ha="center",
                    va="center",
                    color=color,
                )
        plt.colorbar()
        plt.savefig(
            PROJECT_ROOT / "results" / f"{encoder}_{selection_type}_jaccard.pdf",
            bbox_inches="tight",
            pad_inches=0,
        )
        plt.show()

In [None]:
df["type"].unique()

In [None]:
df["selection_method"] = df["type"].apply(
    lambda x: "_".join(x.split("_")[2:]) if x.startswith("greedy") else "manual"
)
df

In [None]:
df["n_units"].unique()

In [None]:
df["selection_method"].unique()

In [None]:
import wandb

In [None]:
api = wandb.Api()
runs = api.runs("resi_dual/residual", filters={"config.exp_type": "residual_coarse"})

In [None]:
import torch.nn.functional as F
from latentis.space import Space
from wandb.apis.public.runs import Run

from residual.sparse_decomposition import SOMP

coarse_data = []
for run in runs:
    run: Run

    exp_type: str = run.config["exp_type"]
    dataset: str = run.config["dataset_name"]
    model: str = run.config.get("model_name", None)
    if model is None:
        model = run.config["encoder_name"]

    run_data = {
        "model": model,
        "dataset": dataset,
        "type": exp_type,
        "score": run.summary.get("test/accuracy", None),
        "selection_method": "optimized",
    }
    encoding_path = (
        PROJECT_ROOT / "optimized" / dataset / "test" / f"{model}_{exp_type}_encodings"
    )
    dictionary_path = PROJECT_ROOT / "dictionaries" / "textspan" / f"{model}.pt"

    if encoding_path.exists() and dictionary_path.exists():
        space = Space.load_from_disk(path=encoding_path).as_tensor()
        decomposition = SOMP(k=10)

        decomp_dictionary = torch.load(
            dictionary_path,
            weights_only=False,
            # map_location=device,
        )

        decomp_out = decomposition(
            X=space,
            dictionary=F.normalize(decomp_dictionary["encodings"]),
            descriptors=decomp_dictionary["dictionary"],
            device="cpu",
        )
        run_data["descriptions"] = [str(x) for x in decomp_out["results"]]

    coarse_data.append(run_data)
coarse_data = pd.DataFrame(coarse_data)

coarse_data

In [None]:
filtered_df = df.copy()
# filtered_df = filtered_df[filtered_df["model"].isin({"openclip_l", "blip_l_flickr"})]
filtered_df["descriptions"] = filtered_df["decomp"].apply(
    lambda x: x["results"][:10] if x is not None else None
)
filtered_df.drop(columns=["decomp"], inplace=True)
filtered_df.drop("residual_indices", axis=1, inplace=True)
filtered_df = filtered_df[(filtered_df["ablation"] != "mean")]
perc = 5
# filtered_df = filtered_df[
#     (
#         filtered_df["type"].str.contains(f"greedy_{perc}%_random")
#         | ~filtered_df["type"].str.contains("random")
#     )
# ]

types = {
    f"greedy_{perc}%_corr_full_out_heads": "U",
    f"greedy_{perc}%_corr_task_heads": "U|T",
    f"greedy_{perc}%_supervised_heads": "S",
    f"greedy_{perc}%_random_0_heads": "R",
    # "random_mean": "R",
    "heads": "H",
    "units": "B",
    "residual_coarse": "O",
    # **{f"greedy_10%_random_{i}_heads": "R" for i in range(10)},
}

random_rows = filtered_df[filtered_df["type"].str.contains("random")]

# Step 1: Filter rows where 'type' contains 'random'
filtered_df = filtered_df[filtered_df["type"].isin(types.keys())]

# Step 2: Group by 'model' and 'dataset' and calculate mean and std for each group
grouped_random = random_rows.groupby(["model", "dataset"])

# Initialize an empty DataFrame to store results
result_df = filtered_df[~filtered_df["type"].str.contains("random")].copy()

# Initialize a list to collect new rows
new_rows = []

# Loop through each group and calculate the mean and std, then append to the result dataframe
# for (model, dataset), group in grouped_random:
#     # Calculate mean and std for the group
#     group_mean = group.mean(numeric_only=True)
#     group_std = group.std(numeric_only=True)

#     # Prepare new rows for mean and std
#     mean_row = pd.Series(group_mean, name=f"random_mean_{model}_{dataset}")
#     std_row = pd.Series(group_std, name=f"random_std_{model}_{dataset}")

#     # Add 'model' and 'dataset' information
#     mean_row["model"] = model
#     mean_row["dataset"] = dataset
#     mean_row["ablation"] = "zero"
#     mean_row["type"] = "random_mean"

#     std_row["model"] = model
#     std_row["dataset"] = dataset
#     std_row["ablation"] = "zero"
#     std_row["type"] = "random_std"

#     # Append the mean and std rows to the list of new rows
#     new_rows.append(mean_row)
#     new_rows.append(std_row)

# Step 3: Convert the list of new rows into a DataFrame and concatenate with the result dataframe
# new_rows_df = pd.DataFrame(new_rows)
# result_df = pd.concat([result_df, new_rows_df], ignore_index=True)
# result_df = result_df[result_df["type"] != "random_std"]

result_df = filtered_df.copy()
filtered_df

In [None]:
result_df = pd.concat(
    [result_df, coarse_data[coarse_data["model"].isin(result_df["model"].unique())]],
    ignore_index=True,
)
result_df.sort_values(by=["model", "dataset", "type"], inplace=True)
result_df["model"] = result_df["model"].apply(lambda x: model_names[x])
result_df["dataset"] = result_df["dataset"].apply(lambda x: dataset_names[x])

result_df["type"] = result_df["type"].apply(types.__getitem__)
result_df["type"] = pd.Categorical(result_df["type"], categories=types.values())
result_df

In [None]:
descriptions_df = result_df.explode("descriptions").reset_index(drop=True)

# Add a progressive ID for each sublist
descriptions_df["description_id"] = (
    descriptions_df.groupby(["model", "type", "dataset"], observed=True).cumcount() + 1
)
descriptions_df = descriptions_df[descriptions_df["description_id"] <= 3]
# exploded_df = exploded_df[exploded_df["type"] != "R"]
descriptions_df = descriptions_df.rename({"descriptions": "description"}, axis=1)
descriptions_df

In [None]:
descriptions_table = descriptions_df.pivot(
    index=["dataset", "model", "type"], columns=["description_id"], values="description"
).fillna(0)
descriptions_table = descriptions_table.to_latex(
    multirow=True, column_format="c", multicolumn_format="c", float_format="%.2f"
)
print(descriptions_table)

In [None]:
table = result_df.pivot(
    index=["dataset"], columns=["model", "type"], values="score"
).fillna(0)
# reorder columns
table = table[
    sorted(
        table.columns,
        key=lambda x: (x[0], list(types.values()).index(x[1])),
    )
]
table.loc["Average"] = table.mean()

table = table.to_latex(
    multirow=True, column_format="c", multicolumn_format="c", float_format="%.2f"
)
print(table)