In [None]:
from pathlib import Path

import pandas as pd
import torch
from latentis import PROJECT_ROOT

In [None]:
exp_dir: Path = PROJECT_ROOT / "results" / "exp2"
exp_dir.exists()

In [None]:
experiments = list(exp_dir.glob("*"))
len([exp.name for exp in experiments])

In [None]:
# plt.rcParams.update(bundles.tmlr2023())

In [None]:
from residual.data.data_registry import dataset_names

metric = "spectral_distances"
datasets = list(dataset_names.keys())
models = [
    "openclip_l",
    "clip_l",
    "blip_l_flickr",
    "blip_l_coco",
    "dinov2_l",
    "vit_l",
]

data = []
for exp in experiments:
    exp_data = torch.load(exp, map_location="cpu", weights_only=True)
    dataset1 = exp_data["dataset1"]
    dataset2 = exp_data["dataset2"]
    encoder = exp_data["encoder_name"]

    if dataset1 != "imagenet":
        continue

    if dataset2 in datasets and encoder in models:
        spectral_distances = exp_data["spectral_distances"]
        distances_shape = spectral_distances.shape

        data.append(
            {
                "encoder": encoder,
                "dataset1": dataset1,
                "dataset2": dataset2,
                "spectral_distances": spectral_distances,
                "distances_shape": distances_shape,
            }
        )
        # data.append(
        #     {
        #         "encoder": encoder,
        #         "dataset1": dataset2,
        #         "dataset2": dataset1,
        #         "spectral_distances": spectral_distances,
        #         "distances_shape": distances_shape,
        #     }
        # )
data = pd.DataFrame(data)
data = data.sort_values(by=["encoder", "dataset1", "dataset2"])
# visualize data without the "spectral_distances" column since it's too large
data.drop(columns=["spectral_distances"])

In [None]:
similarities = torch.stack(data["spectral_distances"].tolist(), dim=0).reshape(
    len(data.encoder.unique()),
    len(data.dataset2.unique()),
    24,
    16,
    24,
    16,
)
similarities.shape

In [None]:
from matplotlib.figure import Figure

from residual.plot import blocked_heatmap

In [None]:
for model_index, model in enumerate(data.encoder.unique()):
    model_similarities = similarities[model_index]
    model_similarities = model_similarities.view(
        model_similarities.shape[0],
        model_similarities.shape[1] * model_similarities.shape[2],
        -1,
    )

    # we are interested only in the similarity between corresponding units
    model_similarities = model_similarities[
        :,
        torch.arange(model_similarities.shape[1]),
        torch.arange(model_similarities.shape[1]),
    ]

    heatmap: Figure = blocked_heatmap(
        data=model_similarities,
        block_size=16,
        y_labels=data.dataset2.unique().tolist(),
    )
    heatmap.suptitle(f"{model}")
    heatmap.show()

    heatmap.savefig(f"{model}_allheads.pdf", dpi=200, bbox_inches="tight", format="pdf")