In [1]:
import numpy as np
from datasets import load_dataset
from transformers import AutoImageProcessor, AutoModel

tiny_imagenet = load_dataset('zh-plus/tiny-imagenet', split='train')
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-small')
model = AutoModel.from_pretrained('facebook/dinov2-small')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def embed(x):
    inputs = processor(images=x, return_tensors="pt")
    outputs = model(**inputs)
    retval = (
        outputs
        .last_hidden_state
        .flatten()
        .detach()
        .numpy()
    )
    retval = retval[np.newaxis, :]
    return retval

In [3]:
classes = {
    0: "fish",
    1: "ground_spider",
    2: "frog",
    5: "snake",
    8: "web_spider",
    19: "penguin"
}
id_labels = classes.keys()
text_labels = list(classes.values())

In [4]:
embeddings = [
    [
        embed(image['image'])
        for image in tiny_imagenet 
        if image['label'] == label
    ]
    for label in id_labels
]

In [5]:
from collections import defaultdict
from dataclasses import dataclass
from typing import Callable

import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist
from scipy.stats import cramervonmises_2samp, ks_2samp

@dataclass
class EmbeddingMatrix:
    """
    Maps scores from each query label to each reference label.
    """

    statistics: dict[str, dict[str, float]]
    pvalues: dict[str, dict[str, float]]

In [6]:
def get_inner_product_function(dataset: np.ndarray):
    return lambda x: np.dot(dataset, x)


def get_projections(
    basis: np.ndarray,
    inner_product: Callable[[np.ndarray], np.ndarray],
) -> np.ndarray:
    """
    Computes orthogonal projection of a dataset onto a plane with orthonormal basis `basis`.
    The dataset is represented abstractly by the Callable `inner_product`, which is a function
    that takes in a vector `v` and returns the array of inner products <x, v> for all x in the
    dataset.

    Parameters
    ----------
    basis: np.ndarray
        Should be k x d, all length 1, and orthogonal
    inner_product: Callable[np.ndarray, np.ndarray]

    Returns
    -------
    projections: np.ndarray
        N x k
    """
    # check that basis is orthonormal
    dots = np.einsum("ik,jk->ij", basis, basis)
    np.testing.assert_almost_equal(dots, np.identity(dots.shape[0]))

    res = []
    for x in basis:
        res.append(inner_product(x))
    res = np.array(res)
    return res.T

In [7]:
def _compute_metrics_from_projections(
    queries: list[list[np.ndarray]],
    references: list[list[np.ndarray]] | None,
    classes: list[str],
    method: str,
    func: Callable,
) -> EmbeddingMatrix:
    """
    Computes metrics using a selectable scipy.stats function.

    Parameters
    ----------
    queries : list[list[np.ndarray]]
        A list of query embedding lists.
    references : list[list[np.ndarray]], optional
        A list of reference embedding lists.
    classes : list[str]
        A list of labels mapped to the distance lists.
    method : str
        The method of calculating distance.
    func : Callable
        A scipy.stats function.

    Returns
    -------
    EmbeddingMatrix
        A metric containing a confusion matrix for both p-value and distance metric.
    """
    pvalues = defaultdict(lambda: defaultdict(float))
    statistics = defaultdict(lambda: defaultdict(float))

    def cast_and_round(x):
        x = float(x)
        return round(x, 3)

    queries_are_references = references is None
    references = queries

    vector = np.random.normal(size=(1, references[0][0].shape[1]))
    vector = vector / np.linalg.norm(vector)
    k = 3
    for _ in range(k):
        for i, query in enumerate(queries):
            for j, reference in enumerate(references):
                if queries_are_references and i == j:
                    # split the set in two and measure how similarly distributed it is.
                    split_idx = len(reference) // 2
                    reference_proj = get_projections(
                        vector,
                        get_inner_product_function(
                            np.concatenate(reference[:split_idx])
                        )
                    )
                    query_proj = get_projections(
                        vector,
                        get_inner_product_function(
                            np.concatenate(query[split_idx:])
                        )
                    )
                else:
                    reference_proj = get_projections(
                        vector,
                        get_inner_product_function(
                            np.concatenate(reference)
                        )
                    )
                    query_proj = get_projections(
                        vector,
                        get_inner_product_function(
                            np.concatenate(query)
                        )
                    )
                
                metric = func(reference_proj, query_proj)

                label_i = classes[i]
                label_j = classes[j]
                pvalues[label_i][label_j] += cast_and_round(metric.pvalue)
                statistics[label_i][label_j] += cast_and_round(metric.statistic)

    return EmbeddingMatrix(pvalues=pvalues, statistics=statistics)

