<a href="https://colab.research.google.com/github/BangachevKiril/RepresentationLearningTheory/blob/main/GeometryOfTrainedModels.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# LRH Notebook

Estimating xi for a collection of different siglip and siglip2 models

In [None]:
from datasets import load_dataset
import os
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import transforms
import torchvision

In [None]:
# optional. This is toconveniently save embeddings once processed if necessary
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# get libraries necessary for Hugging Face models
!pip install -U transformers
import requests
from transformers import AutoProcessor, AutoModel



## Model Names

In [None]:
varieties_siglip = ['google/siglip-so400m-patch14-384',
             'google/siglip-base-patch16-224',
             'google/siglip-base-patch16-384',
             'google/siglip-large-patch16-256',
             'google/siglip-so400m-patch14-224',
             'google/siglip-base-patch16-256',
             'google/siglip-base-patch16-512',
             'google/siglip-large-patch16-384',
             'google/siglip2-so400m-patch14-384',
             'google/siglip2-base-patch16-224',
             'google/siglip2-base-patch16-384',
             'google/siglip2-large-patch16-256',
             'google/siglip2-so400m-patch14-224',
             'google/siglip2-base-patch16-256',
             'google/siglip2-base-patch16-512',
             'google/siglip2-large-patch16-384',
             'google/siglip2-giant-opt-patch16-256']

# Getting Data and Embedding

In [None]:
!mkdir ImageNetVal
%cd ImageNetVal
!wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar

/content/ImageNetVal
--2025-09-10 14:26:50--  https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar
Resolving image-net.org (image-net.org)... 171.64.68.16
Connecting to image-net.org (image-net.org)|171.64.68.16|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 6744924160 (6.3G) [application/x-tar]
Saving to: ‘ILSVRC2012_img_val.tar’


2025-09-10 14:30:56 (26.2 MB/s) - ‘ILSVRC2012_img_val.tar’ saved [6744924160/6744924160]



In [None]:
!mkdir val
# extract
!tar -xvf ILSVRC2012_img_val.tar -C val

In [None]:
# uses the labels_text.txt file which has a list of labels in English.
labels_lookup = []
with open('labels_text.txt', 'r') as f:
  for line in f:
    labels_lookup.append(line.split('\n')[0])

In [None]:
# split the ImageNet dataset based on classes
import shutil

# Paths
val_dir = "val"                          # folder with 50k images
mapping_file = "labels_text.txt"         # class labels
output_dir = "val_unpacked"              # where we unpack into 998 folders (two classes have 100 instead of 50 images)

# Create output folder
os.makedirs(output_dir, exist_ok=True)

# Read mapping
with open(mapping_file, "r") as f:
    lines = f.readlines()

for i,line in enumerate(lines):
    filename = 'ILSVRC2012_val_000' +'0'*(5-len(str(i+1))) + str(i+1) + '.JPEG'
    src = os.path.join(val_dir, filename)
    dst_dir = os.path.join(output_dir, str(line.split('\n')[0]))
    os.makedirs(dst_dir, exist_ok=True)
    dst = os.path.join(dst_dir, filename)
    shutil.move(src, dst)

print("Done! Validation set organized into 1000 folders.")

Done! Validation set organized into 1000 folders.


### Run Infernece

In [None]:
def calculate_xi(image_embeddings, text_embeddings):
  diff = image_embeddings - text_embeddings
  mean_of_norms = np.mean(np.linalg.norm(diff, axis=1)**2)
  norm_of_mean = np.linalg.norm(np.mean(diff, axis = 0))**2
  random = np.random.permutation(np.arange(image_embeddings.shape[0]))
  random_diff = image_embeddings[random, :] - text_embeddings
  random_mean_of_norms = np.mean(np.linalg.norm(random_diff, axis=1)**2)
  return np.array([mean_of_norms, norm_of_mean, random_mean_of_norms])

