In [1]:
from collections import defaultdict
import json
import os
import pickle

import numpy as np
import jax.numpy as jnp
import matplotlib.image as mpimg
import nltk
from flax import nnx
import matplotlib.patches as patches
import pandas as pd
import matplotlib.pyplot as plt
import umap
from nltk.corpus import wordnet as wn

In [2]:
nltk.download('wordnet')

[nltk_data] Downloading package wordnet to
[nltk_data]     /Users/hannahb./nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


True

In [3]:
def get_trials(trials_path):
    trials = []

    with open(trials_path, "r") as f:
        for line in f:
            trial = json.loads(line)
    
            triplet = trial.get("triplet", None)
            context = trial.get("context", None)
            choice = trial.get("choice", None)
    
            trials.append({
                "triplet": triplet,
                "context": context,
                "choice": choice,
            })
            
    return trials

In [4]:
def get_img_feat_dict(features_path):
    with open(features_path, "rb") as f:
        features = pickle.load(f)

    return features

In [5]:
def get_img_cls_dict(labels_path):
    df = pd.read_csv(labels_path)
    df["Label"] = df["PredictionString"].str.split().str[0]
    
    images = [image_id + ".JPEG" for image_id in list(df["ImageId"])]
    classes = list(df["Label"])

    img_cls_dict = dict(zip(images,classes))

    return img_cls_dict

In [6]:
def is_animal(wordnet_id):
    synset = wn.synset_from_pos_and_offset(
        pos=wordnet_id[0], 
        offset=int(wordnet_id[1:])
    )
    animal = wn.synset("animal.n.01")
    return animal in synset.closure(lambda s: s.hypernyms())

In [7]:
def is_clothing(wordnet_id):
    synset = wn.synset_from_pos_and_offset(
        pos=wordnet_id[0], 
        offset=int(wordnet_id[1:])
    )
    clothing = wn.synset("clothing.n.01")
    return clothing in synset.closure(lambda s: s.hypernyms())

In [8]:
def is_container(wordnet_id):
    synset = wn.synset_from_pos_and_offset(
        pos=wordnet_id[0], 
        offset=int(wordnet_id[1:])
    )
    container = wn.synset("container.n.01")
    return container in synset.closure(lambda s: s.hypernyms())

In [9]:
def is_food(wordnet_id):
    synset = wn.synset_from_pos_and_offset(
        pos=wordnet_id[0], 
        offset=int(wordnet_id[1:])
    )
    food = wn.synset("food.n.01")
    return food in synset.closure(lambda s: s.hypernyms())

In [10]:
def is_furniture(wordnet_id):
    synset = wn.synset_from_pos_and_offset(
        pos=wordnet_id[0], 
        offset=int(wordnet_id[1:])
    )
    furniture = wn.synset("furniture.n.01")
    return furniture in synset.closure(lambda s: s.hypernyms())

In [11]:
def is_plant(wordnet_id):
    synset = wn.synset_from_pos_and_offset(
        pos=wordnet_id[0], 
        offset=int(wordnet_id[1:])
    )
    plant = wn.synset("plant.n.02")
    return plant in synset.closure(lambda s: s.hypernyms())

In [12]:
def is_vehicle(wordnet_id):
    synset = wn.synset_from_pos_and_offset(
        pos=wordnet_id[0], 
        offset=int(wordnet_id[1:])
    )
    vehicle = wn.synset("vehicle.n.01")
    return vehicle in synset.closure(lambda s: s.hypernyms())

