In [1]:
import numpy as np
from datasets import load_dataset
from transformers import AutoImageProcessor, AutoModel

tiny_imagenet = load_dataset('zh-plus/tiny-imagenet', split='valid')
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-small')
model = AutoModel.from_pretrained('facebook/dinov2-small')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def embed(x):
    inputs = processor(images=x, return_tensors="pt")
    outputs = model(**inputs)
    retval = (
        outputs
        .last_hidden_state
        .flatten()
        .detach()
        .numpy()
    )
    retval = retval[np.newaxis, :]
    return retval

In [3]:
classes = {
    0: "fish",
    1: "ground_spider",
    2: "frog",
    5: "snake",
    8: "web_spider",
    19: "penguin"
}
id_labels = classes.keys()
text_labels = list(classes.values())

In [4]:
embeddings = [
    [
        embed(image['image'])
        for image in tiny_imagenet 
        if image['label'] == label
    ]
    for label in id_labels
]

Compute metrics

In [5]:
from embedding_metrics import compute_cvm, compute_ks, create_dataframe

In [6]:
cvm = compute_cvm(embeddings, text_labels)
ks = compute_ks(embeddings, text_labels)

Create pandas dataframes

In [7]:
cvm_statistics, cvm_pvalues = create_dataframe(cvm, text_labels)
ks_statistics, ks_pvalues = create_dataframe(ks, text_labels)

In [8]:
cvm_statistics

Unnamed: 0_level_0,Unnamed: 1_level_0,Reference,Reference,Reference,Reference,Reference,Reference
Unnamed: 0_level_1,Unnamed: 1_level_1,fish,ground_spider,frog,snake,web_spider,penguin
Query,fish,0.332,300.961,238.778,316.986,328.888,324.187
Query,ground_spider,326.16,0.893,101.768,181.94,250.241,345.564
Query,frog,222.88,47.346,0.125,170.924,255.216,323.082
Query,snake,197.857,17.905,43.895,1.097,191.244,238.504
Query,web_spider,342.481,262.132,295.753,328.617,2.745,360.07
Query,penguin,214.925,202.98,202.526,232.556,262.217,7.855


In [9]:
cvm_pvalues

Unnamed: 0_level_0,Unnamed: 1_level_0,Reference,Reference,Reference,Reference,Reference,Reference
Unnamed: 0_level_1,Unnamed: 1_level_1,fish,ground_spider,frog,snake,web_spider,penguin
Query,fish,0.11,0.0,0.0,0.0,0.0,0.0
Query,ground_spider,0.0,0.004,0.0,0.0,0.0,0.0
Query,frog,0.0,0.0,0.475,0.0,0.0,0.0
Query,snake,0.0,0.0,0.0,0.001,0.0,0.0
Query,web_spider,0.0,0.0,0.0,0.0,0.0,0.0
Query,penguin,0.0,0.0,0.0,0.0,0.0,0.0


In [10]:
ks_statistics

Unnamed: 0_level_0,Unnamed: 1_level_0,Reference,Reference,Reference,Reference,Reference,Reference
Unnamed: 0_level_1,Unnamed: 1_level_1,fish,ground_spider,frog,snake,web_spider,penguin
Query,fish,0.051,0.737,0.66,0.773,0.786,0.784
Query,ground_spider,0.779,0.065,0.404,0.534,0.659,0.816
Query,frog,0.626,0.258,0.033,0.521,0.669,0.774
Query,snake,0.577,0.166,0.272,0.073,0.571,0.643
Query,web_spider,0.831,0.689,0.74,0.798,0.095,0.866
Query,penguin,0.634,0.601,0.608,0.648,0.686,0.168


In [11]:
ks_pvalues

Unnamed: 0_level_0,Unnamed: 1_level_0,Reference,Reference,Reference,Reference,Reference,Reference
Unnamed: 0_level_1,Unnamed: 1_level_1,fish,ground_spider,frog,snake,web_spider,penguin
Query,fish,0.087,0.0,0.0,0.0,0.0,0.0
Query,ground_spider,0.0,0.011,0.0,0.0,0.0,0.0
Query,frog,0.0,0.0,0.531,0.0,0.0,0.0
Query,snake,0.0,0.0,0.0,0.003,0.0,0.0
Query,web_spider,0.0,0.0,0.0,0.0,0.0,0.0
Query,penguin,0.0,0.0,0.0,0.0,0.0,0.0
