In [1]:
import numpy as np
from datasets import load_dataset
from transformers import AutoImageProcessor, AutoModel

tiny_imagenet = load_dataset('zh-plus/tiny-imagenet', split='valid')
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-small')
model = AutoModel.from_pretrained('facebook/dinov2-small')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from embedding_metrics import compute_cvm, compute_ks, create_dataframe

In [3]:
def embed(x):
    inputs = processor(images=x, return_tensors="pt")
    outputs = model(**inputs)
    retval = (
        outputs
        .last_hidden_state
        .flatten()
        .detach()
        .numpy()
    )
    retval = retval[np.newaxis, :]
    return retval

In [4]:
classes = {
    0: "fish",
    1: "ground_spider",
    2: "frog",
    5: "snake",
    8: "web_spider",
    19: "penguin"
}
id_labels = classes.keys()
text_labels = list(classes.values())

In [5]:
embeddings = [
    [
        embed(image['image'])
        for image in tiny_imagenet 
        if image['label'] == label
    ]
    for label in id_labels
]

Compute metrics

In [6]:
cvm = compute_cvm(embeddings, text_labels)
ks = compute_ks(embeddings, text_labels)

Create pandas dataframes

In [7]:
cvm_statistics, cvm_pvalues = create_dataframe(cvm, text_labels)
ks_statistics, ks_pvalues = create_dataframe(ks, text_labels)

In [8]:
cvm_statistics

Unnamed: 0_level_0,Unnamed: 1_level_0,Reference,Reference,Reference,Reference,Reference,Reference
Unnamed: 0_level_1,Unnamed: 1_level_1,fish,ground_spider,frog,snake,web_spider,penguin
Query,fish,0.294,277.493,214.82,283.315,297.393,320.182
Query,ground_spider,271.57,0.229,84.414,179.595,226.174,336.562
Query,frog,167.965,53.51,0.167,149.982,215.313,307.253
Query,snake,119.521,26.432,25.786,0.089,144.239,216.062
Query,web_spider,297.914,237.63,254.888,297.373,1.901,357.284
Query,penguin,131.408,151.044,129.168,163.965,198.624,5.812


In [9]:
cvm_pvalues

Unnamed: 0_level_0,Unnamed: 1_level_0,Reference,Reference,Reference,Reference,Reference,Reference
Unnamed: 0_level_1,Unnamed: 1_level_1,fish,ground_spider,frog,snake,web_spider,penguin
Query,fish,0.141,0.0,0.0,0.0,0.0,0.0
Query,ground_spider,0.0,0.218,0.0,0.0,0.0,0.0
Query,frog,0.0,0.0,0.341,0.0,0.0,0.0
Query,snake,0.0,0.0,0.0,0.644,0.0,0.0
Query,web_spider,0.0,0.0,0.0,0.0,0.0,0.0
Query,penguin,0.0,0.0,0.0,0.0,0.0,0.0


In [11]:
ks_statistics

Unnamed: 0_level_0,Unnamed: 1_level_0,Reference,Reference,Reference,Reference,Reference,Reference
Unnamed: 0_level_1,Unnamed: 1_level_1,fish,ground_spider,frog,snake,web_spider,penguin
Query,fish,0.047,0.702,0.616,0.712,0.726,0.763
Query,ground_spider,0.694,0.038,0.365,0.544,0.619,0.787
Query,frog,0.52,0.28,0.031,0.484,0.593,0.73
Query,snake,0.44,0.216,0.207,0.026,0.491,0.599
Query,web_spider,0.731,0.638,0.663,0.73,0.09,0.836
Query,penguin,0.471,0.501,0.461,0.522,0.594,0.144


In [12]:
ks_pvalues

Unnamed: 0_level_0,Unnamed: 1_level_0,Reference,Reference,Reference,Reference,Reference,Reference
Unnamed: 0_level_1,Unnamed: 1_level_1,fish,ground_spider,frog,snake,web_spider,penguin
Query,fish,0.128,0.0,0.0,0.0,0.0,0.0
Query,ground_spider,0.0,0.354,0.0,0.0,0.0,0.0
Query,frog,0.0,0.0,0.598,0.0,0.0,0.0
Query,snake,0.0,0.0,0.0,0.798,0.0,0.0
Query,web_spider,0.0,0.0,0.0,0.0,0.0,0.0
Query,penguin,0.0,0.0,0.0,0.0,0.0,0.0