In [14]:
def get_repr_feats(img_feat_dict, img_cls_dict):
    imgs_by_cls = defaultdict(list)
    for img in img_feat_dict:
        imgs_by_cls[img_cls_dict[img]].append(img)

    cls_means = {
        cls: np.mean([img_feat_dict[img] for img in imgs], axis=0)
        for cls, imgs in imgs_by_cls.items()
    }

    feats_norm = {
        img: feat / np.linalg.norm(feat)
        for img, feat in img_feat_dict.items()
    }

    cls_repr_feat_dict = {}
    for cls, mean in cls_means.items():
        mean_norm = mean / np.linalg.norm(mean)

        repr_img = max(
            imgs_by_cls[cls],
            key=lambda img: float(np.dot(feats_norm[img], mean_norm)),
        )
        cls_repr_feat_dict[cls] = img_feat_dict[repr_img]

    classes = list(cls_repr_feat_dict.keys())
    superclasses = ["animal", "clothing", "container", "food", "furniture", "plant", "vehicle", "other"]

    cls_to_superclass = {}
    for cls in classes:
        if is_animal(cls):
            superclass = "animal"
        elif is_clothing(cls):
            superclass = "clothing"
        elif is_container(cls) and not is_vehicle(cls): # some vehicles are also containers
            superclass = "container"
        elif is_food(cls):
            superclass = "food"
        elif is_furniture(cls):
            superclass = "furniture"
        elif is_plant(cls):
            superclass = "plant"
        elif is_vehicle(cls):
            superclass = "vehicle"
        else:
            superclass = "other"
        cls_to_superclass[cls] = superclass

    sum_len_group = 0
    
    classes_sorted = []
    for superclass in superclasses:
        group = [cls for cls in classes if cls_to_superclass[cls] == superclass]
        group.sort()
        classes_sorted.extend(group)
        
        print(f"{superclass}: {sum_len_group}-{sum_len_group+len(group)}")
        sum_len_group += len(group)

    repr_feats = np.stack([cls_repr_feat_dict[cls] for cls in classes_sorted], axis=0)
    superclasses_sorted = [cls_to_superclass[cls] for cls in classes_sorted]

    return repr_feats, cls_repr_feat_dict, superclasses_sorted

In [15]:
def get_model(model_path):
    if not os.path.exists(model_path):
        print(f"Model file does not exist: {model_path}")

    try:
        with open(model_path, "rb") as f:
            model = pickle.load(f)
    except Exception as e:
        print(f"Failed to load model: {e}")

    return model

In [16]:
def get_independent_feats(independent_model, feats):
    independent_feats = independent_model.P(feats)
    
    independent_feats_norm = jnp.linalg.norm(independent_feats, axis=-1, keepdims=True) + 1e-8
    independent_feats = independent_feats / independent_feats_norm
    
    return independent_feats

In [17]:
def get_anchor_feats(anchor_model, context_feat, feats):
    anchor_feats = anchor_model.P(feats)
    anchor_feats = anchor_feats / (jnp.linalg.norm(anchor_feats, axis=-1, keepdims=True) + 1e-8)

    B_flat = anchor_model.B_network(context_feat)
    B = B_flat.reshape(B_flat.shape[:-1] + (anchor_model.rank, anchor_model.embedding_dim))

    anchor_feats = jnp.einsum('...kd,...td->...tk', B, anchor_feats)

    return anchor_feats

In [18]:
def get_choice(model, feats):
    f1 = feats[..., 0, :]
    f2 = feats[..., 1, :]
    f3 = feats[..., 2, :]

    sims = jnp.stack(
        [
            jnp.sum(f2 * f3, axis=-1),
            jnp.sum(f1 * f3, axis=-1),
            jnp.sum(f1 * f2, axis=-1),
        ],
        axis=-1,
    )

    preds = nnx.softmax(model.temperature * sims, axis=-1)
    choice = jnp.argmax(preds, axis=-1)
    
    return choice

