In [None]:
from typing import Optional, Tuple, List, Dict, Iterable, Union
import torch
import torch.nn as nn
from PIL import Image
import numpy as np

In [2]:
# check device
torch.cuda.is_available()

False

In [3]:
class SiglipVisionConfig:
    def __init__(
        self,
        hidden_size=768,
        intermediate_size=3072,
        num_hidden_layers=12,
        num_attention_heads=12,
        num_channels=3,
        image_size=224,
        patch_size=16,
        layer_norm_eps=1e-6,
        attention_dropout=0.0,
        num_image_tokens:int = None,
        **kwargs
    ):
        super().__init__()
        
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.num_channels = num_channels
        self.image_size = image_size
        self.patch_size = patch_size
        self.layer_norm_eps = layer_norm_eps
        self.attention_dropout = attention_dropout
        self.num_image_tokens = num_image_tokens
        
config = SiglipVisionConfig()

In [4]:
class SiglipPatchEmbedding(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        
        self.patch_layer = nn.Conv2d(in_channels=3,
                                     out_channels=config.hidden_size,
                                     kernel_size=config.patch_size,
                                     stride=config.patch_size,
                                     padding='valid')
        
        # total number of pixel in image = 224*224, total pixel in a patch is 16*16 , so total patch number = 224*224 / 16*16
        self.patch_num = config.image_size**2 // config.patch_size**2
        
        self.pos_embedding = nn.Embedding(num_embeddings=self.patch_num,
                                          embedding_dim=config.hidden_size)
        
        self.register_buffer('position_ids',
                             torch.arange(0, self.patch_num).view(1, -1),
                             persistent=False) # `persistent` = do we need it as a part of module state dict
        
        
    def forward(self, x: torch.Tensor):
        x = self.patch_layer(x) # (Batch , channel, height, width) => (Batch, hidden_size, height, width)
        # x = torch.flatten(x, start_dim=-2, end_dim=-1) => [Batch, hidden_size, heigt*widht = num_patch]
        x = x.view(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]) 
        # (Batch, num_patch, hidden_size)
        x = x.transpose(-2, -1)
        x = x + self.pos_embedding(self.position_ids)
        return x

