In [9]:
from typing import Optional, Tuple
import torch
import torch.nn as nn

In [10]:
# check device
torch.cuda.is_available()

False

In [11]:
class SiglipVisionConfig:
    def __init__(
        self,
        hidden_size=768,
        intermediate_size=3072,
        num_hidden_layers=12,
        num_attention_heads=12,
        num_channels=3,
        image_size=224,
        patch_size=16,
        layer_norm_eps=1e-6,
        attention_dropout=0.0,
        num_image_tokens:int = None,
        **kwargs
    ):
        super().__init__()
        
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.num_channels = num_channels
        self.image_size = image_size
        self.patch_size = patch_size
        self.layer_norm_eps = layer_norm_eps
        self.attention_dropout = attention_dropout
        self.num_image_tokens = num_image_tokens
        
config = SiglipVisionConfig()

In [12]:
class SiglipPatchEmbedding(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        
        self.patch_layer = nn.Conv2d(in_channels=3,
                                     out_channels=config.hidden_size,
                                     kernel_size=config.patch_size,
                                     stride=config.patch_size,
                                     padding='valid')
        
        # total number of pixel in image = 224*224, total pixel in a patch is 16*16 , so total patch number = 224*224 / 16*16
        self.patch_num = config.image_size**2 // config.patch_size**2
        
        self.pos_embedding = nn.Embedding(num_embeddings=self.patch_num,
                                          embedding_dim=config.hidden_size)
        
        self.register_buffer('position_ids',
                             torch.arange(0, self.patch_num).view(1, -1),
                             persistent=False) # `persistent` = do we need it as a part of module state dict
        
        
    def forward(self, x: torch.Tensor):
        x = self.patch_layer(x) # (Batch , channel, height, width) => (Batch, hidden_size, height, width)
        # x = torch.flatten(x, start_dim=-2, end_dim=-1) => [Batch, hidden_size, heigt*widht = num_patch]
        x = x.view(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]) 
        # (Batch, num_patch, hidden_size)
        x = x.transpose(-2, -1)
        x = x + self.pos_embedding(self.position_ids)
        return x

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.eps = config.layer_norm_eps
        self.hidden_size = config.hidden_size
        self.scale = nn.Parameter(torch.ones(self.hidden_size))
        self.shift = nn.Parameter(torch.zeros(self.hidden_size))
        
    def forward(self, x):
        x_mean = torch.mean(x, dim=-1, keepdim=True)
        x_std = torch.std(x, dim=-1, keepdim=True)
        x_norm = (x - x_mean) / (x_std + self.eps)
        return x_norm * self.scale + self.shift

In [13]:
class SiglipEncoder(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.eps = config.layer_norm_eps
        

In [14]:
class SiglipVisionTransformer(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.patch_embedding = SiglipPatchEmbedding(config)
        self.encoder = SiglipEncoder(config) 
        

In [15]:
class SiglipVisionModel(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.visual_model = SiglipVisionTransformer(config)
        
    def forward(self, x):
        # [Batch, channel, height, width] -> [Batch, patch, embed_dim]
        return self.visual_model(x)
        
        

In [56]:
torch.manual_seed(0)
x = torch.ones((5, 5, 10))