In [8]:
def compute_cvm(
    queries: list[list[np.ndarray]],
    classes: list[str],
    references: list[list[np.ndarray]] | None = None,
    method: str = "cosine",
) -> EmbeddingMatrix:
    return _compute_metrics_from_projections(
        queries=queries,
        references=references,
        classes=classes,
        method=method,
        func=cramervonmises_2samp,
    )


def compute_ks(
    queries: list[list[np.ndarray]],
    classes: list[str],
    references: list[list[np.ndarray]] | None = None,
    method: str = "cosine",
) -> EmbeddingMatrix:
    return _compute_metrics_from_projections(
        queries=queries,
        references=references,
        classes=classes,
        method=method,
        func=ks_2samp,
    )

Compute metrics

In [9]:
cvm = compute_cvm(embeddings, text_labels)
ks = compute_ks(embeddings, text_labels)

  x = float(x)


Create pandas dataframes

In [10]:
from embedding_metrics import create_dataframe

cvm_statistics, cvm_pvalues = create_dataframe(cvm, text_labels)
ks_statistics, ks_pvalues = create_dataframe(ks, text_labels)

In [11]:
cvm_statistics

Unnamed: 0_level_0,Unnamed: 1_level_0,Reference,Reference,Reference,Reference,Reference,Reference
Unnamed: 0_level_1,Unnamed: 1_level_1,fish,ground_spider,frog,snake,web_spider,penguin
Query,fish,0.051,11.802,0.267,0.285,5.985,35.991
Query,ground_spider,11.802,0.201,14.85,9.981,32.802,9.546
Query,frog,0.267,14.85,0.891,0.684,4.452,41.73
Query,snake,0.285,9.981,0.684,0.633,6.747,32.883
Query,web_spider,5.985,32.802,4.452,6.747,0.402,63.909
Query,penguin,35.991,9.546,41.73,32.883,63.909,0.441


In [12]:
cvm_pvalues

Unnamed: 0_level_0,Unnamed: 1_level_0,Reference,Reference,Reference,Reference,Reference,Reference
Unnamed: 0_level_1,Unnamed: 1_level_1,fish,ground_spider,frog,snake,web_spider,penguin
Query,fish,2.997,0.0,1.935,1.836,0.0,0.0
Query,ground_spider,0.0,2.31,0.0,0.0,0.0,0.0
Query,frog,1.935,0.0,0.414,0.657,0.0,0.0
Query,snake,1.836,0.0,0.657,0.744,0.0,0.0
Query,web_spider,0.0,0.0,0.0,0.0,1.329,0.0
Query,penguin,0.0,0.0,0.0,0.0,0.0,1.2


In [13]:
ks_statistics

Unnamed: 0_level_0,Unnamed: 1_level_0,Reference,Reference,Reference,Reference,Reference,Reference
Unnamed: 0_level_1,Unnamed: 1_level_1,fish,ground_spider,frog,snake,web_spider,penguin
Query,fish,0.264,0.252,0.75,0.39,0.372,0.558
Query,ground_spider,0.252,0.36,0.906,0.342,0.498,0.438
Query,frog,0.75,0.906,0.252,0.984,0.564,1.116
Query,snake,0.39,0.342,0.984,0.336,0.63,0.288
Query,web_spider,0.372,0.498,0.564,0.63,0.156,0.846
Query,penguin,0.558,0.438,1.116,0.288,0.846,0.168


In [14]:
ks_pvalues

Unnamed: 0_level_0,Unnamed: 1_level_0,Reference,Reference,Reference,Reference,Reference,Reference
Unnamed: 0_level_1,Unnamed: 1_level_1,fish,ground_spider,frog,snake,web_spider,penguin
Query,fish,0.864,0.177,0.0,0.0,0.003,0.0
Query,ground_spider,0.177,0.165,0.0,0.009,0.0,0.0
Query,frog,0.0,0.0,1.026,0.0,0.0,0.0
Query,snake,0.0,0.009,0.0,0.261,0.0,0.06
Query,web_spider,0.003,0.0,0.0,0.0,2.667,0.0
Query,penguin,0.0,0.0,0.0,0.06,0.0,2.487
