In [None]:
!curl -L http://i.imgur.com/8o9DXSj.jpeg --output image.jpg

In [None]:
from PIL import  Image

img = Image.open("image.jpg")
img

In [None]:
from transformers import AutoProcessor, SiglipVisionModel, SiglipVisionConfig

In [None]:
processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
vision_model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224",config=SiglipVisionConfig(vision_use_head=False))

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from dataclasses import dataclass
from torchvision import transforms

In [None]:
def preprocesses_image(image,image_size=224):
  preprocess = transforms.Compose([transforms.Resize((image_size,image_size)),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.485,0.406],std=[0.229,0.224,0.225])])
  image_tensor = preprocess(image)
  image_tensor = image_tensor.unsqueeze(0)
  return image_tensor

In [None]:
image_tensor = preprocesses_image(img)
embed_dim = 768
patch_size = 16
image_size = 224
num_patches = (image_size // patch_size) ** 2


with torch.no_grad():
  patch_embedding = nn.Conv2d(in_channels=3,out_channels=embed_dim,kernel_size=patch_size,stride=patch_size)
  patches = patch_embedding(image_tensor)
  
patches.shape, num_patches

In [None]:
position_embedding = nn.Embedding(num_patches,embed_dim)
position_ids = torch.arange(num_patches).expand((1,-1))

position_ids.shape

In [None]:
# after flattening (1,768,196)
embeddings = patches.flatten(start_dim=2,end_dim=-1)
embeddings = embeddings.transpose(1,2)
embeddings = embeddings + position_embedding(position_ids)
embeddings.shape

In [None]:
import matplotlib.pyplot as plt

patches_viz = embeddings[0].detach().numpy()

In [None]:
plt.figure(figsize=(15, 10))
plt.imshow(patches_viz, aspect='auto', cmap='viridis')
plt.colorbar()
plt.title('Visualization of All Patch Embeddings')
plt.xlabel('Embedding Dimension')
plt.ylabel('Patch Number')
plt.show()

In [None]:
vision_model.eval()
inputs = processor(images=img, return_tensors="pt")

with torch.no_grad():
  patch_embeddings = vision_model.vision_model.embeddings(inputs.pixel_values)

print(patch_embeddings.shape)

patches_viz = patch_embeddings[0].detach().numpy()  # Shape: [196, 768]

plt.figure(figsize=(15, 10))
plt.imshow(patches_viz, aspect='auto', cmap='viridis')
plt.colorbar()
plt.title('Trained Model: All Patch Embeddings')
plt.xlabel('Embedding Dimension')
plt.ylabel('Patch Number')
plt.show()

In [None]:
# 