In [19]:
def _visualize_trial_on_axes(
    axes, 
    images_path,
    context,
    triplet,
    choice,
    independent_choice,
    anchor_choice
):
    """
    Visualize the trial (context + 3 triplet images) on a provided axes.
    """
    paths = [
        os.path.join(images_path, context),
        os.path.join(images_path, triplet[0]),
        os.path.join(images_path, triplet[1]),
        os.path.join(images_path, triplet[2]),
    ]
    images = [mpimg.imread(p) for p in paths]

    # Context title
    axes[0].set_title("Context", fontsize=11)

    # Build labels for the three choices
    labels = [[] for _ in range(3)]
    if choice in (0, 1, 2):
        labels[choice].append("Human")
    if independent_choice in (0, 1, 2):
        labels[independent_choice].append("Independent")
    if anchor_choice in (0, 1, 2):
        labels[anchor_choice].append("Anchor")

    for j in range(3):
        axes[j + 1].set_title("\n".join(labels[j]), fontsize=11)

    # Show images aligned to the top
    for ax, img in zip(axes, images):
        ax.imshow(img)
        ax.set_anchor("N")
        ax.set_xticks([])
        ax.set_yticks([])

    # Add colored frames around the 3 triplet images
    colors=("red", "green", "blue")
    
    for ax, color in zip(axes[1:], colors):
        rect = patches.Rectangle(
            (0, 0), 1, 1,
            transform=ax.transAxes,
            fill=False,
            edgecolor=color,
            linewidth=6,
            clip_on=False
        )
        ax.add_patch(rect)

In [20]:
def visualize_trial(
    idx,
    images_path,
    context,
    triplet,
    choice,
    independent_choice,
    anchor_choice
):
    """
    Standalone version.
    """
    fig, axes = plt.subplots(1, 4, figsize=(12, 3), constrained_layout=True)
    fig.suptitle(f"Trial {idx}", fontsize=14)
    
    _visualize_trial_on_axes(
        axes,
        images_path,
        context,
        triplet,
        choice,
        independent_choice,
        anchor_choice
    )
    
    plt.show()

In [21]:
def _visualize_rsm_on_axes(
    axes,
    independent_repr_feats,
    anchor_repr_feats,
):
    """
    Visualize Independent/Anchor RSMs on provided axes.
    """
    independent_RSM = independent_repr_feats @ independent_repr_feats.T
    anchor_RSM = anchor_repr_feats @ anchor_repr_feats.T

    im0 = axes[0].imshow(independent_RSM, origin="lower")
    axes[0].set_title("Independent RSM")

    im1 = axes[1].imshow(anchor_RSM, origin="lower")
    axes[1].set_title("Anchor RSM")

    fig = axes[0].figure
    fig.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)
    fig.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)

    return im0, im1

In [22]:
def visualize_rsm(independent_repr_feats, anchor_repr_feats):
    """
    Standalone version.
    """
    fig, axes = plt.subplots(1, 2, figsize=(10, 4), constrained_layout=True)
    
    _visualize_rsm_on_axes(
        axes,
        independent_repr_feats,
        anchor_repr_feats
    )
    
    plt.show()

In [24]:
def _compute_pca(X):
    """
    Parameters
    ----------
    X : array-like, shape (N, D)
        Data matrix.

    Returns
    -------
    Y : ndarray, shape (N, 2)
        2D embedding.
    """
    X = np.asarray(X, dtype=float)
    if X.ndim != 2:
        raise ValueError("X must be a 2D array.")

    # Gram matrix
    K = X @ X.T

    # Double-centering
    N = X.shape[0]
    J = np.eye(N) - np.ones((N, N)) / N
    Kc = J @ K @ J

    # Eigen-decomposition
    eigvals, eigvecs = np.linalg.eigh(Kc)
    idx = np.argsort(eigvals)[::-1]
    eigvals = eigvals[idx]
    eigvecs = eigvecs[:, idx]

    # 2D embedding
    lam = np.maximum(eigvals[:2], 0.0)
    Y = eigvecs[:, :2] * np.sqrt(lam)

    return Y

