In [1]:
import numpy as np
from datasets import load_dataset
from transformers import AutoImageProcessor, AutoModel

tiny_imagenet = load_dataset('zh-plus/tiny-imagenet', split='valid')
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-small')
model = AutoModel.from_pretrained('facebook/dinov2-small')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def embed(x):
    inputs = processor(images=x, return_tensors="pt")
    outputs = model(**inputs)
    retval = (
        outputs
        .last_hidden_state
        .flatten()
        .detach()
        .numpy()
    )
    retval = retval[np.newaxis, :]
    return retval

In [3]:
classes = {
    0: "fish",
    1: "ground_spider",
    2: "frog",
    5: "snake",
    8: "web_spider",
    19: "penguin"
}
id_labels = classes.keys()
text_labels = list(classes.values())

In [4]:
embeddings = [
    [
        embed(image['image'])
        for image in tiny_imagenet 
        if image['label'] == label
    ]
    for label in id_labels
]

In [5]:
from collections import defaultdict
from dataclasses import dataclass
from typing import Callable

import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist
from scipy.stats import cramervonmises_2samp, ks_2samp

@dataclass
class EmbeddingMatrix:
    """
    Maps scores from each query label to each reference label.
    """

    statistics: dict[str, dict[str, float]]
    pvalues: dict[str, dict[str, float]]

In [6]:
def get_inner_product_function(dataset: np.ndarray):
    return lambda x: np.dot(dataset, x)


def get_projections(
    basis: np.ndarray,
    inner_product: Callable[[np.ndarray], np.ndarray],
) -> np.ndarray:
    """
    Computes orthogonal projection of a dataset onto a plane with orthonormal basis `basis`.
    The dataset is represented abstractly by the Callable `inner_product`, which is a function
    that takes in a vector `v` and returns the array of inner products <x, v> for all x in the
    dataset.

    Parameters
    ----------
    basis: np.ndarray
        Should be k x d, all length 1, and orthogonal
    inner_product: Callable[np.ndarray, np.ndarray]

    Returns
    -------
    projections: np.ndarray
        N x k
    """
    # check that basis is orthonormal
    dots = np.einsum("ik,jk->ij", basis, basis)
    np.testing.assert_almost_equal(dots, np.identity(dots.shape[0]))

    res = []
    for x in basis:
        res.append(inner_product(x))
    res = np.array(res)
    return res.T

In [7]:
def _compute_metrics_from_projections(
    queries: list[list[np.ndarray]],
    references: list[list[np.ndarray]] | None,
    classes: list[str],
    method: str,
    func: Callable,
) -> EmbeddingMatrix:
    """
    Computes metrics using a selectable scipy.stats function.

    Parameters
    ----------
    queries : list[list[np.ndarray]]
        A list of query embedding lists.
    references : list[list[np.ndarray]], optional
        A list of reference embedding lists.
    classes : list[str]
        A list of labels mapped to the distance lists.
    method : str
        The method of calculating distance.
    func : Callable
        A scipy.stats function.

    Returns
    -------
    EmbeddingMatrix
        A metric containing a confusion matrix for both p-value and distance metric.
    """
    pvalues = defaultdict(dict)
    statistics = defaultdict(dict)

    def cast_and_round(x):
        x = float(x)
        return round(x, 3)

    queries_are_references = references is None
    references = queries

    vector = np.random.normal(size=(1, references[0][0].shape[1]))
    vector = vector / np.linalg.norm(vector)

    for i, query in enumerate(queries):
        for j, reference in enumerate(references):
            if queries_are_references and i == j:
                # split the set in two and measure how similarly distributed it is.
                split_idx = len(reference) // 2
                reference_proj = get_projections(
                    vector,
                    get_inner_product_function(
                        np.concatenate(reference[:split_idx])
                    )
                )
                query_proj = get_projections(
                    vector,
                    get_inner_product_function(
                        np.concatenate(query[split_idx:])
                    )
                )
            else:
                reference_proj = get_projections(
                    vector,
                    get_inner_product_function(
                        np.concatenate(reference)
                    )
                )
                query_proj = get_projections(
                    vector,
                    get_inner_product_function(
                        np.concatenate(query)
                    )
                )
            
            metric = func(reference_proj, query_proj)

            label_i = classes[i]
            label_j = classes[j]
            pvalues[label_i][label_j] = cast_and_round(metric.pvalue)
            statistics[label_i][label_j] = cast_and_round(metric.statistic)

    return EmbeddingMatrix(pvalues=pvalues, statistics=statistics)

