In [None]:
import sys
import os
project_root = os.path.abspath("..")
sys.path.insert(0, project_root)

In [None]:
import torch
from matplotlib import pyplot as plt
import numpy as np
from utils.linalg import project_onto_plane
from utils.plotting import bivariate_color_map


In [None]:
SIZE = 224
INTERESTING_LAYERS = list(range(8, 17))

In [None]:
def project_onto_computed_xy(embeds, x_axes, y_axes, title, axes_token, token, layer, extra_tokens=None, grid_wh=4):

    fig, ax = plt.subplots(1, 1, figsize=(6, 6.2))

    # Collect embeddings for the main token
    color = []
    selected_embeds = []
    for (x, y, size), embed in embeds[layer][token].items():
        if size == SIZE:
            color.append(bivariate_color_map(x, y, grid_wh, grid_wh))
            selected_embeds.append(embed)

    # Convert to numpy and project main token embeddings
    main_array = torch.stack(selected_embeds).to(torch.float32).numpy()
    x_axis = x_axes[layer][axes_token].to(torch.float32).numpy()
    y_axis = y_axes[layer][axes_token].to(torch.float32).numpy()
    coords_main, explained = project_onto_plane(main_array, x_axis, y_axis)

    # Plot main token points
    ax.scatter(coords_main[:, 0], coords_main[:, 1], c=np.array(color), marker="o", label=token, s=150)

    if extra_tokens is not None:
        for label, val in extra_tokens.items():
            if isinstance(val, str):
                token_embeds = []
                for (_, _, size_), embed_val in embeds[layer].get(val, {}).items():
                    if size_ == SIZE:
                        token_embeds.append(embed_val)
                if len(token_embeds) == 0:
                    raise ValueError(f"No embeddings found for extra token '{val}' at layer {layer}.")
                avg_embed = torch.stack(token_embeds).to(torch.float32).mean(dim=0).numpy()
                print(avg_embed)
            else:
                # Assume val is a tensor or array
                if isinstance(val, torch.Tensor):
                    avg_embed = val[layer].to(torch.float32).cpu().numpy()
                else:
                    avg_embed = np.asarray(val[layer], dtype=np.float32)

            # Project using the shared project_onto_plane function
            coords_extra, _ = project_onto_plane(avg_embed[np.newaxis, :], x_axis, y_axis)
            coord = coords_extra[0]
            ax.scatter(coord[0], coord[1], marker=f"${label}$", c="black", s=150)

    # Finalize plot
    subtitle = f"Explained: {explained:.2f}"
    ax.invert_yaxis()
    ax.xaxis.tick_top()
    ax.set_title(subtitle)
    fig.subplots_adjust(top=0.85)
    fig.suptitle(f"{title}: Layer {layer}")

In [None]:
def visualize_grid(model, frog_token, left_token, right_token, layer):
    embeds = torch.load(f"embeds/id_grid/camel_frog/{model}.pt", map_location="cpu",weights_only=True)
    x_axes = torch.load(f"embeds/id_grid/camel_frog/{model}_x.pt", map_location="cpu",weights_only=True)
    y_axes = torch.load(f"embeds/id_grid/camel_frog/{model}_y.pt", map_location="cpu",weights_only=True)

    lr_embed = torch.load(f"embeds/id_grid/camel_frog_text/{model}.pt", map_location="cpu",weights_only=True)
    # Use .float() to cast to float32 before converting to numpy
    left_embed = np.array([lr_embed[layer][left_token].float().numpy() for layer in lr_embed.keys()])
    right_embed = np.array([lr_embed[layer][right_token].float().numpy() for layer in lr_embed.keys()])

    project_onto_computed_xy(embeds, x_axes, y_axes, axes_token=frog_token, title="Avg Frog Across Varying Reference", token=frog_token, extra_tokens={"L": left_embed, "R": right_embed}, layer=layer)

## Llava-7b

In [None]:
visualize_grid("llava-7b", "rog", "▁left", "▁right", layer=13)

## Llama-11b

In [None]:
visualize_grid("llama-11b", "Ġfrog", "Ġleft", "Ġright", layer=14)