In [1]:
import torch
import torch.nn as nn

In [4]:
class SIGLIP_VISION_CONFIG(nn.Module):
    def __init__(
        self,
        hidden_size = 768,
        ffn_2 = 3072,
        num_hidden_layer = 12,
        num_channels = 3,
        image_size = 224,
        patch_size = 16,
        layer_norm = 1e-6,
        attention_dropout = 0.0,
        num_attn_heads = 12):

        super.__init__()
        self.hidden_size = hidden_size,
        self.ffn_2 = ffn_2,
        self.num_hidden_layer = num_hidden_layer,
        self.num_channels = num_channels,
        self.image_size = image_size,
        self.patch_size = patch_size,
        self.layer_norm = layer_norm,
        self.attention_dropout = attention_dropout,
        self.num_heads = num_attn_heads


In [5]:
class SIGLIP_VISION_EMBEDDING(nn.Module):
    def __init__(self,config:SIGLIP_VISION_CONFIG):
        self.config = config
        self.embed_size = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.num_patches = (self.image_size/self.patch_size) *  (self.image_size/self.patch_size)
        self.pos_embeddings = nn.Embedding(num_embeddings=self.num_patches,embedding_dim=self.embed_size)
        self.patch_embeddings = nn.Conv2d(in_channels=self.config.num_channels,out_channels=self.config.hidden_size,kernel_size=self.patch_size,stride=self.patch_size,padding=None)

        self.register_buffer(
            "postions_ids",
            torch.arange(self.num_patches).expand(-1,1),
            persistent=False
        )


    def forward(self,pixel_values):
        batch_size,num_channels,height,width = pixel_values.shape
        patch_embeds = self.patch_embeddings(pixel_values) # [Batch_Size, Embed_Dim, Num_Patches_H, Num_Patches_W] 
        patch_embeds = patch_embeds.flatten(2)# [Batch_Size, Embed_Dim, Num_Patches_H * Num_Patches_W]
        positions = self.postions_ids[:,:self.num_patches]
        pos_embds = self.pos_embeddings(positions)

        final_embeddings = pos_embds + patch_embeds
        return final_embeddings

In [8]:
class SIGLIP_VISION_ATTENTION(nn.Module):
    def __init__(self,config:SIGLIP_VISION_CONFIG, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.config = config
        self.K = nn.Linear(in_features=config.hidden_size,out_features=config.hidden_size)
        self.Q = nn.Linear(in_features=config.hidden_size,out_features=config.hidden_size)
        self.V = nn.Linear(in_features=config.hidden_size,out_features=config.hidden_size)

        self.O_proj = nn.Linear(in_features=config.hidden_size,out_features=config.hidden_size)
        self.head_dim = self.config.hidden_size/self.config.num_heads



    def forward(self,hidden_states):
        batch_size,seq_len,dim = hidden_states.size()
        query_states = self.Q(hidden_states) #[batch,seq_len,dim]@[dim,dim] = [batch,seq_len,dim]
        key_states = self.K(hidden_states) #[batch,seq_len,dim]@[dim,dim] = [batch,seq_len,dim]
        value_states = self.V(hidden_states) #[batch,seq_len,dim]@[dim,dim] = [batch,seq_len,dim]

        key_states = key_states.view(batch_size,seq_len,self.config.num_heads,(self.config.hidden_size/self.config.num_heads)).transpose(1,2)#[batch,n_heads,seq_len,head_dim]
        query_states = query_states.view(batch_size,seq_len,self.config.num_heads,(self.config.hidden_size/self.config.num_heads)).transpose(1,2)#[batch,n_heads,seq_len,head_dim]
        value_states = value_states.view(batch_size,seq_len,self.config.num_heads,(self.config.hidden_size/self.config.num_heads)).transpose(1,2)#[batch,n_heads,seq_len,head_dim]
        qkT = query_states @ key_states.transpose(2,3) / torch.sqrt(self.head_dim)#[batch,n_heads,seq_len,seq_len]

        qkT = torch.nn.functional.softmax(qkT,dim=-1)
        final_boss = qkT @ value_states#[batch,n_heads,seq_len,head_dim]
        final_boss = final_boss.reshape(batch_size,seq_len,self.config.num_heads,self.head_dim)
        final_boss = final_boss.reshape(batch_size,seq_len,self.config.hidden_size)
        final_boss = self.O_proj(final_boss)#[batch,seq_len,dim]
        return final_boss


In [9]:
class SIGLIP_FFN(nn.Module):
    def __init__(self,config:SIGLIP_VISION_CONFIG ,*args, **kwargs):
        self.config = config
        self.ffn1 = nn.Linear(in_features=config.hidden_size,out_features=config.ffn_2)
        self.ffn2 = nn.Linear(in_features=config.ffn_2,out_features = config.hidden_size)
        super().__init__(*args, **kwargs)
    def forward(hidden_features,self):
        out1 = self.ffn1(hidden_features)
        out1 = nn.functional.gelu(out1,approximate='tanh')
        out2 = self.ffn2(out1)
        return out2

In [13]:
class SIGLIP_ENCODER_BLOCK(nn.Module):
    def __init__(self, config:SIGLIP_VISION_CONFIG,*args, **kwargs):
        super().__init__(*args, **kwargs)
        self.config = config
        self.self_attn = SIGLIP_VISION_ATTENTION(config)
        self.layerNorm1 = nn.LayerNorm(config.hidden_size,eps=1e-6)
        self.ffnBlock = SIGLIP_FFN(config=config)
        self.layerNorm2 = nn.LayerNorm(config.hidden_size,eps=1e-6)
    def forward(self,hidden_states):
        residual_state = hidden_states,
        hidden_states = self.layerNorm1(hidden_states)
        hidden_states = self.self_attn(hidden_states)
        hidden_states = hidden_states + residual_state
        residual_state = hidden_states

        hidden_states = self.layerNorm2(hidden_states)
        hidden_states = self.ffnBlock(hidden_states)
        hidden_states = hidden_states + residual_state

        return hidden_states


In [14]:
class SIGlIP_ENCODER(nn.Module):
    def __init__(self,config:SIGLIP_VISION_CONFIG, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.config = config
        self.blocks = nn.ModuleList(
            SIGLIP_ENCODER_BLOCK(config) for _ in range(config.num_hidden_layer)
        )
    def forward(self,hidden_states):
        input_embeds = hidden_states

        for layer in self.blocks:
            hidden_states = layer(input_embeds)

        return hidden_states    

In [16]:
class SIGLIP_TRANSFORMER(nn.Module):
    def __init__(self,config:SIGLIP_VISION_CONFIG, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.embedding = SIGLIP_VISION_EMBEDDING(config)
        self.encoder_blocks = SIGlIP_ENCODER(config=config)
        self.post_layer_norm = nn.LayerNorm(config.hidden_size,eps = 1e-6)
    def forward(self,pixel_values):
        hidden_states = self.embedding(pixel_values)   
        hidden_states = self.encoder_blocks(hidden_states)
        final_state = self.post_layer_norm(hidden_states)
        return final_state   

In [18]:
class SiglipVisionModel(nn.Module):

    def __init__(self, config: SIGLIP_VISION_CONFIG):
        super().__init__()
        self.config = config
        self.vision_model = SIGLIP_TRANSFORMER(config)

    def forward(self, pixel_values):
        # [Batch_Size, Channels, Height, Width] -> [Batch_Size, Num_Patches, Embed_Dim]
        return self.vision_model(pixel_values=pixel_values) 