In [None]:
# siglip
for variety in varieties_siglip:
  # load model
  print(variety)
  processor = AutoProcessor.from_pretrained(variety)
  model = AutoModel.from_pretrained(variety).to('cuda')
  # set up data
  image_embeddings = []
  text_embeddings = []
  splits = []
  ordered_labels = []
  so_far = 0
  # embed
  for i,label in enumerate(os.listdir('ImageNetVal/val_unpacked')):
    with torch.no_grad(): #no grad computations plz
      # get data
      images = [Image.open('ImageNetVal/val_unpacked/'+label + '/' + filename ).convert('RGB') for
              filename in os.listdir('ImageNetVal/val_unpacked/'+label)] # get images
      for _ in os.listdir('ImageNetVal/val_unpacked/'+label):
        ordered_labels.append(label)
      # process data
      inputs = processor(images=images, text=[label],
                       return_tensors="pt", padding="max_length").to(model.device) # process
      splits.append(so_far)
      # find representations
      outputs = model(**inputs) # forward pass
    #record
      image_embeddings.append(outputs.image_embeds.detach().cpu().numpy())
      text_embeddings.append(np.outer(np.ones(len(images)),
                                     outputs.text_embeds.detach().cpu().numpy()))
      so_far += len(images)
  # after all is embedded
  image_embeddings = np.concatenate(image_embeddings, axis=0)
  text_embeddings = np.concatenate(text_embeddings, axis=0)
  splits = np.array(splits)
  ordered_labels = np.array(ordered_labels)
  xi = calculate_xi(image_embeddings, text_embeddings)
  # np.savez('drive/My Drive/Research/SigLIP/imagenetval'+variety.split('/')[-1]+'.npz',
  #        text = text_embeddings,
  #        image = image_embeddings,
  #        splits = splits,
  #        ordered_labels= ordered_labels,
  #        xi = xi)

google/siglip-so400m-patch14-384


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

google/siglip-base-patch16-224


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

google/siglip-base-patch16-384


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

google/siglip-large-patch16-256


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

google/siglip-so400m-patch14-224


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

google/siglip-base-patch16-256


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

google/siglip-base-patch16-512


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

google/siglip-large-patch16-384


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

google/siglip2-so400m-patch14-384


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.


google/siglip2-base-patch16-224


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.


google/siglip2-base-patch16-384


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.


google/siglip2-large-patch16-256


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.


google/siglip2-so400m-patch14-224


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.


google/siglip2-base-patch16-256


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.


google/siglip2-base-patch16-512


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.


google/siglip2-large-patch16-384


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.


google/siglip2-giant-opt-patch16-256


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.


# Experiments

In [None]:
import pandas as pd

In [None]:
subset_classes = [
    'pelican',
    'banjo',
    'bath towel',
    'castle',
    'acorn',
    'red wine',
    'Granny Smith',
    'wig',
    'radiator',
    'honeycomb'
]

# Get min and max cosine similarities

In [None]:
# get splits and form similarity boolean matrix
prefix = 'drive/My Drive/Research/SigLIP/imagenetval'
data = np.load(prefix+varieties_siglip[-1].split('/')[-1]+'.npz')
splits = data['splits']
labels = data['ordered_labels']

In [None]:
correct_label = np.zeros((np.shape(data['text'])[0], np.shape(splits)[0]),dtype=bool)
c = 0
for i in range(correct_label.shape[0]):
  if c+1<correct_label.shape[1] and i == splits[c+1]:
    c += 1
  correct_label[i,c] = True

In [None]:
subset = np.array([label in subset_classes for label in labels])
subset_splits = [0]
for label in subset_classes:
  subset_splits.append(np.sum(labels == label) + subset_splits[-1])
subset_splits = np.array(subset_splits[:-1])

In [None]:
correct_label_subset = np.zeros((np.sum(subset), len(subset_classes)),dtype=bool)
c = 0
for i in range(np.sum(subset)):
  if c+1<len(subset_classes) and i == subset_splits[c+1]:
    c += 1
  correct_label_subset[i,c] = True

In [None]:
subset_labels = np.array([np.where(labels == label)[0][0] for label in subset_classes])

In [None]:
def get_cos_similarity_statistics(text_embeddings, image_embeddings, correct_label, splits):
  ips = image_embeddings @ text_embeddings[splits, :].T
  matching = ips[correct_label]
  mismatching = ips[~correct_label]
  return np.array([np.percentile(matching, 5), np.percentile(matching, 95), np.mean(matching),
                   np.percentile(mismatching, 5), np.percentile(mismatching, 95), np.mean(mismatching)])

In [None]:
prefix = 'drive/My Drive/Research/SigLIP/imagenetval'
vals = []
vals_subset = []
for variety in varieties_siglip:
  print(variety)
  data = np.load(prefix+variety.split('/')[-1]+'.npz')

  text_full = data['text']
  image_full = data['image']
  vals.append(np.concatenate([calculate_xi(image_full, text_full),
                              get_cos_similarity_statistics(text_full, image_full, correct_label, splits)]))

  text_subset = text_full[subset, :]
  image_subset = image_full[subset,:]
  vals_subset.append(np.concatenate([calculate_xi(image_subset, text_subset),
                                     get_cos_similarity_statistics(text_subset, image_subset, correct_label_subset, subset_splits)]))

