<a href="https://colab.research.google.com/github/abdul-basit-ai/JAX_for_LLMs/blob/main/CLIP_Jax_from_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# !pip install -U "jax[tpu]"
# !pip install git+https://github.com/google/flax.git
# print("Done:")

In [3]:
import jax
import jax.numpy as jnp
from jax import random, grad, jit, vmap
from flax import nnx
import optax
from typing import Sequence, Union, Tuple, Optional, Any
import numpy as np

In [4]:
from dataclasses import dataclass
#Imports the dataclass decorator from the standard Python library.
#This decorator automatically generates special methods based on the type-annotated class variables

@dataclass(frozen=True)#Applies the decorator and sets the frozen parameter to True.Mean the class is immutable
class CLIPConfig:

    projection_dim: int = 512 #output embedding dim
    # Vision Encoder (ViT-B/32)
    vision_image_size: int = 224
    vision_patch_size: int = 32
    vision_num_layers: int = 12
    vision_num_heads: int = 12
    vision_hidden_size: int = 768 #input embedding
    vision_mlp_dim: int = 3072
    vision_dropout_rate: float = 0.0

    # Text Encoder
    text_vocab_size: int = 49408
    text_max_position_embeddings: int = 77
    text_num_layers: int = 12
    text_num_heads: int = 12
    text_hidden_size: int = 512
    text_mlp_dim: int = 2048
    text_dropout_rate: float = 0.0

CONFIG = CLIPConfig()

In [5]:
class AttentionBlock(nnx.Module):
    #Following Pre Normalization Architecture
    def __init__(self, embed_dim:int, num_heads:int, dropout_rate:float,*,rngs = nnx.Rngs):#Model Initialization
        #Making sure that embed_dim is divisible by num_heads, * measn the rngs should be passed explicitly with name not only by position
        assert embed_dim % num_heads == 0 , print("Make sure embed_dim divisible by num_head")
        self.head_dim = embed_dim // num_heads
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate
        self.rngs = rngs

        #Layer Norm before Attention Block
        self.norm = nnx.LayerNorm(embed_dim)

        #QKV, We split it later for efficency and project it now
        # Remember we will pass rngs where ever there will be weights and biases , cuz Jax doesn't do it autommatically
        #So we have to pass it explicitlt to generate random numbers for weights etc
        self.qkv = nnx.Linear(embed_dim,embed_dim * 3, rngs = rngs)

        #OutputProject,  Combines heads back to get same Matrix
        self.out_proj = nnx.Linear(embed_dim, embed_dim, rngs = rngs)

        #Dropout
        self.dropout = nnx.Dropout(dropout_rate)

    def _attention(self, x: jax.Array, rngs: nnx.Rngs,deterministic: bool) -> jax.Array:#return type hint, Deterministic : False, dropout active, else otherwise
        qkv_out = self.qkv(x) # Batch_size, Seq_len,Embed_dim * 3
        q, k, v = jnp.split(qkv_out,3,axis=-1) # Now, 3 *(B,S,E) axis = -1 the split happens at first dimension, B

        def rearrange_for_attention(tensor):#To reshape and Permute for MHA
            # B,S,E --> B,S,N_H * H_D
            tensor = tensor.reshape(tensor.shape[0], tensor.shape[1], self.num_heads,self.head_dim)
            return jnp.transpose((tensor, (0,2,1,3)))
        q_heads = rearrange_for_attention(q)
        k_heads = rearrange_for_attention(k)
        v_heads = rearrange_for_attention(v)

        attn_weights = jnp.matmul(q_heads,jnp.swapaxes(k_heads,-1,-1))
        attn_weights = attn_weights/jnp.sqrt(self.head_dim)
        attn_weights = jax.nn.softmax(attn_weights, axis = -1)
        context = jnp.matmul(attn_weights,v_heads)

        #Reshape to originial
        context = jnp.transpose(context, (0,2,1,3))
        #Concatenate Heads
        context = context.reshape(context.shape[0], context[1],self.embed_dim)
        output = self.out_proj(context)
        output = self.dropout(output, rngs = rngs, deterministic = deterministic,)
        return output

        #Forward in Jax
    def __call__(self, x:jax.Array, *, deterministic:bool=False,rngs = nnx.Rngs)-> jax.Array:
        norm_x = self.norm(x)
        attn_out = self._attention(norm_x,rngs,deterministic=deterministic)
        residual = x + attn_out
        return residual



In [6]:
class MLP_Block(nnx.Module):
    def ___init__(self,mlp_dim:int,out_dim:int, dropout_rate:float,*,rngs = nnx.Rngs):
        #expand dim
        self.Linear1 = nnx.Linear(out_dim,mlp_dim,rngs=rngs)
        #project to original
        self.Linear2 = nnx.Linear(mlp_dim,out_dim,rngs=rngs)
        self.dropout = nnx.Dropout(dropout_rate)

    def __call__(self, x:jax.Array, *,deterministic:bool=False, rngs = nnx.Rngs)-> jax.Array:
        x = self.Linear1(x)
        x = nnx.gelu(x)
        x= self.Linear2(x)
        x = self.dropout(x, rngs = rngs, deterministic = deterministic)
        return x

