In [1]:
import numpy as np
from datasets import load_dataset
from transformers import AutoImageProcessor, AutoModel

tiny_imagenet = load_dataset('zh-plus/tiny-imagenet', split='valid')
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-small')
model = AutoModel.from_pretrained('facebook/dinov2-small')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def embed(x):
    inputs = processor(images=x, return_tensors="pt")
    outputs = model(**inputs)
    retval = (
        outputs
        .last_hidden_state
        .flatten()
        .detach()
        .numpy()
    )
    retval = retval[np.newaxis, :]
    return retval

In [3]:
classes = {
    0: "fish",
    1: "ground_spider",
    2: "frog",
    5: "snake",
    8: "web_spider",
    19: "penguin"
}
id_labels = classes.keys()
text_labels = list(classes.values())

In [4]:
embeddings = [
    [
        embed(image['image'])
        for image in tiny_imagenet 
        if image['label'] == label
    ]
    for label in id_labels
]

In [5]:
embeddings[0][0].shape

(1, 98688)

Compute metrics

In [6]:
from embedding_metrics import create_dataframe, EmbeddingMatrix, defaultdict, cramervonmises_2samp, ks_2samp, Callable

In [None]:
def _compute_metrics_feature_wise(
    queries: list[list[np.ndarray]],
    references: list[list[np.ndarray]] | None,
    classes: list[str],
    method: str,
    func: Callable,
) -> EmbeddingMatrix:
    """
    Computes metrics using a selectable scipy.stats function.

    Parameters
    ----------
    queries : list[list[np.ndarray]]
        A list of query embedding lists.
    references : list[list[np.ndarray]], optional
        A list of reference embedding lists.
    classes : list[str]
        A list of labels mapped to the distance lists.
    method : str
        The method of calculating distance.
    func : Callable
        A scipy.stats function.

    Returns
    -------
    EmbeddingMatrix
        A metric containing a confusion matrix for both p-value and distance metric.
    """
    pvalues = defaultdict(lambda: defaultdict(float))
    statistics = defaultdict(lambda: defaultdict(float))

    def cast_and_round(x):
        x = float(x)
        return round(x, 3)

    queries_are_references = references is None
    references = queries

    for i, query in enumerate(queries):
        for j, reference in enumerate(references):
            label_i = classes[i]
            label_j = classes[j]
            pvalues[label_i][label_j] = 0
            statistics[label_i][label_j] = 0
            for k in range(reference[0].shape[1]):
                reference_feature = np.concatenate(reference)[:, k]

                if queries_are_references and i == j:
                    # split the set in two and measure how similarly distributed it is.
                    split_idx = reference_feature.shape[0] // 2
                    metric = func(
                        reference_feature[:split_idx],
                        reference_feature[split_idx:],
                    )
                else:
                    query_feature = np.concatenate(query)[:, k]
                    metric = func(reference_feature, query_feature)

                pvalues[label_i][label_j] += cast_and_round(metric.pvalue)
                statistics[label_i][label_j] += cast_and_round(metric.statistic)

    return EmbeddingMatrix(pvalues=pvalues, statistics=statistics)


def compute_cvm(
    queries: list[list[np.ndarray]],
    classes: list[str],
    references: list[list[np.ndarray]] | None = None,
    method: str = "cosine",
) -> EmbeddingMatrix:
    return _compute_metrics_feature_wise(
        queries=queries,
        references=references,
        classes=classes,
        method=method,
        func=cramervonmises_2samp,
    )


def compute_ks(
    queries: list[list[np.ndarray]],
    classes: list[str],
    references: list[list[np.ndarray]] | None = None,
    method: str = "cosine",
) -> EmbeddingMatrix:
    return _compute_metrics_feature_wise(
        queries=queries,
        references=references,
        classes=classes,
        method=method,
        func=ks_2samp,
    )

In [7]:
cvm = compute_cvm(embeddings, text_labels)
ks = compute_ks(embeddings, text_labels)

Create pandas dataframes

In [8]:
cvm_statistics, cvm_pvalues = create_dataframe(cvm, text_labels)
ks_statistics, ks_pvalues = create_dataframe(ks, text_labels)

In [9]:
cvm_statistics

Unnamed: 0_level_0,Unnamed: 1_level_0,Reference,Reference,Reference,Reference,Reference,Reference
Unnamed: 0_level_1,Unnamed: 1_level_1,fish,ground_spider,frog,snake,web_spider,penguin
Query,fish,0.16,2.008,0.058,2.304,3.365,2.728
Query,ground_spider,2.008,0.26,1.992,0.81,6.457,5.513
Query,frog,0.058,1.992,0.055,2.381,3.598,2.864
Query,snake,2.304,0.81,2.381,0.072,6.407,5.66
Query,web_spider,3.365,6.457,3.598,6.407,0.175,0.202
Query,penguin,2.728,5.513,2.864,5.66,0.202,0.213


In [10]:
cvm_pvalues

Unnamed: 0_level_0,Unnamed: 1_level_0,Reference,Reference,Reference,Reference,Reference,Reference
Unnamed: 0_level_1,Unnamed: 1_level_1,fish,ground_spider,frog,snake,web_spider,penguin
Query,fish,0.369,0.0,0.839,0.0,0.0,0.0
Query,ground_spider,0.0,0.179,0.0,0.007,0.0,0.0
Query,frog,0.839,0.0,0.875,0.0,0.0,0.0
Query,snake,0.0,0.007,0.0,0.764,0.0,0.0
Query,web_spider,0.0,0.0,0.0,0.0,0.33,0.267
Query,penguin,0.0,0.0,0.0,0.0,0.267,0.248


In [11]:
ks_statistics

Unnamed: 0_level_0,Unnamed: 1_level_0,Reference,Reference,Reference,Reference,Reference,Reference
Unnamed: 0_level_1,Unnamed: 1_level_1,fish,ground_spider,frog,snake,web_spider,penguin
Query,fish,0.24,0.46,0.12,0.52,0.56,0.54
Query,ground_spider,0.46,0.28,0.48,0.34,0.82,0.74
Query,frog,0.12,0.48,0.16,0.5,0.6,0.6
Query,snake,0.52,0.34,0.5,0.16,0.8,0.72
Query,web_spider,0.56,0.82,0.6,0.8,0.24,0.18
Query,penguin,0.54,0.74,0.6,0.72,0.18,0.28


In [12]:
ks_pvalues

Unnamed: 0_level_0,Unnamed: 1_level_0,Reference,Reference,Reference,Reference,Reference,Reference
Unnamed: 0_level_1,Unnamed: 1_level_1,fish,ground_spider,frog,snake,web_spider,penguin
Query,fish,0.475,0.0,0.869,0.0,0.0,0.0
Query,ground_spider,0.0,0.285,0.0,0.006,0.0,0.0
Query,frog,0.869,0.0,0.915,0.0,0.0,0.0
Query,snake,0.0,0.006,0.0,0.915,0.0,0.0
Query,web_spider,0.0,0.0,0.0,0.0,0.475,0.396
Query,penguin,0.0,0.0,0.0,0.0,0.396,0.285
