<a href="https://colab.research.google.com/github/G0nkly/pytorch_sandbox/blob/main/vit/VIT_2nd_try.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# https://www.youtube.com/watch?v=4XgDdxpXHEQ

In [None]:
!curl -L http://i.imgur.com/8o9DXSj.jpeg --output image.jpg

In [None]:
from PIL import Image

In [None]:
img = Image.open("image.jpg")
img

In [None]:
from transformers import AutoProcessor, SiglipVisionModel, SiglipVisionConfig

In [None]:
processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
vision_model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224", config=SiglipVisionConfig(vision_use_head=False))
vision_model

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from dataclasses import dataclass

from torchvision import transforms

def preprocess_image(image, image_size=224):
  preprocess = transforms.Compose([
      transforms.Resize((image_size, image_size)),
      transforms.ToTensor(),
      transforms.Normalize(
          mean=[0.485, 0.456, 0.406],
          std=[0.229, 0.224, 0.225]
      )
  ])

  image_tensor = preprocess(image)
  image_tensor = image_tensor.unsqueeze(0)
  return image_tensor

image_tensor = preprocess_image(img)
embed_dim = 768
patch_size = 16
image_size = 224
num_patches = (image_size // patch_size) ** 2

with torch.no_grad():
  patch_embedding = nn.Conv2d(in_channels=3, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size)
  patches = patch_embedding(image_tensor)

patches.shape, num_patches

In [None]:
position_embedding = nn.Embedding(num_patches, embed_dim)
position_ids = torch.arange(num_patches).expand(1, -1)

position_ids.shape

In [None]:
# after flatten (1, 768, 196)
embeddings = patches.flatten(start_dim=2, end_dim=-1)
# (1, 768, 196) -> (1, 196, 768)
embeddings = embeddings.transpose(1,2)
embeddings = embeddings + position_embedding(position_ids)
embeddings.shape

In [None]:
import matplotlib.pyplot as plt

patches_viz = embeddings[0].detach().numpy()

plt.figure(figsize=(15,8))
plt.imshow(patches_viz, aspect="auto", cmap="viridis")
plt.colorbar()
plt.title("Visualization of all patch embeddings")
plt.xlabel("Embedding Dimension")
plt.ylabel("Patch Number")
plt.show()

In [None]:
vision_model.eval()
inputs = processor(images=img, return_tensors="pt")

with torch.no_grad():
  patch_embeddings = vision_model.vision_model.embeddings(inputs.pixel_values)

print(patch_embeddings.shape)

patches_viz = patch_embeddings[0].detach().numpy()

plt.figure(figsize=(15,8))
plt.imshow(patches_viz, aspect="auto", cmap="viridis")
plt.colorbar()
plt.title("Trained Model: All Patch Embeddings")
plt.xlabel("Embedding Dimension")
plt.ylabel("Patch Number")
plt.show()

In [None]:
@dataclass
class SiglipVisionConfig:
  num_channels: int = 3
  embed_dim: int = 768
  image_size: int = 224
  patch_size: int = 16

class SiglipVisionEmbeddings(nn.Module):

  def __init__(self, config: SiglipVisionConfig):
    super().__init__()
    self.config = config

    self.num_channels = config.num_channels
    self.embed_dim = config.embed_dim
    self.image_size = config.image_size
    self.patch_size = config.patch_size
    self.patch_embedding = nn.Conv2d(
        in_channels=self.num_channels,
        out_channels=self.embed_dim,
        kernel_size=self.patch_size,
        stride=self.patch_size,
        padding="valid"
    )

    self.num_patches = (self.image_size // self.patch_size) **2
    self.num_positions = self.num_patches
    self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
    self.register_buffer(
        "position_ids",
        torch.arange(self.num_positions).expand((1, -1)),
        persistent=False,
    )

  def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
    B, C, H, W = pixel_values.shape

    patch_embeds = self.patch_embedding(pixel_values)
    embeddings = patch_embeds.flatten(start_dim=2, end_dim=-1)
    embeddings = embeddings.transpose(1,2)
    embeddings = embeddings + self.position_embedding(self.position_ids)
    return embeddings

In [None]:
embd = SiglipVisionEmbeddings(SiglipVisionConfig())
embd(image_tensor).shape

In [None]:
from transformers import SiglipVisionModel as HFSiglipVisionModel

our_state_dict = embd.state_dict()
hf_state_dict = {k.replace("vision_model.embeddings.", ""): v for k, v in vision_model.state_dict().items()
  if "vision_model.embeddings." in k}
our_state_dict.update(hf_state_dict)
embd.load_state_dict(our_state_dict)

with torch.no_grad():
  our_output = embd(image_tensor)
  hf_output = vision_model.vision_model.embeddings(image_tensor)
  print("Max difference between our output and HF output:", torch.max(torch.abs(our_output - hf_output))) # = 0 <=> Match!