In [1]:
import numpy as np
import pandas as pd
from collections import defaultdict
from scipy.spatial.distance import cdist
from scipy.stats import ks_2samp, cramervonmises_2samp

from datasets import load_dataset
from transformers import AutoImageProcessor, AutoModel

tiny_imagenet = load_dataset('zh-plus/tiny-imagenet', split='valid')
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-small')
model = AutoModel.from_pretrained('facebook/dinov2-small')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
images0 = [image['image'] for image in tiny_imagenet if image['label'] == 0]
images1 = [image['image'] for image in tiny_imagenet if image['label'] == 1]
images2 = [image['image'] for image in tiny_imagenet if image['label'] == 2]

In [3]:
def embed(x):
    inputs = processor(images=x, return_tensors="pt")
    outputs = model(**inputs)
    retval = (
        outputs
        .last_hidden_state
        .flatten()
        .detach()
        .numpy()
    )
    retval = retval[np.newaxis, :]
    return retval

embeddings0 = [embed(x) for x in images0]
embeddings1 = [embed(x) for x in images1]
embeddings2 = [embed(x) for x in images2]

In [4]:
embeddings0[0].shape

(1, 98688)

In [5]:
def compute_distances(A, B):
    return np.array(
        [
            cdist(a, b)
            for a in A
            for b in B
        ]
    )

def compute_metrics(data, func = cramervonmises_2samp):
    pvalues = defaultdict(dict)
    statistics = defaultdict(dict)
    
    def cast_and_round(x):
        x = float(x)
        return round(x,3)

    for i, query in enumerate(data):
        for j, reference in enumerate(data):
            reference_distance = compute_distances(reference, reference)
            if i == j:
                # split the set in two and measure how similarly distributed it is.
                split_idx = len(reference_distance) // 2
                metric = func(reference_distance[:split_idx], reference_distance[split_idx:])
            else:
                query_distance = compute_distances(reference, query)
                metric = func(reference_distance, query_distance)
            pvalues[i][j] = cast_and_round(metric.pvalue)
            statistics[i][j] = cast_and_round(metric.statistic)
    return (pvalues, statistics)

def compute_cvm(data):
    return compute_metrics(data, func=cramervonmises_2samp)

def compute_ks(data):
    return compute_metrics(data, func=ks_2samp)

data = [embeddings0, embeddings1, embeddings2]
cvm_pvalues, cvm_statistics = compute_cvm(data)
ks_pvalues, ks_statistics = compute_ks(data)

  x = float(x)


In [6]:
col_ix = pd.MultiIndex.from_product([['Reference'], [0,1,2]]) 
row_ix = pd.MultiIndex.from_product([['Query'], [0,1,2]])

query_cvm_statistic_df = pd.DataFrame(cvm_statistics)
query_cvm_statistic_df = query_cvm_statistic_df.set_index(row_ix)
query_cvm_statistic_df.columns = col_ix

query_cvm_pvalue_df = pd.DataFrame(cvm_pvalues)
query_cvm_pvalue_df = query_cvm_pvalue_df.set_index(row_ix)
query_cvm_pvalue_df.columns = col_ix

print("Cramer-Von Mises")
print(" === statistic ===")
print(query_cvm_statistic_df)
print()
print(" === p-value ===")
print(query_cvm_pvalue_df)

Cramer-Von Mises
 === statistic ===
        Reference                  
                0        1        2
Query 0     0.282  282.536  220.019
      1   276.491    0.220   88.925
      2   172.883   57.489    0.161

 === p-value ===
        Reference              
                0      1      2
Query 0     0.152  0.000  0.000
      1     0.000  0.233  0.000
      2     0.000  0.000  0.359


In [7]:
col_ix = pd.MultiIndex.from_product([['Reference'], [0,1,2]]) 
row_ix = pd.MultiIndex.from_product([['Query'], [0,1,2]])

query_cvm_statistic_df = pd.DataFrame(ks_statistics)
query_cvm_statistic_df = query_cvm_statistic_df.set_index(row_ix)
query_cvm_statistic_df.columns = col_ix

query_cvm_pvalue_df = pd.DataFrame(ks_statistics)
query_cvm_pvalue_df = query_cvm_pvalue_df.set_index(row_ix)
query_cvm_pvalue_df.columns = col_ix

print("Kolmgorov-Smirnov")
print(" === statistic ===")
print(query_cvm_statistic_df)
print()
print(" === p-value ===")
print(query_cvm_pvalue_df)

Kolmgorov-Smirnov
 === statistic ===
        Reference              
                0      1      2
Query 0     0.046  0.706  0.620
      1     0.697  0.037  0.372
      2     0.525  0.288  0.030

 === p-value ===
        Reference              
                0      1      2
Query 0     0.046  0.706  0.620
      1     0.697  0.037  0.372
      2     0.525  0.288  0.030