In [8]:
def compute_cvm(
    queries: list[list[np.ndarray]],
    classes: list[str],
    references: list[list[np.ndarray]] | None = None,
    method: str = "cosine",
) -> EmbeddingMatrix:
    return _compute_metrics_from_projections(
        queries=queries,
        references=references,
        classes=classes,
        method=method,
        func=cramervonmises_2samp,
    )


def compute_ks(
    queries: list[list[np.ndarray]],
    classes: list[str],
    references: list[list[np.ndarray]] | None = None,
    method: str = "cosine",
) -> EmbeddingMatrix:
    return _compute_metrics_from_projections(
        queries=queries,
        references=references,
        classes=classes,
        method=method,
        func=ks_2samp,
    )

Compute metrics

In [9]:
cvm = compute_cvm(embeddings, text_labels)
ks = compute_ks(embeddings, text_labels)

  x = float(x)


Create pandas dataframes

In [11]:
from embedding_metrics import create_dataframe

cvm_statistics, cvm_pvalues = create_dataframe(cvm, text_labels)
ks_statistics, ks_pvalues = create_dataframe(ks, text_labels)

In [12]:
cvm_statistics

Unnamed: 0_level_0,Unnamed: 1_level_0,Reference,Reference,Reference,Reference,Reference,Reference
Unnamed: 0_level_1,Unnamed: 1_level_1,fish,ground_spider,frog,snake,web_spider,penguin
Query,fish,0.802,1.779,0.471,0.251,0.088,0.103
Query,ground_spider,1.779,0.272,0.501,0.927,2.385,1.201
Query,frog,0.471,0.501,0.18,0.082,0.927,0.31
Query,snake,0.251,0.927,0.082,0.368,0.568,0.149
Query,web_spider,0.088,2.385,0.927,0.568,0.17,0.271
Query,penguin,0.103,1.201,0.31,0.149,0.271,0.21


In [13]:
cvm_pvalues

Unnamed: 0_level_0,Unnamed: 1_level_0,Reference,Reference,Reference,Reference,Reference,Reference
Unnamed: 0_level_1,Unnamed: 1_level_1,fish,ground_spider,frog,snake,web_spider,penguin
Query,fish,0.007,0.0,0.047,0.189,0.659,0.578
Query,ground_spider,0.0,0.164,0.04,0.004,0.0,0.001
Query,frog,0.047,0.04,0.318,0.691,0.004,0.127
Query,snake,0.189,0.004,0.691,0.088,0.027,0.399
Query,web_spider,0.659,0.0,0.004,0.027,0.343,0.164
Query,penguin,0.578,0.001,0.127,0.399,0.164,0.254


In [14]:
ks_statistics

Unnamed: 0_level_0,Unnamed: 1_level_0,Reference,Reference,Reference,Reference,Reference,Reference
Unnamed: 0_level_1,Unnamed: 1_level_1,fish,ground_spider,frog,snake,web_spider,penguin
Query,fish,0.2,0.3,0.18,0.22,0.5,0.18
Query,ground_spider,0.3,0.4,0.2,0.12,0.42,0.4
Query,frog,0.18,0.2,0.24,0.14,0.4,0.3
Query,snake,0.22,0.12,0.14,0.16,0.42,0.32
Query,web_spider,0.5,0.42,0.4,0.42,0.16,0.62
Query,penguin,0.18,0.4,0.3,0.32,0.62,0.24


In [15]:
ks_pvalues

Unnamed: 0_level_0,Unnamed: 1_level_0,Reference,Reference,Reference,Reference,Reference,Reference
Unnamed: 0_level_1,Unnamed: 1_level_1,fish,ground_spider,frog,snake,web_spider,penguin
Query,fish,0.71,0.022,0.396,0.179,0.0,0.396
Query,ground_spider,0.022,0.036,0.272,0.869,0.0,0.001
Query,frog,0.396,0.272,0.475,0.717,0.001,0.022
Query,snake,0.179,0.869,0.717,0.915,0.0,0.012
Query,web_spider,0.0,0.0,0.001,0.0,0.915,0.0
Query,penguin,0.396,0.001,0.022,0.012,0.0,0.475
