In [2]:
from typing import Optional, Tuple, List, Dict, Iterable, Union
import torch
import torch.nn as nn
from PIL import Image
import numpy as np

In [2]:
# check device
torch.cuda.is_available()

False

## Config

In [3]:
class SiglipVisionConfig:
    def __init__(
        self,
        hidden_size=768,
        intermediate_size=3072,
        num_hidden_layers=12,
        num_attention_heads=12,
        num_channels=3,
        image_size=224,
        patch_size=16,
        layer_norm_eps=1e-6,
        attention_dropout=0.0,
        num_image_tokens:int = None,
        **kwargs
    ):
        super().__init__()
        
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.num_channels = num_channels
        self.image_size = image_size
        self.patch_size = patch_size
        self.layer_norm_eps = layer_norm_eps
        self.attention_dropout = attention_dropout
        self.num_image_tokens = num_image_tokens
        
config = SiglipVisionConfig()

## Visual Model

In [None]:
# layer taks a raw image -> patch_tokens with positional embeddings
class SiglipPatchEmbedding(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        
        self.patch_layer = nn.Conv2d(in_channels=3,
                                     out_channels=config.hidden_size,
                                     kernel_size=config.patch_size,
                                     stride=config.patch_size,
                                     padding='valid')
        
        # total number of pixel in image = 224*224, total pixel in a patch is 16*16 , so total patch number = 224*224 / 16*16
        self.patch_num = config.image_size**2 // config.patch_size**2
        
        self.pos_embedding = nn.Embedding(num_embeddings=self.patch_num,
                                          embedding_dim=config.hidden_size)
        
        self.register_buffer('position_ids',
                             torch.arange(0, self.patch_num).view(1, -1),
                             persistent=False) # `persistent` = do we need it as a part of module state dict
        
        
    def forward(self, x: torch.Tensor):
        x = self.patch_layer(x) # (Batch , channel, height, width) => (Batch, hidden_size, height, width)
        # x = torch.flatten(x, start_dim=-2, end_dim=-1) => [Batch, hidden_size, heigt*widht = num_patch]
        x = x.view(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]) 
        # (Batch, num_patch, hidden_size)
        x = x.transpose(-2, -1)
        x = x + self.pos_embedding(self.position_ids)
        return x