google/siglip-so400m-patch14-384
google/siglip-base-patch16-224
google/siglip-base-patch16-384
google/siglip-large-patch16-256
google/siglip-so400m-patch14-224
google/siglip-base-patch16-256
google/siglip-base-patch16-512
google/siglip-large-patch16-384
google/siglip2-so400m-patch14-384
google/siglip2-base-patch16-224
google/siglip2-base-patch16-384
google/siglip2-large-patch16-256
google/siglip2-so400m-patch14-224
google/siglip2-base-patch16-256
google/siglip2-base-patch16-512
google/siglip2-large-patch16-384
google/siglip2-giant-opt-patch16-256


In [None]:
import pandas as pd
summary  = pd.DataFrame(vals, index=varieties_siglip,
                       columns=['mean_of_norms', 'norm_of_mean','random_mean_of_norms',
                                '5%_pos_cos', '95%_pos_cos', 'mean_pos_cos',
                                '5%_neg_cos', '95%_neg_cos', 'mean_neg_cos'])
summary[['norm_of_mean','mean_of_norms', 'random_mean_of_norms',
        '5%_pos_cos','95%_neg_cos',
        'mean_pos_cos', 'mean_neg_cos']]

Unnamed: 0,norm_of_mean,mean_of_norms,random_mean_of_norms,5%_pos_cos,95%_neg_cos,mean_pos_cos,mean_neg_cos
google/siglip-so400m-patch14-384,1.116217,1.724865,2.002979,0.076926,0.048571,0.137568,-0.001476
google/siglip-base-patch16-224,1.222064,1.810037,2.060556,0.03826,0.018077,0.094982,-0.030491
google/siglip-base-patch16-384,1.215966,1.806814,2.063315,0.040821,0.01727,0.096593,-0.031865
google/siglip-large-patch16-256,1.242042,1.795498,2.071453,0.04001,0.015123,0.102251,-0.03585
google/siglip-so400m-patch14-224,1.106269,1.726981,2.004091,0.074733,0.048326,0.136509,-0.002205
google/siglip-base-patch16-256,1.22246,1.799137,2.058665,0.041268,0.019994,0.100431,-0.029425
google/siglip-base-patch16-512,1.215097,1.805874,2.063822,0.040906,0.016991,0.097063,-0.032233
google/siglip-large-patch16-384,1.234044,1.808413,2.076564,0.035337,0.011998,0.095794,-0.038382
google/siglip2-so400m-patch14-384,1.302343,1.970297,1.975328,-0.017915,0.042756,0.014851,0.012275
google/siglip2-base-patch16-224,1.36966,2.011628,2.034635,-0.049748,0.023278,-0.005814,-0.017362


In [None]:
import pandas as pd
summary_subset  = pd.DataFrame(vals_subset, index=varieties_siglip,
                       columns=['mean_of_norms', 'norm_of_mean','random_mean_of_norms',
                                '5%_pos_cos', '95%_pos_cos', 'mean_pos_cos',
                                '5%_neg_cos', '95%_neg_cos', 'mean_neg_cos'])
summary_subset[['norm_of_mean','mean_of_norms', 'random_mean_of_norms',
        '5%_pos_cos','95%_neg_cos',
        'mean_pos_cos', 'mean_neg_cos']]

Unnamed: 0,norm_of_mean,mean_of_norms,random_mean_of_norms,5%_pos_cos,95%_neg_cos,mean_pos_cos,mean_neg_cos
google/siglip-so400m-patch14-384,1.179319,1.731823,1.962427,0.082902,0.047438,0.134088,0.00688
google/siglip-base-patch16-224,1.280315,1.813226,2.026054,0.040934,0.015833,0.093387,-0.023312
google/siglip-base-patch16-384,1.278733,1.811447,2.023058,0.042949,0.015528,0.094277,-0.023921
google/siglip-large-patch16-256,1.290667,1.802289,2.034465,0.042347,0.012932,0.098855,-0.030243
google/siglip-so400m-patch14-224,1.171229,1.73166,1.965315,0.082982,0.046566,0.13417,0.006587
google/siglip-base-patch16-256,1.278844,1.803811,2.021443,0.04595,0.017765,0.098094,-0.022763
google/siglip-base-patch16-512,1.279646,1.810664,2.020846,0.043339,0.01551,0.094668,-0.024039
google/siglip-large-patch16-384,1.291837,1.816628,2.034713,0.035621,0.011745,0.091686,-0.031676
google/siglip2-so400m-patch14-384,1.323267,1.966208,1.969006,-0.010208,0.042012,0.016896,0.015482
google/siglip2-base-patch16-224,1.405334,2.001502,2.023003,-0.041619,0.020486,-0.000751,-0.011653
