In [1]:
import base64
import io
import pandas as pd
from PIL import Image
import torchvision.transforms as transforms
import torch
import torch.nn as nn  # why
from torch.nn import functional as F # why
from torch.nn import init  # why

In [2]:
class PatchEmbeddings(nn.Module):

    def __init__(self , img_size=96 , patch_size=16 , hidden_dim=512):
        super().__init__()

        self.img_size = img_size

        self.patch_size = patch_size

        self.num_patches = (img_size // patch_size)**2

        self.conv = nn.Conv2d(in_channels=3, out_channels=hidden_dim,kernel_size=patch_size,stride=patch_size)


    def forward(self,X):

        X = self.conv(X)

        X = X.flatten(2)

        X = X.transpose(1,2)

        return X

        
 


In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
text_path = "./input.txt"

with open(text_path , 'r' , encoding='utf-8') as f :
    text = f.read()

chars = sorted(list(set(text)))

stoi = { ch:i for i,ch in enumerate(chars)}

stoi[''] = 65

itos = { i:ch for i,ch in enumerate(chars)}

itos[65] = ''
# what does lambda function do ?
encode = lambda s: [stoi[c] for c in s]

decode = lambda l: ''.join([itos[i] for i in l])

vocab_size = len(stoi.keys())

In [5]:
img_size , patch_size , num_hiddens , batch_size = 96,16,512,4

patch_embeddings = PatchEmbeddings(img_size,patch_size,num_hiddens)

X = torch.zeros(batch_size , 3, img_size , img_size)

patch_embeddings(X).shape

torch.Size([4, 36, 512])

lets define a MLP class now

In [6]:
class MLP(nn.Module):

    def __init__(self , n_embd , dropout=0.1 , is_decoder=True):

        super().__init__()

        layers = [

                nn.Linear(n_embd , 4 * n_embd),

                nn.ReLU() if is_decoder else nn.GELU(),

                nn.Linear(4 * n_embd , n_embd),

                nn.Dropout(dropout)
                ]

        self.net = nn.Sequential(*layers)


    def forward(self , x):
        return self.net(x)

# lets check what will happen if we give it a input embedding of size 128

In [7]:
n_embd = 128
testmlp = MLP(n_embd)
mlp_input = torch.zeros(batch_size, 3, n_embd)
testmlp_out = testmlp(mlp_input)
testmlp_out.shape

torch.Size([4, 3, 128])

In [8]:
class Head(nn.Module):

    def __init__(self , n_embd , head_size , dropout = 0.1 , is_decoder = False):
        super().__init__()

        self.key = nn.Linear(n_embd , head_size , bias=False)

        self.query = nn.Linear(n_embd,head_size,bias=False)

        self.value = nn.Linear(n_embd,head_size,bias=False)

        self.dropout = nn.Dropout(dropout)

        self.is_decoder = is_decoder

    
    def forward(self , x):

        B , T , C = x.shape

        k = self.key(x)
        q = self.query(x)
        v = self.value(x)


        wei = q @ k.transpose(-2,-1) * ( C ** -0.5)

        if self.is_decoder:

            tril = torch.tril(torch.ones(T,T, dtype = torch.bool , device = x.device))

            wei = wei.masked_fill(trill == 0 , float('-inf'))

        wei = F.softmax(wei , dim=-1)

        wei = self.dropout(wei)

        out = wei @ v

        return out




In [9]:
n_embd, head_size, batch_size = 128, 16, 4

testhead = Head(n_embd, head_size)
head_input = torch.zeros(batch_size, 3, n_embd)
testhead_out = testhead(head_input)
testhead_out.shape 

torch.Size([4, 3, 16])

In [10]:
class MultiHeadAttention(nn.Module):

    def __init__(self , n_embd , num_heads , dropout=0.1 , is_decoder = False):

            super().__init__()

            assert n_embd % num_heads == 0 , "n_embd must be divisible by num_heads"

            self.heads = nn.ModuleList([

                Head(n_embd , n_embd // num_heads , dropout , is_decoder)
                for _ in range(num_heads)


            ])

            self.proj = nn.Linear(n_embd , n_embd)

            self.dropout = nn.Dropout(dropout)

    

    def forward(self , x):

        head_outputs = [h(x) for h in self.heads]

        out = torch.cat(head_outputs , dim = -1)

        out = self.proj(out)

        out = self.dropout(out)

        return out

In [11]:
n_embd, n_head = 128, 8
testmha = MultiHeadAttention(n_embd, n_head)
head_input = torch.zeros(batch_size, 3, n_embd)
testmha_out = testmha(head_input)
testmha_out.shape

torch.Size([4, 3, 128])

# Now lets code the encoder block

In [12]:
class Block(nn.Module):

    def __init__(self, n_embd , num_heads , dropout = 0.1 , is_decoder = False):

        super().__init__()

        self.ln1 = nn.LayerNorm(n_embd)

        self.attn = MultiHeadAttention(n_embd , num_heads , dropout , is_decoder)

        self.ln2 = nn.LayerNorm(n_embd)

        self.ffn = nn.Sequential(

                nn.Linear(n_embd , 4 * n_embd),
                nn.GELU(),
                nn.Linear(4 * n_embd , n_embd) ,

        )


    def forward(self , x):
        original_x = x

        x = self.ln1(x)

        attn_output = self.attn(x)

        x = original_x + attn_output

        x = self.ln2(x)

        ffn_output = self.ffn(x)

        x = x + ffn_output

        return x


In [13]:
n_embd, head_size, batch_size = 128, 16, 4

testblock = Block(n_embd, n_head)
block_input = torch.zeros(batch_size, 3, n_embd)
testblock_out = testblock(block_input)
testblock_out.shape

torch.Size([4, 3, 128])

Now lets use all the chunks together to create a ViT


In [14]:
class ViT(nn.Module):

    def __init__(self , img_size , patch_size , num_hiddens,num_heads , num_blks , emb_dropout, blk_dropout):
        super().__init__()

        self.patch_embedding = PatchEmbeddings(img_size, patch_size , num_hiddens)

        self.cls_token = nn.Parameter(torch.zeros(1,1,num_hiddens))

        num_patches = (img_size // patch_size)**2

        self.pos_embedding = nn.Parameter(torch.randn(1,num_patches + 1 , num_hiddens)) # what does nn.Parameter do??
        self.dropout = nn.Dropout(emb_dropout)

        self.blocks = nn.ModuleList([Block(num_hiddens, num_heads, blk_dropout, is_decoder=False) for _ in range(num_blks)])

        self.layer_norm = nn.LayerNorm(num_hiddens)

    
    def forward(self , X):

        x = self.patch_embedding(X)

        cls_tokens = self.cls_token.expand(x.shape[0], -1 , -1)

        x = torch.cat((cls_tokens,x) , dim = 1)

        x += self.pos_embedding

        x = self.dropout(x)

        for blocks in self.blocks:

            x = blocks(x)

        x = self.layer_norm(x[:,0])

        return x




In [15]:
img_size, patch_size, num_hiddens, n_head, num_blks, dropout = 96, 16, 512, 8, 3, 0.1

testvit = ViT(img_size, patch_size, num_hiddens, n_head, num_blks, dropout, dropout)
vit_input = torch.zeros(batch_size, 3, img_size, img_size)
testvit_out = testvit(vit_input)
testvit_out.shape

torch.Size([4, 512])

### Now we have the embedding from the images and we need to concatenate with the text embedding , but we can't do that directly . first we need to project the dimensionality of image embeddings to dimensionality of text embeddings

# Lets do that

In [16]:
class MultiModalProjector(nn.Module):
    def __init__(self, n_embd, image_embed_dim, dropout=0.1):
        super().__init__()

        
        self.net = nn.Sequential(
            
            nn.Linear(image_embed_dim, 4 * image_embed_dim),

            
            nn.GELU(),

            
            nn.Linear(4 * image_embed_dim, n_embd),

            
            nn.Dropout(dropout)
        )

    def forward(self, x):
        
        x = self.net(x)
        return x

In [17]:
n_embd,num_hiddens = 128, 512

testmmp = MultiModalProjector(n_embd,num_hiddens)
mmp_input = testvit_out
testmmp_out = testmmp(mmp_input)
testmmp_out.shape

torch.Size([4, 128])

In [18]:
class DecoderLanguageModel(nn.Module):

    def __init__(self, n_embd, image_embed_dim, vocab_size, num_heads, n_layer, use_images=False):
        super().__init__()

        self.use_images = use_images

        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)

        # Position embedding table
        self.position_embedding_table = nn.Embedding(1000, n_embd)

        if use_images:
            # Image projection layer to align image embeddings with text embeddings
            self.image_projection = MultiModalProjector(n_embd, image_embed_dim)

        # Stack of transformer decoder blocks
        self.blocks = nn.Sequential(*[Block(n_embd, num_heads, is_decoder=True) for _ in range(n_layer)])

        # Final layer normalization
        self.ln_f = nn.LayerNorm(n_embd)

        # Language modeling head
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, image_embeds=None, targets=None):
        # Get token embeddings from the input indices
        tok_emb = self.token_embedding_table(idx)

        if self.use_images and image_embeds is not None:
            # Project and concatenate image embeddings with token embeddings
            img_emb = self.image_projection(image_embeds).unsqueeze(1)
            tok_emb = torch.cat([img_emb, tok_emb], dim=1)

        # Get position embeddings
        pos_emb = self.position_embedding_table(torch.arange(tok_emb.size(1), device=device)).unsqueeze(0)

        # Add position embeddings to token embeddings
        x = tok_emb + pos_emb

        # Pass through the transformer decoder blocks
        x = self.blocks(x)

        # Apply final layer normalization
        x = self.ln_f(x)

        # Get the logits from the language modeling head
        logits = self.lm_head(x)

        if targets is not None:
            if self.use_images and image_embeds is not None:
                # Prepare targets by concatenating a dummy target for the image embedding
                batch_size = idx.size(0)
                targets = torch.cat([torch.full((batch_size, 1), -100, dtype=torch.long, device=device), targets], dim=1)

            # Compute the cross-entropy loss
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100)
            return logits, loss

        return logits

    def generate(self, idx, image_embeds, max_new_tokens):
        # Get the batch size and sequence length
        B, T = idx.shape

        # Initialize the generated sequence with the input indices
        generated = idx

        if self.use_images and image_embeds is not None:
            # Project and concatenate image embeddings with token embeddings
            img_emb = self.image_projection(image_embeds).unsqueeze(1)
            current_output = torch.cat([img_emb, self.token_embedding_table(idx)], dim=1)
        else:
            current_output = self.token_embedding_table(idx)

        # Generate new tokens iteratively
        for i in range(max_new_tokens):
            # Get the current sequence length
            T_current = current_output.size(1)

            # Get position embeddings for the current sequence length
            current_pos_emb = self.position_embedding_table(torch.arange(T_current, device=device)).unsqueeze(0)

            
            current_output += current_pos_emb

            
            for block in self.blocks:
                current_output = block(current_output)

            
            logits = self.lm_head(current_output[:, -1, :])

           
            probs = F.softmax(logits, dim=-1)

            
            idx_next = torch.multinomial(probs, num_samples=1)

            
            generated = torch.cat((generated, idx_next), dim=1)

            
            idx_next_emb = self.token_embedding_table(idx_next)

            
            current_output = torch.cat((current_output, idx_next_emb), dim=1)

        return generated