In [5]:
class LayerNorm(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.eps = config.layer_norm_eps
        self.hidden_size = config.hidden_size
        self.scale = nn.Parameter(torch.ones(self.hidden_size))
        self.shift = nn.Parameter(torch.zeros(self.hidden_size))
        
    def forward(self, x):
        x_mean = torch.mean(x, dim=-1, keepdim=True)
        x_std = torch.std(x, dim=-1, keepdim=True)
        x_norm = (x - x_mean) / (x_std + self.eps)
        return x_norm * self.scale + self.shift

In [6]:
class SiglipAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_attention_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        assert self.hidden_size % self.num_attention_heads == 0, "hidden size should be divisible by num_attention_heads"
        self.attn_head_size = self.hidden_size // self.num_attention_heads
        
        self.q_W = nn.Linear(self.hidden_size, self.hidden_size)
        self.k_W = nn.Linear(self.hidden_size, self.hidden_size)
        self.v_W = nn.Linear(self.hidden_size, self.hidden_size)
        self.out_W = nn.Linear(self.hidden_size, self.hidden_size)
        
    def forward(self, x):
        batch_size, patch_len, _ = x.shape
        
        queries = self.q_W(x)
        keys = self.k_W(x)
        values = self.v_W(x)
        
        # (Batch, patch, hidden_size) -> (Batch, patch, heads, head_embed) -> (batch, heads, patch, head_embed)
        queries = queries.view(batch_size, patch_len, self.num_attention_heads, self.attn_head_size).transpose(1, 2)
        keys = keys.view(batch_size, patch_len, self.num_attention_heads, self.attn_head_size).transpose(1, 2)
        values = values.view(batch_size, patch_len, self.num_attention_heads, self.attn_head_size).transpose(1, 2)
        
        # attention_logits (b, head , patch, patch)
        attention_logits = (queries @ keys.transpose(-1, -2)) * self.attn_head_size**-0.5
        attention_scores = nn.functional.softmax(attention_logits, dim=-1, dtype=torch.float32)
        
        # (b, head, patch, patch) * (b, head, patch, embed) => (b, head, patch, embed)
        contextual_embeddings = attention_scores @ values
        
        # (b, head, patch, embed) -> (b, patch, head, embed)
        contextual_embeddings = contextual_embeddings.transpose(1, 2)
        
        contextual_embeddings = contextual_embeddings.contiguous().view(batch_size, patch_len, self.num_attention_heads*self.attn_head_size)
        
        contextual_embeddings = self.out_W(contextual_embeddings)
        
        return contextual_embeddings, attention_scores 
        

In [7]:
class SiglipEncoderLayer(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.layer_norm_1 = LayerNorm(config)
        self.attention = SiglipAttention(config)
        self.layer_norm_2 = LayerNorm(config)
        self.final_linear = nn.Sequential(
            nn.Linear(config.hidden_size, config.intermediate_size),
            nn.GELU('tanh'),
            nn.Linear(config.intermediate_size, config.hidden_size)
        )        
    
    def forward(self, x):
        residual =  x
        x = self.layer_norm_1(x)
        x, _ = self.attention(x)
        x = x + residual
        
        residual = x
        x = self.layer_norm_2(x)
        x = self.final_linear(x)
        x = x + residual
        
        return x

In [8]:
class SiglipEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_layer = config.num_hidden_layers
        self.encoder_layers = nn.ModuleList(
            [SiglipEncoderLayer(config) for _ in range(self.n_layer)]
            )
        
    def forward(self, x):
        for layer in self.encoder_layers:
            x = layer(x)

In [9]:
class SiglipVisionTransformer(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.patch_embedding = SiglipPatchEmbedding(config)
        self.encoder = SiglipEncoder(config) 
        

In [10]:
class SiglipVisionModel(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.visual_model = SiglipVisionTransformer(config)
        
    def forward(self, x):
        # [Batch, channel, height, width] -> [Batch, patch, embed_dim]
        return self.visual_model(x)
        
        

In [None]:
IMAGENET_STANDARD_MEAN = [0.5, 0.5, 0.5]
IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5]

def resize(image: Image,
           size: Tuple[int, int],
           resample: Image.Resampling=None,
           reducing_gap: Optional[int]=None):
    height, width = size
    resize_image = image.resize(
        (width, height), resample=resample, reducing_gap=reducing_gap
    )
    
    return resize_image

def rescale(image: np.ndarray,
            scale: float,
            dtype: np.dtype = np.float32):
    rescaled_image = image * scale
    rescaled_image = rescaled_image.astype(dtype)
    return rescaled_image

def normalize(image: Image.Image,
              mean: Union[float, Iterable[float]],
              std: Union[float, Iterable[float]],
              ):
    mean = np.array(mean)
    std = np.array(std)
    normalized_image = (image - mean) / std
    
    return normalized_image
    

def process_image(
    images: List[Image.Image],
    size: Dict[str, int] = None,
    resample: Image.Resampling = None,
    rescale_factor: float=None,
    image_mean: Optional[Union[float, List[float]]] = None,
    image_std: Optional[Union[float, List[float]]] = None,
):
    height, width = size[0], size[1]
    
    images = [
        resize(image=image, size=(height, width), resample=resample) for image in images
    ]
    
    images = [np.array(image) for image in images]
    
    images = [rescale(image, scale=rescale_factor) for image in images]
    
    images = [normalize(image, mean=image_mean, std=image_std) for image in images]
    
    images = [image.transpose(2, 0, 1) for image in images]
    
    return images
    

class PaliGemmaProcessor:
    
    IMAGE_TOKEN = "<image>"
    
    def __init__(self, tokenizer, num_image_token: int, image_size: int):
        super().__init__()
        
        self.image_seq_len = num_image_token # num_patches
        self.image_size = image_size
        
        # Special token to add
        token_to_add = {'additional_special_token': [self.IMAGE_TOKEN]}
        tokenizer.add_special_tokens(token_to_add) # added special token in hugging face tokenizer
        
        EXTRA_TOKENS = [
            f"<loc{i:04d}>" for i in range(1024)
        ]
        
        EXTRA_TOKENS += [
            f'<seg{i:03d}>' for i in range(128)
        ]
        
        tokenizer.add_tokens(EXTRA_TOKENS)
        
        self.image_token_id = tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
        
        tokenizer.add_bos_token = False
        tokenizer.add_eos_token = False
        
        self.tokenizer = tokenizer
        
    def __call__(self, 
                 texts: List[str],
                 images: List[Image.Image],
                 padding=False,
                 truncate=False) -> dict:
        
        assert len(images) == len(texts), "the number of text and image should be same"
        
        pixel_values = process_image(
            images,
            size=(self.image_size, self.image_size),
            resample=Image.Resampling.BICUBIC,
            rescale_factor=1 / 255.0,
            image_mean=IMAGENET_STANDARD_MEAN,
            image_std=IMAGENET_STANDARD_STD
        )
        
        
        