In [1]:
import numpy as np
import pandas as pd
from collections import defaultdict
from scipy.spatial.distance import cdist
from scipy.stats import ks_2samp, cramervonmises_2samp

from datasets import load_dataset
from transformers import AutoImageProcessor, AutoModel

tiny_imagenet = load_dataset('zh-plus/tiny-imagenet', split='valid')
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-small')
model = AutoModel.from_pretrained('facebook/dinov2-small')

  from .autonotebook import tqdm as notebook_tqdm
Using the latest cached version of the dataset since zh-plus/tiny-imagenet couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /home/czaloom/.cache/huggingface/datasets/zh-plus___tiny-imagenet/default/0.0.0/5a77092c28e51558c5586e9c5eb71a7e17a5e43f (last modified on Mon Jun 24 10:11:00 2024).


In [2]:
def embed(x):
    inputs = processor(images=x, return_tensors="pt")
    outputs = model(**inputs)
    retval = (
        outputs
        .last_hidden_state
        .flatten()
        .detach()
        .numpy()
    )
    retval = retval[np.newaxis, :]
    return retval

In [3]:
def compute_distances(A, B):
    return np.array(
        [
            cdist(a, b)
            for a in A
            for b in B
        ]
    )

In [4]:
def compute_metrics(data, classes, func):
    pvalues = defaultdict(dict)
    statistics = defaultdict(dict)
    
    def cast_and_round(x):
        x = float(x)
        return round(x,3)

    for i, query in enumerate(data):
        for j, reference in enumerate(data):
            reference_distance = compute_distances(reference, reference)
            if i == j:
                # split the set in two and measure how similarly distributed it is.
                split_idx = len(reference_distance) // 2
                metric = func(reference_distance[:split_idx], reference_distance[split_idx:])
            else:
                query_distance = compute_distances(reference, query)
                metric = func(reference_distance, query_distance)

            label_i = classes[i]
            label_j = classes[j]
            pvalues[label_i][label_j] = cast_and_round(metric.pvalue)
            statistics[label_i][label_j] = cast_and_round(metric.statistic)
            
    return (pvalues, statistics)

def compute_cvm(data, classes):
    return compute_metrics(data, classes, func=cramervonmises_2samp)

def compute_ks(data, classes):
    return compute_metrics(data, classes, func=ks_2samp)

In [5]:
classes = {
    0: "fish",
    1: "ground_spider",
    2: "frog",
    5: "snake",
    8: "web_spider",
    19: "penguin"
}

embeddings = []
for label in classes.keys():
    embeddings.append(
        [
            embed(image['image'])
            for image in tiny_imagenet 
            if image['label'] == label
        ]
    )

In [6]:
embeddings[0][0].shape

(1, 98688)

In [7]:
labels = list(classes.values())

cvm_pvalues, cvm_statistics = compute_cvm(embeddings, labels)
ks_pvalues, ks_statistics = compute_ks(embeddings, labels)

  x = float(x)


In [8]:
col_ix = pd.MultiIndex.from_product([['Reference'], labels]) 
row_ix = pd.MultiIndex.from_product([['Query'], labels])

In [9]:
cvm_statistic_df = pd.DataFrame(cvm_statistics)
cvm_statistic_df = cvm_statistic_df.set_index(row_ix)
cvm_statistic_df.columns = col_ix

cvm_pvalue_df = pd.DataFrame(cvm_pvalues)
cvm_pvalue_df = cvm_pvalue_df.set_index(row_ix)
cvm_pvalue_df.columns = col_ix

print("Cramer-Von Mises")
print(" === statistic ===")
print(cvm_statistic_df)
print(" === p-values ===")
print(cvm_pvalue_df)

Cramer-Von Mises
 === statistic ===
                    Reference                                             \
                         fish ground_spider     frog    snake web_spider   
Query fish              0.282       282.536  220.019  288.389    302.395   
      ground_spider   276.491         0.220   88.925  184.655    231.123   
      frog            172.883        57.489    0.161  155.017    220.365   
      snake           124.062        29.078   28.576    0.085    148.863   
      web_spider      302.893       242.703  260.031  302.358      1.825   
      penguin         136.432       156.242  134.315  169.142    203.883   

                              
                     penguin  
Query fish           325.079  
      ground_spider  341.346  
      frog           312.181  
      snake          221.046  
      web_spider     361.959  
      penguin          5.582  
 === p-values ===
                    Reference                                               
            

In [10]:
ks_statistic_df = pd.DataFrame(ks_statistics)
ks_statistic_df = ks_statistic_df.set_index(row_ix)
ks_statistic_df.columns = col_ix

ks_pvalue_df = pd.DataFrame(ks_pvalues)
ks_pvalue_df = ks_pvalue_df.set_index(row_ix)
ks_pvalue_df.columns = col_ix

print("Kolmgorov-Smirnov")
print(" === statistic ===")
print(ks_statistic_df)
print()
print(" === p-values ===")
print(ks_pvalue_df)

Kolmgorov-Smirnov
 === statistic ===
                    Reference                                               
                         fish ground_spider   frog  snake web_spider penguin
Query fish              0.046         0.706  0.620  0.716      0.730   0.765
      ground_spider     0.697         0.037  0.372  0.548      0.622   0.789
      frog              0.525         0.288  0.030  0.489      0.597   0.733
      snake             0.446         0.222  0.213  0.026      0.496   0.603
      web_spider        0.734         0.642  0.666  0.732      0.088   0.838
      penguin           0.478         0.507  0.468  0.528      0.598   0.141

 === p-values ===
                    Reference                                               
                         fish ground_spider   frog  snake web_spider penguin
Query fish              0.136         0.000  0.000  0.000        0.0     0.0
      ground_spider     0.000         0.366  0.000  0.000        0.0     0.0
      frog          