In [1]:
import numpy as np
from datasets import load_dataset
from transformers import AutoImageProcessor, AutoModel

tiny_imagenet = load_dataset('zh-plus/tiny-imagenet', split='train')
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-small')
model = AutoModel.from_pretrained('facebook/dinov2-small')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def embed(x):
    inputs = processor(images=x, return_tensors="pt")
    outputs = model(**inputs)
    retval = (
        outputs
        .last_hidden_state
        .flatten()
        .detach()
        .numpy()
    )
    retval = retval[np.newaxis, :]
    return retval

In [3]:
classes = {
    0: "fish",
    1: "ground_spider",
    2: "frog",
    5: "snake",
    8: "web_spider",
    19: "penguin"
}
id_labels = classes.keys()
text_labels = list(classes.values())

In [4]:
embeddings = [
    [
        embed(image['image'])
        for image in tiny_imagenet 
        if image['label'] == label
    ]
    for label in id_labels
]

In [5]:
from collections import defaultdict
from dataclasses import dataclass
from typing import Callable

import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist
from scipy.stats import cramervonmises_2samp, ks_2samp

@dataclass
class EmbeddingMatrix:
    """
    Maps scores from each query label to each reference label.
    """

    statistics: dict[str, dict[str, float]]
    pvalues: dict[str, dict[str, float]]

In [6]:
def get_inner_product_function(dataset: np.ndarray):
    return lambda x: np.dot(dataset, x)


def get_projections(
    basis: np.ndarray,
    inner_product: Callable[[np.ndarray], np.ndarray],
) -> np.ndarray:
    """
    Computes orthogonal projection of a dataset onto a plane with orthonormal basis `basis`.
    The dataset is represented abstractly by the Callable `inner_product`, which is a function
    that takes in a vector `v` and returns the array of inner products <x, v> for all x in the
    dataset.

    Parameters
    ----------
    basis: np.ndarray
        Should be k x d, all length 1, and orthogonal
    inner_product: Callable[np.ndarray, np.ndarray]

    Returns
    -------
    projections: np.ndarray
        N x k
    """
    # check that basis is orthonormal
    dots = np.einsum("ik,jk->ij", basis, basis)
    np.testing.assert_almost_equal(dots, np.identity(dots.shape[0]))

    res = []
    for x in basis:
        res.append(inner_product(x))
    res = np.array(res)
    return res.T

In [7]:
def _compute_metrics_from_projections(
    queries: list[list[np.ndarray]],
    references: list[list[np.ndarray]] | None,
    classes: list[str],
    method: str,
    func: Callable,
) -> EmbeddingMatrix:
    """
    Computes metrics using a selectable scipy.stats function.

    Parameters
    ----------
    queries : list[list[np.ndarray]]
        A list of query embedding lists.
    references : list[list[np.ndarray]], optional
        A list of reference embedding lists.
    classes : list[str]
        A list of labels mapped to the distance lists.
    method : str
        The method of calculating distance.
    func : Callable
        A scipy.stats function.

    Returns
    -------
    EmbeddingMatrix
        A metric containing a confusion matrix for both p-value and distance metric.
    """
    pvalues = defaultdict(lambda: defaultdict(float))
    statistics = defaultdict(lambda: defaultdict(float))

    def cast_and_round(x):
        x = float(x)
        return round(x, 3)

    queries_are_references = references is None
    references = queries

    vector = np.random.normal(size=(1, references[0][0].shape[1]))
    vector = vector / np.linalg.norm(vector)
    k = 3
    for _ in range(k):
        for i, query in enumerate(queries):
            for j, reference in enumerate(references):
                if queries_are_references and i == j:
                    # split the set in two and measure how similarly distributed it is.
                    split_idx = len(reference) // 2
                    reference_proj = get_projections(
                        vector,
                        get_inner_product_function(
                            np.concatenate(reference[:split_idx])
                        )
                    )
                    query_proj = get_projections(
                        vector,
                        get_inner_product_function(
                            np.concatenate(query[split_idx:])
                        )
                    )
                else:
                    reference_proj = get_projections(
                        vector,
                        get_inner_product_function(
                            np.concatenate(reference)
                        )
                    )
                    query_proj = get_projections(
                        vector,
                        get_inner_product_function(
                            np.concatenate(query)
                        )
                    )
                
                metric = func(reference_proj, query_proj)

                label_i = classes[i]
                label_j = classes[j]
                pvalues[label_i][label_j] += cast_and_round(metric.pvalue)
                statistics[label_i][label_j] += cast_and_round(metric.statistic)

    return EmbeddingMatrix(pvalues=pvalues, statistics=statistics)

