In [None]:
import torch
from pathlib import Path
from latentis import PROJECT_ROOT

In [None]:
exp_dir: Path = PROJECT_ROOT / "results" / "exp1"
exp_dir.exists()

In [None]:
experiments = list(exp_dir.glob("*"))
len([exp.name for exp in experiments])

In [None]:
experiments

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from residual.nn.model_registry import model_names

# plt.rcParams.update(bundles.tmlr2023())
cmap = "RdBu_r"

In [None]:
models = ["openclip_l", "clip_l", "dinov2_l", "vit_l", "blip_l_flickr"]

twonn_id = torch.zeros(len(models), 24, 16)
pca_id = torch.zeros(len(models), 24, 16)
evr_std = torch.zeros(len(models), 24, 16)

pca_threshold = 0.99
for exp_path in experiments:
    exp_data = torch.load(exp_path, map_location="cpu", weights_only=False)
    if exp_data["dataset_name"] != "imagenet" or exp_data["encoder_name"] not in models:
        continue
    for (layer_idx, head_idx, _unit_type), unit_data in exp_data["data"].items():
        model_idx = models.index(exp_data["encoder_name"])
        pca_id[model_idx, layer_idx, head_idx] = (
            torch.where(unit_data["pca_evr"] > pca_threshold)[0][0] + 1
        )
        twonn_id[model_idx, layer_idx, head_idx] = unit_data["id_twonn"]
        evr_std[model_idx, layer_idx, head_idx] = unit_data["pca_evr"][:64].std()

In [None]:
from tueplots import figsizes

# plt.rcParams.update(figsizes.tmlr2023(ncols=1, nrows=1))
fig, ax = plt.subplots(nrows=1, ncols=len(models) - 1, figsize=(8, 4))
for m, model in enumerate(models):
    pca_mean = pca_id[m].mean(dim=1, keepdim=True)
    std_mean = evr_std[m].mean(dim=1, keepdim=True)
    print(pca_mean.shape)
    twonn_mean = twonn_id[m].mean(dim=1, keepdim=True)
    normalized_twonn_mean = twonn_mean / twonn_mean.max()
    normalized_pca_mean = pca_mean / pca_mean.max()
    normalized_std_mean = std_mean / std_mean.max()
    ratio_mean = (pca_id[m] / twonn_id[m]).mean(dim=1, keepdim=True)
    normalized_ratio_mean = ratio_mean / ratio_mean.max()
    plot_data = torch.cat(
        [
            normalized_pca_mean,
            normalized_twonn_mean,
            normalized_ratio_mean,
            normalized_std_mean,
        ],
        dim=1,
    ).numpy()
    annot_data = torch.cat([pca_mean, twonn_mean, ratio_mean, std_mean], dim=1).numpy()

    im = ax[m].imshow(
        plot_data,
        aspect="auto",
        cmap=cmap,
        origin="lower",
        vmin=0,
        vmax=1,
    )
    for i in range(plot_data.shape[0]):
        for j in range(plot_data.shape[1]):
            value = annot_data[i, j]
            map_value = plot_data[i, j]
            color = (
                "black" if (map_value < 0.8 and map_value > 0.3) else "white"
            )  # Adjust text color based on cell color
            ax[m].text(
                j, i, f"{value:.2f}", ha="center", va="center", color=color
            )  # , fontsize=14)
    ax[m].set_title(model_names[model])  # , fontsize=14)
    ax[m].set_xticks(np.arange(plot_data.shape[1]))
    ax[m].set_xticklabels(["PCA", "TwoNN", "Ratio", "Std"])  # , fontsize=14)
    if m == 0:
        ax[m].set_yticks(np.arange(plot_data.shape[0]))
        ax[m].set_yticklabels(np.arange(plot_data.shape[0]))  # , fontsize=14)
        ax[m].set_ylabel("Layer")  # , fontsize=14)
    else:
        ax[m].set_yticks([])

plt.savefig("id_columns_99.pdf", dpi=200, bbox_inches="tight", format="pdf")
#!rsvg-convert -f pdf -o id_columns_99.pdf id_columns_99.svg
#!rm id_columns_99.svg

In [None]:
models = ["openclip_l", "clip_l", "dinov2_l", "vit_l", "blip_l_flickr", "blip_l_coco"]
max_components = 64
unit_evr = torch.zeros(len(models), 24, 16, max_components)

for exp_path in experiments:
    if "imagenet" not in exp_path.stem or "_l" not in exp_path.stem:
        continue
    exp_data = torch.load(exp_path, map_location="cpu", weights_only=False)
    for (layer_idx, head_idx, _unit_type), unit_data in exp_data["data"].items():
        model_idx = models.index(exp_data["model_name"])
        evr = unit_data["pca_evr"][:max_components].unsqueeze(0)
        new_evr = torch.cat([evr[:, 0].unsqueeze(-1), torch.diff(evr, dim=-1)], dim=-1)
        unit_evr[model_idx, layer_idx, head_idx] = new_evr

In [None]:
normalized_mean = torch.nn.functional.normalize(unit_evr[m].mean(dim=1), dim=-1, p=1)
models = ["openclip_l", "dinov2_l"]
plt.rcParams.update(figsizes.iclr2024(ncols=1, nrows=1, height_to_width_ratio=0.8))
fig, ax = plt.subplots(nrows=1, ncols=len(models))
k = 15
for m, model in enumerate(models):
    normalized_mean = torch.nn.functional.normalize(
        unit_evr[m].mean(dim=1), dim=-1, p=1
    )[:, :k]
    im = ax[m].imshow(
        normalized_mean,
        aspect="auto",
        cmap=cmap,
        origin="lower",
        # vmin=0,
        # vmax=1,
    )
    ax[m].set_title(model_names[model])
    ax[m].set_xticks(np.arange(normalized_mean.shape[1], step=2))
    ax[m].set_xticklabels(np.arange(1, normalized_mean.shape[1] + 1, step=2))
    if m == 1:
        ax[m].yaxis.tick_right()
        ax[m].yaxis.set_label_position("right")
        ax[m].set_yticks(np.arange(normalized_mean.shape[0]))
        ax[m].set_yticklabels(np.arange(normalized_mean.shape[0]))
        ax[m].set_ylabel("Layer")
    else:
        ax[m].set_yticks([])
    for i in range(normalized_mean.shape[0]):
        row = normalized_mean[i]
        cumsum = np.cumsum(row)
        condition = cumsum >= 0.5
        threshold_index = np.argmax(condition)
        if (
            threshold_index < k and condition.any()
        ):  # Check if the threshold is within the plotted range
            ax[m].axvline(
                x=threshold_index + 0.4,
                ymin=i / normalized_mean.shape[0],
                ymax=(i + 1) / normalized_mean.shape[0],
                color="y",
                linewidth=2,
            )
plt.savefig("evr_openclip_dino.svg", dpi=200, bbox_inches="tight", format="svg")
!rsvg-convert -f pdf -o evr_openclip_dino.pdf evr_openclip_dino.svg
!rm evr_openclip_dino.svg