class LayerNorm(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.eps = config.layer_norm_eps
        self.hidden_size = config.hidden_size
        self.scale = nn.Parameter(torch.ones(self.hidden_size))
        self.shift = nn.Parameter(torch.zeros(self.hidden_size))
        
    def forward(self, x):
        x_mean = torch.mean(x, dim=-1, keepdim=True)
        x_std = torch.std(x, dim=-1, keepdim=True)
        x_norm = (x - x_mean) / (x_std + self.eps)
        return x_norm * self.scale + self.shift

class SiglipAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_attention_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        assert self.hidden_size % self.num_attention_heads == 0, "hidden size should be divisible by num_attention_heads"
        self.attn_head_size = self.hidden_size // self.num_attention_heads
        
        self.q_W = nn.Linear(self.hidden_size, self.hidden_size)
        self.k_W = nn.Linear(self.hidden_size, self.hidden_size)
        self.v_W = nn.Linear(self.hidden_size, self.hidden_size)
        self.out_W = nn.Linear(self.hidden_size, self.hidden_size)
        
    def forward(self, x):
        batch_size, patch_len, _ = x.shape
        
        queries = self.q_W(x)
        keys = self.k_W(x)
        values = self.v_W(x)
        
        # (Batch, patch, hidden_size) -> (Batch, patch, heads, head_embed) -> (batch, heads, patch, head_embed)
        queries = queries.view(batch_size, patch_len, self.num_attention_heads, self.attn_head_size).transpose(1, 2)
        keys = keys.view(batch_size, patch_len, self.num_attention_heads, self.attn_head_size).transpose(1, 2)
        values = values.view(batch_size, patch_len, self.num_attention_heads, self.attn_head_size).transpose(1, 2)
        
        # attention_logits (b, head , patch, patch)
        attention_logits = (queries @ keys.transpose(-1, -2)) * self.attn_head_size**-0.5
        attention_scores = nn.functional.softmax(attention_logits, dim=-1, dtype=torch.float32)
        
        # (b, head, patch, patch) * (b, head, patch, embed) => (b, head, patch, embed)
        contextual_embeddings = attention_scores @ values
        
        # (b, head, patch, embed) -> (b, patch, head, embed)
        contextual_embeddings = contextual_embeddings.transpose(1, 2)
        
        contextual_embeddings = contextual_embeddings.contiguous().view(batch_size, patch_len, self.num_attention_heads*self.attn_head_size)
        
        contextual_embeddings = self.out_W(contextual_embeddings)
        
        return contextual_embeddings, attention_scores 
        


class SiglipEncoderLayer(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.layer_norm_1 = LayerNorm(config)
        self.attention = SiglipAttention(config)
        self.layer_norm_2 = LayerNorm(config)
        self.final_linear = nn.Sequential(
            nn.Linear(config.hidden_size, config.intermediate_size),
            nn.GELU('tanh'),
            nn.Linear(config.intermediate_size, config.hidden_size)
        )        
    
    def forward(self, x):
        residual =  x
        x = self.layer_norm_1(x)
        x, _ = self.attention(x)
        x = x + residual
        
        residual = x
        x = self.layer_norm_2(x)
        x = self.final_linear(x)
        x = x + residual
        
        return x

class SiglipEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_layer = config.num_hidden_layers
        self.encoder_layers = nn.ModuleList(
            [SiglipEncoderLayer(config) for _ in range(self.n_layer)]
            )
        
    def forward(self, x):
        for layer in self.encoder_layers:
            x = layer(x)

class SiglipVisionTransformer(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.patch_embedding = SiglipPatchEmbedding(config)
        self.encoder = SiglipEncoder(config) 
        self.post_layer_norm = LayerNorm(config)
    
    def forward(self, x):
        # (b, channel, height, width) -> (b, patch_len, patch_embedding)
        x = self.patch_embedding(x)
        # (b, num_patch, embedding_size) same as before
        x = self.encoder(x) 
        x = self.post_layer_norm(x)
        
        return x
        
        
        
class SiglipVisionModel(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.visual_model = SiglipVisionTransformer(config)
        
    def forward(self, x):
        # [Batch, channel, height, width] -> [Batch, patch, embed_dim]
        return self.visual_model(x)
        
        
        

## Siglip Input Processor

In [None]:
IMAGENET_STANDARD_MEAN = [0.5, 0.5, 0.5]
IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5]

def resize(image: Image,
           size: Tuple[int, int],
           resample: Image.Resampling=None,
           reducing_gap: Optional[int]=None):
    height, width = size
    resize_image = image.resize(
        (width, height), resample=resample, reducing_gap=reducing_gap
    )
    
    return resize_image

def rescale(image: np.ndarray,
            scale: float,
            dtype: np.dtype = np.float32):
    rescaled_image = image * scale
    rescaled_image = rescaled_image.astype(dtype)
    return rescaled_image

def normalize(image: Image.Image,
              mean: Union[float, Iterable[float]],
              std: Union[float, Iterable[float]],
              ):
    mean = np.array(mean)
    std = np.array(std)
    normalized_image = (image - mean) / std
    
    return normalized_image
    

def process_image(
    images: List[Image.Image],
    size: Dict[str, int] = None,
    resample: Image.Resampling = None,
    rescale_factor: float=None,
    image_mean: Optional[Union[float, List[float]]] = None,
    image_std: Optional[Union[float, List[float]]] = None,
):
    height, width = size[0], size[1]
    
    images = [
        resize(image=image, size=(height, width), resample=resample) for image in images
    ]
    
    images = [np.array(image) for image in images]
    
    images = [rescale(image, scale=rescale_factor) for image in images]
    
    images = [normalize(image, mean=image_mean, std=image_std) for image in images]
    
    images = [image.transpose(2, 0, 1) for image in images]
    
    return images

def add_image_tokens_to_prompt(prefix_prompt,
                               bos_token,
                               image_seq_len,
                               image_token):
    
    return f"{image_token * image_seq_len}{bos_token}{prefix_prompt}\n"
    
    

class PaliGemmaProcessor:
    
    IMAGE_TOKEN = "<image>"
    
    def __init__(self, tokenizer, num_image_token: int, image_size: int):
        super().__init__()
        
        self.image_seq_len = num_image_token # num_patches
        self.image_size = image_size
        
        # Special token to add
        token_to_add = {'additional_special_token': [self.IMAGE_TOKEN]}
        tokenizer.add_special_tokens(token_to_add) # added special token in hugging face tokenizer
        
        EXTRA_TOKENS = [
            f"<loc{i:04d}>" for i in range(1024)
        ]
        
        EXTRA_TOKENS += [
            f'<seg{i:03d}>' for i in range(128)
        ]
        
        tokenizer.add_tokens(EXTRA_TOKENS)
        
        self.image_token_id = tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
        
        tokenizer.add_bos_token = False
        tokenizer.add_eos_token = False
        
        self.tokenizer = tokenizer
        
    def __call__(self, 
                 texts: List[str],
                 images: List[Image.Image],
                 padding=False,
                 truncate=False) -> dict:
        
        assert len(images) == len(texts), "the number of text and image should be same"
        
        pixel_values = process_image(
            images,
            size=(self.image_size, self.image_size),
            resample=Image.Resampling.BICUBIC,
            rescale_factor=1 / 255.0,
            image_mean=IMAGENET_STANDARD_MEAN,
            image_std=IMAGENET_STANDARD_STD
        )
        
        pixel_values = np.stack(pixel_values, axis=0)
        pixel_values = torch.tensor(pixel_values)
        
        input_string = [
            add_image_tokens_to_prompt(
                prefix_prompt=prompt,
                bos_token=self.tokenizer.bos_token,
                image_seq_len=self.image_seq_len,
                image_token=self.IMAGE_TOKEN
            )
            for prompt in texts
        ]
        
        inputs = self.tokenizer(
            input_string,
            return_tesors='pt',
            padding=padding,
            truncation=truncate,
        )
        
        return {'pixel_values': pixel_values, **inputs} # input_ids and attention mask
        
        

## Langauge Model - Gemma

In [None]:
class GemmaConfig:
    
    def __init__(self,
                 vocab_size,
                 hidden_size,
                 intermediate_size,
                 num_hidden_layers,
                 num_attention_heads,
                 num_key_value_heads,
                 head_dim=256,
                 max_position_embeddings=2048,
                 rms_norm_eps=1e-6,
                 rope_theta=100000.0,
                 attention_bias=False,
                 attention_dropout=0.0,
                 pad_token_id=None,
                 **kwargs):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.num_key_value_heads = num_key_value_heads
        self.head_dim = head_dim
        self.max_position_embeddings = max_position_embeddings
        self.rms_norm_eps = rms_norm_eps
        self.rope_theta = rope_theta
        self.attention_bias = attention_bias
        self.attention_dropout = attention_dropout
        self.pad_token_id = pad_token_id

In [None]:

class PaligemmaConfig():
    
    def __init__(
        self,
        vision_config=None,
        text_config=None,
        ignore_index=-100,
        image_token_index=256000,
        vocab_size=257152,
        projection_dim=2048,
        hidden_size=2048,
        pad_token_id = None,
        **kwargs,
    ):
        self.vision_config = SiglipVisionConfig(**vision_config)
        self.text_config = GemmaConfig(**text_config)
        self.is_encoder_decoder = False
        self.ignore_index = ignore_index
        self.image_token_index = image_token_index
        self.projection_dim = projection_dim
        self.hidden_size = hidden_size
        self.pad_token_id = pad_token_id 
        self.vocab_size = self.text_config.vocab_size
        
        self.text_config.num_image_tokens = self.vision_config.image_size**2 // self.vision_config.patch_size**2
        self.vision_config.projection_dim = self.projection_dim
        

In [None]:
class PaliGemmaForConditionalGeneration(nn.Module):
    def __init__(self, config: PaliGemmaConfig):
        super().__init__()
        
        self.config = config
        self.vision_model = SiglipVisionModel(config.vision_config)
        self.multi_model_projector = PaliGemmaMultiModelProjector(config) # A Linear projector that standardize text and image embedding size
        self.vocab_size = config.vacab_size
        
        self.language_model = GemmaForCasualLm(config.text_config)
        
        self.pad_token = self.config.pad_token_id if self.confg.pad_token_id is not -1
        
    # share the weights of embedding (vocab_size * embdding_size) to output projecting layer weights (embedding_size * vocab_size)
    def tie_weights(self):
        return self.language_model.tie_weights()
    
    def _merge_input_ids_with_image_features(
        self, image_features: torch.Tensor,
        input_embeds: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        kv_cache
    ):
        embed_dim = image_features.shape[-1]
        batch_size, seq_len = input_ids.shape
        dtype, device = input_embeds.dtype, input_embeds.device
        
        # same scaling as in the attention for same reason, for consistent magnitude event the number parameter increase (different size of models)
        scaled_image_features = image_features / (self.config.hidden_size**0.5)
        final_embedding = torch.zeros(batch_size, seq_len, embed_dim, dtype=dtype, device=device)
        
        text_mask = (input_ids != self.config.image_token_index) & (input_ids != self.pad_token_id)
        image_mask = (input_ids == self.config.image_token_index)
        pad_mask = (input_ids == self.pad_token_id)
        
        text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
        image_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
        pad_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
        
        # final_embedding (b, total_seq_len, embed_size)
        # all mask = (b, total_seq_len, embed_size )
        # image_feature (b, patch_len, embed)
        final_embedding = torch.where(text_mask_expanded, input_embeds, final_embedding)
        # we use masked_scatter since our scale_image_feature is different shape than final multimodel input shape, so torch.where gives error, but masked_scatter do the same thing in different way, you can learn it 
        final_embedding = torch.masked_scatter(final_embedding, image_mask_expanded, scaled_image_features)
        final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding)
        
        # KV CACHE
        min_dtype = torch.finfo(dtype=dtype)
        q_len = input_embeds.shape[1]
        
        if kv_cache is None or kv_cache.num_items() == 1:
            # Prefilling phase
            casual_mask =  torch.full(size=(batch_size, q_len, q_len),
                                      fill_value=0,
                                      dtype=dtype,
                                      device=device)
        else:
            
            assert q_len == 1, "when enable kv cache the query len should be one"
            
            kv_len = kv_cache.num_items() + q_len
            
            casual_mask = torch.full(size=(batch_size, q_len, kv_len),
                                      fill_value=0,
                                      dtype=dtype,
                                      device=device)
            
            # (batch, query_len, kv_len) -> (batch, att_head, query_len, kv_len``)
            casual_mask = casual_mask.unsqueeze(1) # Adding a head dimensions since we will have multiple head in attention
            
        
        
        
        def forward(
            self,
            input_ids:torch.LongTensor=None,
            pixel_values:torch.FloatTensor=None,
            attention_mask:Optional[torch.Tensor]=None,
            kv_cache=Optional[KVCache]=None
        ) -> Tuple:
            assert torch.all(attention_mask==1), "The input can't padded" # for this implementation only
            
        
        
        
        
        pass
    
    def forward(self,
                input_ids: torch.LongTensor = None,
                pixel_values: torch.FloatTensor = None,
                attention_mask: Optional[torch.Tensor] = None,
                kv_cache: Optional[KVCache] = None) -> Tuple :
        
        assert torch.all(attention_mask == 1), "The input can't be padded"
        
        # tokens ids to embeddings [this input embedding contain embedding of image place holder token that need to replace by original image embedding]
        input_embeddings = self.language_model.get_input_embeddings()(input_ids) # (B, Seq_len, embedding_size)
        
        # extract visual feature (b, patch_len, embedding )
        visual_embeddings = self.vision_model(pixel_values)
        
        # project visual embedding to a standard embedding size as text embedding
        visual_embeddings = self.multi_model_projector(visual_embeddings) # (b, patch_len, patch_embed) -> (b, patch_len, d_model) for merge between text and image embedding
        
        # replace image placeholder token embedding with real image embedding
        input_embeddings, attention_mask, position_ids =  self._merge_input_ids_with_image_features(visual_embeddings, input_embeddings, input_ids, attention_mask, kv_cache)
        
        outputs = self.language_model(
            attention_mask=attention_mask,
            position_ids=position_ids,
            input_embeddings=input_embeddings,
            kv_cache=kv_cache
        )
        return outputs
               
        
    
        
        

In [12]:
torch.rand(10, dtype=torch.float16).nan_items()

AttributeError: 'Tensor' object has no attribute 'nan_items'

In [17]:
torch.full(size=(3, 2, 2), fill_value=0, dtype=torch.float16, device='mps')

tensor([[[0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.]]], device='mps:0', dtype=torch.float16)