In [7]:
class Transformer_Encoder_Layer(nnx.Module):
    def __init__(self,embed_dim:int,num_heads:int,mlp_dim:int,dropout_rate:float,*,rngs = nnx.Rngs):
        self.attn = AttentionBlock(embed_dim,num_heads,dropout_rate,rngs=rngs)
        self.mlp_norm = nnx.LayerNorm(embed_dim)
        self.mlp = MLP_Block(mlp_dim,embed_dim,dropout_rate,rngs=rngs)

    def __call__(self, x:jax.Array, *,deterministic:bool=False, rngs = nnx.Rngs)-> jax.Array:
        x = self.attn(x,deterministic=deterministic,rngs=rngs)
        norm_x = self.mlp_norm(x)
        mlp_out = self.mlp(norm_x,deterministic=deterministic,rngs=rngs)
        x = x + mlp_out
        return x


In [8]:
class Transformer_Encoder(nnx.Module):
    def __init__(self, embed_dim:int, num_layers:int, num_heads:int, mlp_dim:int, dropout_rate:float,*,rngs = nnx.Rngs):
        #Stack encoder layers
        self.layers:Sequence[Transformer_Encoder_Layer] = [Transformer_Encoder_Layer(embed_dim=embed_dim,num_heads=num_heads,
                                                                                    mlp_dim=mlp_dim,dropout_rate=dropout_rate)
        for _ in range(num_layers)]

        self.final_norm = nnx.LayerNorm(embed_dim)

    def __call__(self,x:jax.Array,*,deterministic:bool=False,rngs = nnx.Rngs)-> jax.Array:
        for layer in self.layers:
            x = layer(x,deterministic=deterministic,rngs=rngs)
            x = self.final_norm(x)
            return x


In [9]:
class CLIP_Text_Encoder(nnx.Module):
    def __init__(self, vocab_size: int,embed_dim: int,hidden_dim: int,
                 num_layers: int,num_heads: int,max_position_embeddings: int,dropout_rate: float,
                 *,rngs: nnx.Rngs):
        self.max_position_embeddings = max_position_embeddings
        self.hidden_dim = hidden_dim

        #Embeddings
        self.token_embedding = nnx.Embed(vocab_size = vocab_size,num_features = hidden_dim, rngs = rngs)
        #Positional Encoding (Learned)
        self.positional_encoding = nnx.Param(jax,random.normal(rngs.params(), (max_position_embeddings, hidden_dim))* 0.02)

        #Trasnformer Stack
        self.transformer = Transformer_Encoder(embed_dim=hidden_dim, num_layers= num_layers, num_heads=num_heads,mlp_dim = hidden_dim*4,dropout_rate=dropout_rate, rngs=rngs)
        self.final_projrction = nnx.Linear(hidden_dim,embed_dim,rngs=rngs)

    def __call__(self, input_ids = jnp.ndarray, attention_mask : Optional[jnp.ndarray]=None,*,deterministic:bool=False, rngs = nnx.Rngs)->jnp.ndarray:
        sequence_length = input_ids.shape[1]
        token_embedd = self.token_embedding(input_ids)
        if sequence_length > self.max_position_embeddings:
            raise ValueError(
                f"Sequence length ({sequence_length}) exceeds max position embeddings "
                f"({self.max_position_embeddings})." )
        # Add positional embeddings to the token embeddings
        # We slice the learned positional matrix to match the current sequence length.
        positional_embeds = self.positional_encoding.value[:sequence_length,:]
        hidden_states = token_embedd + positional_embeds

        #Running through Transformer Stack
        encoded_output = self.transformer(hidden_states,deterministic=deterministic,rngs=rngs)

        #Pooling (The CLIP Text Pooling Strategy)
        if attention_mask is not None:
            # This works because the attention mask is typically 1 for tokens, 0 for padding.
            eos_indices = jnp.sum(attention_mask, axis=-1) - 1
            # Use jnp.take_along_axis for advanced indexing to get the EOS vector for each batch item
            # The indices need to be reshaped to (Batch, 1, 1) to match dimensions for indexing
            eos_indices = eos_indices[:, jnp.newaxis, jnp.newaxis]

            # Extract the EOS vector from the sequence
            pooled_output = jnp.take_along_axis(encoded_output, eos_indices, axis=1)
            # Reshape from (Batch, 1, Hidden_Dim) to (Batch, Hidden_Dim)
            pooled_output = jnp.squeeze(pooled_output, axis=1)
        else:
            # Fallback: Just take the last element if no mask is provided (simpler models)
            pooled_output = encoded_output[:, -1, :]

        # 4. Final Projection
        # Map the pooled vector to the fixed-size latent space for contrastive learning
        final_embedding = self.final_projection(pooled_output)

        return final_embedding



In [11]:
!git add clip_vision_encoder.py

fatal: not a git repository (or any of the parent directories): .git