In [25]:
def _plot_pca(
    ax, 
    Y, 
    n_base,
    triplet,
    title,
    superclasses
):
    superclass_to_color = {
        "animal": "tab:blue",
        "clothing": "tab:orange",
        "container": "tab:purple",
        "food": "tab:brown",
        "furniture": "tab:pink",
        "plant": "tab:green",
        "vehicle": "tab:red",
        "other": "tab:gray",
    }

    base_colors = [superclass_to_color[s] for s in superclasses]

    ax.scatter(
        Y[:n_base, 0],
        Y[:n_base, 1],
        c=base_colors,
        s=15,
        alpha=0.6
    )

    triplet_colors = ["red", "green", "blue"]
    for i in range(len(triplet)):
        ax.scatter(
            Y[n_base + i, 0],
            Y[n_base + i, 1],
            c=triplet_colors[i],
            s=80
        )

    ax.set_xlabel("PC 1")
    ax.set_ylabel("PC 2")
    ax.set_title(title)
    ax.axis("equal")

In [26]:
def _visualize_pca_on_axes(
    axes,
    independent_repr_feats,
    independent_triplet_feats,
    anchor_repr_feats,
    anchor_triplet_feats,
    triplet,
    superclasses
):
    """
    Visualize independent/anchor PCA on provided axes.
    """
    X_ind = np.asarray(independent_repr_feats, dtype=float)
    T_ind = np.asarray(independent_triplet_feats, dtype=float)
    X_anc = np.asarray(anchor_repr_feats, dtype=float)
    T_anc = np.asarray(anchor_triplet_feats, dtype=float)

    if X_ind.ndim != 2 or T_ind.ndim != 2 or X_anc.ndim != 2 or T_anc.ndim != 2:
        raise ValueError("All feature inputs must be 2D arrays.")
    if X_ind.shape[1] != T_ind.shape[1]:
        raise ValueError("independent_repr_feats and independent_triplet_feats must have the same dimensionality.")
    if X_anc.shape[1] != T_anc.shape[1]:
        raise ValueError("anchor_repr_feats and anchor_triplet_feats must have the same dimensionality.")
    if len(triplet) != T_ind.shape[0] or len(triplet) != T_anc.shape[0]:
        raise ValueError("Length of triplet must match the number of triplet feature rows for both inputs.")

    X_all_ind = np.vstack([X_ind, T_ind])
    X_all_anc = np.vstack([X_anc, T_anc])
    
    Y_ind = _compute_pca(X_all_ind)
    Y_anc = _compute_pca(X_all_anc)

    _plot_pca(
        axes[0],
        Y_ind,
        n_base=X_ind.shape[0],
        triplet=triplet,
        title=f"Independent (PCA)",
        superclasses=superclasses
    )
    _plot_pca(
        axes[1],
        Y_anc,
        n_base=X_anc.shape[0],
        triplet=triplet,
        title=f"Anchor (PCA)",
        superclasses=superclasses
    )

In [27]:
def visualize_pca(
    independent_repr_feats,
    independent_triplet_feats,
    anchor_repr_feats,
    anchor_triplet_feats,
    triplet,
    superclasses
):
    """
    Standalone version.
    """
    fig, axes = plt.subplots(1, 2, figsize=(12, 6), constrained_layout=True)
    
    _visualize_pca_on_axes(
        axes,
        independent_repr_feats,
        independent_triplet_feats,
        anchor_repr_feats,
        anchor_triplet_feats,
        triplet,
        superclasses
    )
    
    plt.show()

In [28]:
def visualize_trial_rsm_pca(
    idx,
    images_path,
    context,
    triplet,
    choice,
    independent_choice,
    anchor_choice,
    independent_repr_feats,
    anchor_repr_feats,
    independent_triplet_feats,
    anchor_triplet_feats,
    superclasses
):
    fig, axes = plt.subplots(2, 4, figsize=(16, 8), constrained_layout=True)
    fig.suptitle(f"Trial {idx}", fontsize=14)

    _visualize_trial_on_axes(
        axes[0, :],
        images_path,
        context,
        triplet,
        choice,
        independent_choice,
        anchor_choice
    )

    _visualize_rsm_on_axes(
        axes[1, 0:2],
        independent_repr_feats,
        anchor_repr_feats
    )

    _visualize_pca_on_axes(
        axes[1, 2:4],
        independent_repr_feats,
        independent_triplet_feats,
        anchor_repr_feats,
        anchor_triplet_feats,
        triplet,
        superclasses
    )

    plt.show()

