In [None]:
from pathlib import Path

import pandas as pd
import torch
from latentis import PROJECT_ROOT

In [None]:
exp_dir: Path = PROJECT_ROOT / "results" / "exp1"
exp_dir.exists()

In [None]:
experiments = list(exp_dir.glob("*"))
len([exp.name for exp in experiments])

In [None]:
experiments

In [None]:
import matplotlib.pyplot as plt
import numpy as np

from residual.nn.model_registry import model_names

# plt.rcParams.update(bundles.tmlr2023())
cmap = "RdBu_r"

In [None]:
models = ["openclip_l", "clip_l", "dinov2_l", "vit_l", "blip_l_flickr"]


data = []

pca_threshold = 0.99
for exp_path in experiments:
    exp_data = torch.load(exp_path, map_location="cpu", weights_only=False)
    if exp_data["dataset_name"] != "imagenet" or exp_data["encoder_name"] not in models:
        continue
    for (layer_idx, head_idx, unit_type), unit_data in exp_data["data"].items():
        model_name = exp_data["encoder_name"]
        pca_id = torch.where(unit_data["pca_evr"] > pca_threshold)[0][0] + 1
        twonn_id = unit_data["id_twonn"]
        evr_1 = unit_data["pca_evr"][:1]

        data.append(
            dict(
                model=model_name,
                dataset=exp_data["dataset_name"],
                layer_idx=layer_idx if unit_type != "emb" else -1,
                head_idx=head_idx,
                unit_type=unit_type,
                pca_id=pca_id.item(),
                twonn_id=twonn_id,
                evr_1=evr_1.item(),
            )
        )

data = pd.DataFrame(data)
data

In [None]:
# from tueplots import figsizes
import numpy.ma as ma

# plt.rcParams.update(figsizes.tmlr2023(ncols=1, nrows=1))
fig, ax = plt.subplots(nrows=1, ncols=len(models), figsize=(15, 11))

unit_type: str = "mlp"

for model_idx, model in enumerate(models):
    model_data = data[data["model"] == model]
    model_data = model_data[model_data.unit_type == unit_type]
    layer_data = model_data.groupby(["layer_idx", "unit_type"], dropna=False)  # .mean()
    y_labels = layer_data[["layer_idx", "unit_type"]]
    y_labels = [f"{layer_idx}" for (layer_idx, unit_type), _ in y_labels]

    layer_data = layer_data.aggregate(
        {"pca_id": "mean", "twonn_id": "mean", "evr_1": "mean"}
    )
    pca_mean = torch.as_tensor(layer_data["pca_id"].to_numpy())
    twonn_mean = torch.as_tensor(layer_data["twonn_id"].to_numpy())
    evr_mean = torch.as_tensor(layer_data["evr_1"].to_numpy())

    normalized_twonn_mean = twonn_mean / twonn_mean.max()
    normalized_pca_mean = pca_mean / pca_mean.max()
    normalized_evr_mean = evr_mean / evr_mean.max()
    ratio_mean = pca_mean / twonn_mean
    normalized_ratio_mean = ratio_mean / ratio_mean.max()
    plot_data = torch.stack(
        [
            normalized_pca_mean,
            normalized_twonn_mean,
            normalized_ratio_mean,
            normalized_evr_mean,
        ],
        dim=1,
    ).numpy()
    annot_data = torch.stack(
        [pca_mean, twonn_mean, ratio_mean, evr_mean], dim=1
    ).numpy()

    im = ax[model_idx].imshow(
        plot_data,
        aspect="auto",
        cmap=cmap,  # Main colormap
        origin="lower",
        vmin=-0.1,
        vmax=1.2,
    )

    # Create a masked array for the specific column with a different colormap
    column_index = 3  # Index of the column to highlight
    masked_data = ma.masked_array(plot_data, mask=False)
    masked_data[:, :column_index] = (
        ma.masked
    )  # Mask all columns before the target column
    masked_data[:, column_index + 1 :] = (
        ma.masked
    )  # Mask all columns after the target column

    # Overlay the selected column with a different colormap
    ax[model_idx].imshow(
        masked_data,
        aspect="auto",
        cmap="Greys",  # Different colormap for this column
        origin="lower",
        vmin=-0,
        vmax=1.5,
    )

    for i in range(plot_data.shape[0]):
        for j in range(plot_data.shape[1]):
            value = annot_data[i, j]
            map_value = plot_data[i, j]
            if j < 3:
                color = (
                    "black" if (map_value < 0.95 and map_value > 0.2) else "white"
                )  # Adjust text color based on cell color
                ax[model_idx].text(
                    j, i, f"{value:.2f}", ha="center", va="center", color=color
                )  # , fontsize=14)
            else:
                color = (
                    "white" if map_value > 0.8 else "black"
                )  # Adjust text color based on cell color
                ax[model_idx].text(
                    j, i, f"{value:.2f}", ha="center", va="center", color=color
                )  # , fontsize=14)
            ax[model_idx].text(
                j, i, f"{value:.2f}", ha="center", va="center", color=color
            )  # , fontsize=14)
    ax[model_idx].set_title(model_names[model])  # , fontsize=14)
    ax[model_idx].set_xticks(np.arange(plot_data.shape[1]))
    ax[model_idx].set_xticklabels(["L", "N", "Ratio", r"EVR$_1$"])  # , fontsize=14)
    if model_idx == 0:
        ax[model_idx].set_yticks(np.arange(plot_data.shape[0]))
        ax[model_idx].set_yticklabels(y_labels)  # , fontsize=14)
        ax[model_idx].set_ylabel("Layer")  # , fontsize=14)
    else:
        ax[model_idx].set_yticks([])

plt.savefig(
    str(PROJECT_ROOT / "results" / f"{unit_type}_ids.pdf"),
    dpi=200,
    bbox_inches="tight",
    format="pdf",
)