In [8]:
def compute_cvm(
    queries: list[list[np.ndarray]],
    classes: list[str],
    references: list[list[np.ndarray]] | None = None,
    method: str = "cosine",
) -> EmbeddingMatrix:
    return _compute_metrics_from_projections(
        queries=queries,
        references=references,
        classes=classes,
        method=method,
        func=cramervonmises_2samp,
    )


def compute_ks(
    queries: list[list[np.ndarray]],
    classes: list[str],
    references: list[list[np.ndarray]] | None = None,
    method: str = "cosine",
) -> EmbeddingMatrix:
    return _compute_metrics_from_projections(
        queries=queries,
        references=references,
        classes=classes,
        method=method,
        func=ks_2samp,
    )

Compute metrics

In [9]:
cvm = compute_cvm(embeddings, text_labels)
ks = compute_ks(embeddings, text_labels)

  x = float(x)


Create pandas dataframes

In [10]:
from embedding_metrics import create_dataframe

cvm_statistics, cvm_pvalues = create_dataframe(cvm, text_labels)
ks_statistics, ks_pvalues = create_dataframe(ks, text_labels)

In [11]:
cvm_statistics

Unnamed: 0_level_0,Unnamed: 1_level_0,Reference,Reference,Reference,Reference,Reference,Reference
Unnamed: 0_level_1,Unnamed: 1_level_1,fish,ground_spider,frog,snake,web_spider,penguin
Query,fish,0.153,23.636,22.091,16.439,4.224,28.036
Query,ground_spider,23.636,0.037,0.349,0.83,11.616,0.547
Query,frog,22.091,0.349,0.168,0.392,9.903,1.611
Query,snake,16.439,0.83,0.392,0.595,6.332,2.471
Query,web_spider,4.224,11.616,9.903,6.332,0.307,16.102
Query,penguin,28.036,0.547,1.611,2.471,16.102,0.118


In [12]:
cvm_pvalues

Unnamed: 0_level_0,Unnamed: 1_level_0,Reference,Reference,Reference,Reference,Reference,Reference
Unnamed: 0_level_1,Unnamed: 1_level_1,fish,ground_spider,frog,snake,web_spider,penguin
Query,fish,0.381,0.0,0.0,0.0,0.0,0.0
Query,ground_spider,0.0,0.95,0.099,0.006,0.0,0.03
Query,frog,0.0,0.099,0.339,0.076,0.0,0.0
Query,snake,0.0,0.006,0.076,0.023,0.0,0.0
Query,web_spider,0.0,0.0,0.0,0.0,0.129,0.0
Query,penguin,0.0,0.03,0.0,0.0,0.0,0.505


In [13]:
ks_statistics

Unnamed: 0_level_0,Unnamed: 1_level_0,Reference,Reference,Reference,Reference,Reference,Reference
Unnamed: 0_level_1,Unnamed: 1_level_1,fish,ground_spider,frog,snake,web_spider,penguin
Query,fish,0.068,0.524,0.478,0.416,0.176,0.156
Query,ground_spider,0.524,0.096,0.066,0.114,0.404,0.416
Query,frog,0.478,0.066,0.056,0.076,0.378,0.374
Query,snake,0.416,0.114,0.076,0.052,0.314,0.308
Query,web_spider,0.176,0.404,0.378,0.314,0.06,0.034
Query,penguin,0.156,0.416,0.374,0.308,0.034,0.072


In [14]:
ks_pvalues

Unnamed: 0_level_0,Unnamed: 1_level_0,Reference,Reference,Reference,Reference,Reference,Reference
Unnamed: 0_level_1,Unnamed: 1_level_1,fish,ground_spider,frog,snake,web_spider,penguin
Query,fish,0.611,0.0,0.0,0.0,0.0,0.0
Query,ground_spider,0.0,0.2,0.226,0.003,0.0,0.0
Query,frog,0.0,0.226,0.829,0.111,0.0,0.0
Query,snake,0.0,0.003,0.111,0.889,0.0,0.0
Query,web_spider,0.0,0.0,0.0,0.0,0.76,0.935
Query,penguin,0.0,0.0,0.0,0.0,0.935,0.537