In [29]:
def ranking_preserved(T: np.ndarray, Y: np.ndarray) -> bool:
    # pairwise dot products for the triplet
    dots = np.array([
        T[0] @ T[1],
        T[0] @ T[2],
        T[1] @ T[2],
    ])

    # pairwise Euclidean distances for the last 3 embeddings (triplet)
    triplet = Y[-3:]
    dists = np.array([
        np.linalg.norm(triplet[0] - triplet[1]),
        np.linalg.norm(triplet[0] - triplet[2]),
        np.linalg.norm(triplet[1] - triplet[2]),
    ])

    # higher dot product = closer, lower distance = closer
    return np.array_equal(np.argsort(-dots), np.argsort(dists))

In [30]:
def get_idxs(
    trials_path = "/home/space/datasets/context_project/data/bretts/behavioural_data/triplets_test.jsonl",
    features_path = "/home/space/datasets/context_project/data/bretts/feature_vectors/DINOv2-B-alignet/features.pkl",
    labels_path = "LOC_val_solution.csv",
    independent_model_path = "/home/space/datasets/context_project/models/" + "rare-valley-2178" + ".pkl",
    anchor_model_path = "/home/space/datasets/context_project/models/" + "sparkling-violet-2008" + ".pkl",
    images_path = "/home/space/datasets/imagenet/2012/val_set_unlabeled/all_classes/"
):
    trials = get_trials(trials_path) 

    img_feat_dict = get_img_feat_dict(features_path)
    img_cls_dict = get_img_cls_dict(labels_path)

    repr_feats, cls_repr_feat_dict = get_repr_feats(img_feat_dict, img_cls_dict)

    independent_model = get_model(independent_model_path)
    anchor_model = get_model(anchor_model_path)
    
    T = len(trials)
    D = next(iter(img_feat_dict.values())).shape[0]

    contexts = np.empty((T, D), dtype=np.float32)
    triplets = np.empty((T, 3, D), dtype=np.float32)
    choices = np.empty(T, dtype=np.int64)

    contexts_repr = np.empty((T, D), dtype=np.float32)
    triplets_repr = np.empty((T, 3, D), dtype=np.float32)

    for i, trial in enumerate(trials):
        context = trial["context"]
        triplet = trial["triplet"]
        choice = trial["choice"]
        
        contexts[i] = img_feat_dict[context]
        triplets[i] = np.stack([img_feat_dict[img] for img in triplet])
        choices[i] = choice
        
        contexts_repr[i] = cls_repr_feat_dict[img_cls_dict[context]]
        triplets_repr[i] = np.stack([cls_repr_feat_dict[img_cls_dict[img]] for img in triplet])

    independent_triplets = get_independent_feats(independent_model, triplets)
    anchor_triplets = get_anchor_feats(anchor_model, contexts, triplets)
    
    independent_choices = get_choice(independent_model, independent_triplets)
    anchor_choices = get_choice(anchor_model, anchor_triplets)

    # Filter for trials where anchor is correct and independent is incorrect
    mask_1 = np.array((choices == anchor_choices) & (choices != independent_choices))
    print(mask_1.mean(), mask_1.sum())

    independent_triplets_repr = get_independent_feats(independent_model, triplets_repr)
    anchor_triplets_repr = get_anchor_feats(anchor_model, contexts_repr, triplets_repr)
    
    independent_choices_repr = get_choice(independent_model, independent_triplets_repr)
    anchor_choices_repr = get_choice(anchor_model, anchor_triplets_repr)

    # Filter for trials where class representatives yield the same result
    mask_2 = np.array((independent_choices == independent_choices_repr) & (anchor_choices == anchor_choices_repr))
    print(mask_2.mean(), mask_2.sum())

    mask = mask_1 & mask_2
    print(mask.mean(), mask.sum())


    # Filter for samples where order is preserved by MDS
    i = 0
    mask_sum = mask.sum()

    independent_repr_feats = get_independent_feats(independent_model, repr_feats)
    
    for idx, trial in enumerate(trials):
        if mask[idx]:
            if i % 100 == 0 and i != 0:
                print(f"{i}/{mask_sum}")
            i += 1
            
            anchor_repr_feats = get_anchor_feats(anchor_model, contexts[idx], repr_feats)

            independent_triplet_feats = independent_triplets[idx]
            anchor_triplet_feats = anchor_triplets[idx]
            
            X_ind = np.asarray(independent_repr_feats, dtype=float)
            T_ind = np.asarray(independent_triplet_feats, dtype=float)
            X_anc = np.asarray(anchor_repr_feats, dtype=float)
            T_anc = np.asarray(anchor_triplet_feats, dtype=float)
            
            X_all_ind = np.vstack([X_ind, T_ind])
            X_all_anc = np.vstack([X_anc, T_anc])
            
            Y_ind = _compute_pca(X_all_ind)
            Y_anc = _compute_pca(X_all_anc)

            independent_ranking_preserved = ranking_preserved(T_ind, Y_ind)
            anchor_ranking_preserved = ranking_preserved(T_anc, Y_anc)
    
            if not independent_ranking_preserved or not anchor_ranking_preserved:
                mask[idx] = 0
    print(f"{mask_sum}/{mask_sum}")
    print(mask.mean(), mask.sum())

    idxs = np.where(mask)[0]
    
    return idxs

