In [1]:
import numpy as np
import pandas as pd
from collections import defaultdict
from scipy.spatial.distance import cdist
from scipy.stats import ks_2samp, cramervonmises_2samp

from datasets import load_dataset
from transformers import AutoImageProcessor, AutoModel

tiny_imagenet = load_dataset('zh-plus/tiny-imagenet', split='valid')
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-small')
model = AutoModel.from_pretrained('facebook/dinov2-small')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def embed(x):
    inputs = processor(images=x, return_tensors="pt")
    outputs = model(**inputs)
    retval = (
        outputs
        .last_hidden_state
        .flatten()
        .detach()
        .numpy()
    )
    retval = retval[np.newaxis, :]
    return retval

In [3]:
def compute_distances(A, B):
    return np.array(
        [
            cdist(a, b)
            for a in A
            for b in B
        ]
    )

def compute_self_distances(A):
    return np.array(
        [
            cdist(a, b)
            for a_idx, a in enumerate(A)
            for b_idx, b in enumerate(A)
            if a_idx != b_idx
        ]
    )

In [4]:
def compute_metrics(data, classes, func):
    pvalues = defaultdict(dict)
    statistics = defaultdict(dict)
    
    def cast_and_round(x):
        x = float(x)
        return round(x,3)

    for i, query in enumerate(data):
        for j, reference in enumerate(data):
            reference_distance = compute_self_distances(reference)
            if i == j:
                # split the set in two and measure how similarly distributed it is.
                split_idx = len(reference_distance) // 2
                metric = func(reference_distance[:split_idx], reference_distance[split_idx:])
            else:
                query_distance = compute_distances(reference, query)
                metric = func(reference_distance, query_distance)

            label_i = classes[i]
            label_j = classes[j]
            pvalues[label_i][label_j] = cast_and_round(metric.pvalue)
            statistics[label_i][label_j] = cast_and_round(metric.statistic)
            
    return (pvalues, statistics)

def compute_cvm(data, classes):
    return compute_metrics(data, classes, func=cramervonmises_2samp)

def compute_ks(data, classes):
    return compute_metrics(data, classes, func=ks_2samp)

In [5]:
classes = {
    0: "fish",
    1: "ground_spider",
    2: "frog",
    5: "snake",
    8: "web_spider",
    19: "penguin"
}

embeddings = []
for label in classes.keys():
    embeddings.append(
        [
            embed(image['image'])
            for image in tiny_imagenet 
            if image['label'] == label
        ]
    )

In [6]:
embeddings[0][0].shape

(1, 98688)

In [7]:
labels = list(classes.values())

cvm_pvalues, cvm_statistics = compute_cvm(embeddings, labels)
ks_pvalues, ks_statistics = compute_ks(embeddings, labels)

  x = float(x)


In [8]:
col_ix = pd.MultiIndex.from_product([['Reference'], labels]) 
row_ix = pd.MultiIndex.from_product([['Query'], labels])

In [9]:
cvm_statistic_df = pd.DataFrame(cvm_statistics)
cvm_statistic_df = cvm_statistic_df.set_index(row_ix)
cvm_statistic_df.columns = col_ix

cvm_pvalue_df = pd.DataFrame(cvm_pvalues)
cvm_pvalue_df = cvm_pvalue_df.set_index(row_ix)
cvm_pvalue_df.columns = col_ix

print("Cramer-Von Mises")
print(" === statistic ===")
print(cvm_statistic_df)
print(" === p-values ===")
print(cvm_pvalue_df)

Cramer-Von Mises
 === statistic ===
                    Reference                                             \
                         fish ground_spider     frog    snake web_spider   
Query fish              0.294       277.493  214.820  283.315    297.393   
      ground_spider   271.570         0.229   84.414  179.595    226.174   
      frog            167.965        53.510    0.167  149.982    215.313   
      snake           119.521        26.432   25.786    0.089    144.239   
      web_spider      297.914       237.630  254.888  297.373      1.901   
      penguin         131.408       151.044  129.168  163.965    198.624   

                              
                     penguin  
Query fish           320.182  
      ground_spider  336.562  
      frog           307.253  
      snake          216.062  
      web_spider     357.284  
      penguin          5.812  
 === p-values ===
                    Reference                                               
            

In [10]:
ks_statistic_df = pd.DataFrame(ks_statistics)
ks_statistic_df = ks_statistic_df.set_index(row_ix)
ks_statistic_df.columns = col_ix

ks_pvalue_df = pd.DataFrame(ks_pvalues)
ks_pvalue_df = ks_pvalue_df.set_index(row_ix)
ks_pvalue_df.columns = col_ix

print("Kolmgorov-Smirnov")
print(" === statistic ===")
print(ks_statistic_df)
print()
print(" === p-values ===")
print(ks_pvalue_df)

Kolmgorov-Smirnov
 === statistic ===
                    Reference                                               
                         fish ground_spider   frog  snake web_spider penguin
Query fish              0.047         0.702  0.616  0.712      0.726   0.763
      ground_spider     0.694         0.038  0.365  0.544      0.619   0.787
      frog              0.520         0.280  0.031  0.484      0.593   0.730
      snake             0.440         0.216  0.207  0.026      0.491   0.599
      web_spider        0.731         0.638  0.663  0.730      0.090   0.836
      penguin           0.471         0.501  0.461  0.522      0.594   0.144

 === p-values ===
                    Reference                                               
                         fish ground_spider   frog  snake web_spider penguin
Query fish              0.128         0.000  0.000  0.000        0.0     0.0
      ground_spider     0.000         0.354  0.000  0.000        0.0     0.0
      frog          