In [33]:
def visualize(
    idxs,
    trials_path = "/home/space/datasets/context_project/data/bretts/behavioural_data/triplets_test.jsonl",
    features_path = "/home/space/datasets/context_project/data/bretts/feature_vectors/DINOv2-B-alignet/features.pkl",
    labels_path = "LOC_val_solution.csv",
    independent_model_path = "/home/space/datasets/context_project/models/" + "rare-valley-2178" + ".pkl",
    anchor_model_path = "/home/space/datasets/context_project/models/" + "sparkling-violet-2008" + ".pkl",
    images_path = "/home/space/datasets/imagenet/2012/val_set_unlabeled/all_classes/"
):
    trials = get_trials(trials_path) 

    img_feat_dict = get_img_feat_dict(features_path)
    img_cls_dict = get_img_cls_dict(labels_path)
    
    repr_feats, _, superclasses = get_repr_feats(img_feat_dict, img_cls_dict)

    independent_model = get_model(independent_model_path)
    anchor_model = get_model(anchor_model_path)
    
    independent_repr_feats = get_independent_feats(independent_model, repr_feats)
    
    for idx in idxs:
        trial = trials[idx]
    
        context = trial["context"]
        triplet = trial["triplet"]
        choice = trial["choice"]
    
        context_feat = img_feat_dict[context]
        triplet_feats = np.stack([img_feat_dict[img] for img in triplet])
    
        independent_triplet_feats = get_independent_feats(independent_model, triplet_feats)
        independent_choice = get_choice(independent_model, independent_triplet_feats)

        anchor_repr_feats = get_anchor_feats(anchor_model, context_feat, repr_feats)
        anchor_triplet_feats = get_anchor_feats(anchor_model, context_feat, triplet_feats)
        anchor_choice = get_choice(anchor_model, anchor_triplet_feats)
        
        visualize_trial_rsm_pca(
            idx,
            images_path,
            context,
            triplet,
            choice,
            independent_choice,
            anchor_choice,
            independent_repr_feats,
            anchor_repr_feats,
            independent_triplet_feats,
            anchor_triplet_feats,
            superclasses
        )

In [34]:
visualize([15615])

FileNotFoundError: [Errno 2] No such file or directory: '/home/space/datasets/context_project/data/bretts/behavioural_data/triplets_test.jsonl'