In [2]:
import base64
import io
import pandas as pd
from PIL import Image
import torchvision.transforms as transforms
import torch
import torch.nn as nn  # why
from torch.nn import functional as F # why
from torch.nn import init  # why

In [3]:
class PatchEmbeddings(nn.Module):

    def __init__(self , img_size=96 , patch_size=16 , hidden_dim=512):
        super().__init__()

        self.img_size = img_size

        self.patch_size = patch_size

        self.num_patches = (img_size // patch_size)**2

        self.conv = nn.Conv2d(in_channels=3, out_channels=hidden_dim,kernel_size=patch_size,stride=patch_size)


    def forward(self,X):

        X = self.conv(X)

        X = X.flatten(2)

        X = X.transpose(1,2)

        return X

        
 


In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
text_path = "./input.txt"

with open(text_path , 'r' , encoding='utf-8') as f :
    text = f.read()

chars = sorted(list(set(text)))

stoi = { ch:i for i,ch in enumerate(chars)}

stoi[''] = 65

itos = { i:ch for i,ch in enumerate(chars)}

itos[65] = ''
# what does lambda function do ?
encode = lambda s: [stoi[c] for c in s]

decode = lambda l: ''.join([itos[i] for i in l])

vocab_size = len(stoi.keys())

In [6]:
img_size , patch_size , num_hiddens , batch_size = 96,16,512,4

patch_embeddings = PatchEmbeddings(img_size,patch_size,num_hiddens)

X = torch.zeros(batch_size , 3, img_size , img_size)

patch_embeddings(X).shape

torch.Size([4, 36, 512])

lets define a MLP class now

In [7]:
class MLP(nn.Module):

    def __init__(self , n_embd , dropout=0.1 , is_decoder=True):

        super().__init__()

        layers = [

                nn.Linear(n_embd , 4 * n_embd),

                nn.ReLU() if is_decoder else nn.GELU(),

                nn.Linear(4 * n_embd , n_embd),

                nn.Dropout(dropout)
                ]

        self.net = nn.Sequential(*layers)


    def forward(self , x):
        return self.net(x)

# lets check what will happen if we give it a input embedding of size 128

In [8]:
n_embd = 128
testmlp = MLP(n_embd)
mlp_input = torch.zeros(batch_size, 3, n_embd)
testmlp_out = testmlp(mlp_input)
testmlp_out.shape

torch.Size([4, 3, 128])

In [9]:
class Head(nn.Module):

    def __init__(self , n_embd , head_size , dropout = 0.1 , is_decoder = False):
        super().__init__()

        self.key = nn.Linear(n_embd , head_size , bias=False)

        self.query = nn.Linear(n_embd,head_size,bias=False)

        self.value = nn.Linear(n_embd,head_size,bias=False)

        self.dropout = nn.Dropout(dropout)

        self.is_decoder = is_decoder

    
    def forward(self , x):

        B , T , C = x.shape

        k = self.key(x)
        q = self.query(x)
        v = self.value(x)


        wei = q @ k.transpose(-2,-1) * ( C ** -0.5)

        if self.is_decoder:

            tril = torch.tril(torch.ones(T,T, dtype = torch.bool , device = x.device))

            wei = wei.masked_fill(tril == 0 , float('-inf'))

        wei = F.softmax(wei , dim=-1)

        wei = self.dropout(wei)

        out = wei @ v

        return out




In [10]:
n_embd, head_size, batch_size = 128, 16, 4

testhead = Head(n_embd, head_size)
head_input = torch.zeros(batch_size, 3, n_embd)
testhead_out = testhead(head_input)
testhead_out.shape 

torch.Size([4, 3, 16])

In [11]:
class MultiHeadAttention(nn.Module):

    def __init__(self , n_embd , num_heads , dropout=0.1 , is_decoder = False):

            super().__init__()

            assert n_embd % num_heads == 0 , "n_embd must be divisible by num_heads"

            self.heads = nn.ModuleList([

                Head(n_embd , n_embd // num_heads , dropout , is_decoder)
                for _ in range(num_heads)


            ])

            self.proj = nn.Linear(n_embd , n_embd)

            self.dropout = nn.Dropout(dropout)

    

    def forward(self , x):

        head_outputs = [h(x) for h in self.heads]

        out = torch.cat(head_outputs , dim = -1)

        out = self.proj(out)

        out = self.dropout(out)

        return out

In [12]:
n_embd, n_head = 128, 8
testmha = MultiHeadAttention(n_embd, n_head)
head_input = torch.zeros(batch_size, 3, n_embd)
testmha_out = testmha(head_input)
testmha_out.shape

torch.Size([4, 3, 128])

# Now lets code the encoder block

In [13]:
class Block(nn.Module):

    def __init__(self, n_embd , num_heads , dropout = 0.1 , is_decoder = False):

        super().__init__()

        self.ln1 = nn.LayerNorm(n_embd)

        self.attn = MultiHeadAttention(n_embd , num_heads , dropout , is_decoder)

        self.ln2 = nn.LayerNorm(n_embd)

        self.ffn = nn.Sequential(

                nn.Linear(n_embd , 4 * n_embd),
                nn.GELU(),
                nn.Linear(4 * n_embd , n_embd) ,

        )


    def forward(self , x):
        original_x = x

        x = self.ln1(x)

        attn_output = self.attn(x)

        x = original_x + attn_output

        x = self.ln2(x)

        ffn_output = self.ffn(x)

        x = x + ffn_output

        return x


In [14]:
n_embd, head_size, batch_size = 128, 16, 4

testblock = Block(n_embd, n_head)
block_input = torch.zeros(batch_size, 3, n_embd)
testblock_out = testblock(block_input)
testblock_out.shape

torch.Size([4, 3, 128])

Now lets use all the chunks together to create a ViT


In [15]:
class ViT(nn.Module):

    def __init__(self , img_size , patch_size , num_hiddens,num_heads , num_blks , emb_dropout, blk_dropout):
        super().__init__()

        self.patch_embedding = PatchEmbeddings(img_size, patch_size , num_hiddens)

        self.cls_token = nn.Parameter(torch.zeros(1,1,num_hiddens))

        num_patches = (img_size // patch_size)**2

        self.pos_embedding = nn.Parameter(torch.randn(1,num_patches + 1 , num_hiddens)) # what does nn.Parameter do??
        self.dropout = nn.Dropout(emb_dropout)

        self.blocks = nn.ModuleList([Block(num_hiddens, num_heads, blk_dropout, is_decoder=False) for _ in range(num_blks)])

        self.layer_norm = nn.LayerNorm(num_hiddens)

    
    def forward(self , X):

        x = self.patch_embedding(X)

        cls_tokens = self.cls_token.expand(x.shape[0], -1 , -1)

        x = torch.cat((cls_tokens,x) , dim = 1)

        x += self.pos_embedding

        x = self.dropout(x)

        for blocks in self.blocks:

            x = blocks(x)

        x = self.layer_norm(x[:,0])

        return x




In [16]:
img_size, patch_size, num_hiddens, n_head, num_blks, dropout = 96, 16, 512, 8, 3, 0.1

testvit = ViT(img_size, patch_size, num_hiddens, n_head, num_blks, dropout, dropout)
vit_input = torch.zeros(batch_size, 3, img_size, img_size)
testvit_out = testvit(vit_input)
testvit_out.shape

torch.Size([4, 512])

### Now we have the embedding from the images and we need to concatenate with the text embedding , but we can't do that directly . first we need to project the dimensionality of image embeddings to dimensionality of text embeddings

# Lets do that

In [17]:
class MultiModalProjector(nn.Module):
    def __init__(self, n_embd, image_embed_dim, dropout=0.1):
        super().__init__()

        
        self.net = nn.Sequential(
            
            nn.Linear(image_embed_dim, 4 * image_embed_dim),

            
            nn.GELU(),

            
            nn.Linear(4 * image_embed_dim, n_embd),

            
            nn.Dropout(dropout)
        )

    def forward(self, x):
        
        x = self.net(x)
        return x

In [18]:
n_embd,num_hiddens = 128, 512

testmmp = MultiModalProjector(n_embd,num_hiddens)
mmp_input = testvit_out
testmmp_out = testmmp(mmp_input)
testmmp_out.shape

torch.Size([4, 128])

In [19]:
class DecoderLanguageModel(nn.Module):

    def __init__(self, n_embd, image_embed_dim, vocab_size, num_heads, n_layer, use_images=False):
        super().__init__()

        self.use_images = use_images

        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)

        # Position embedding table
        self.position_embedding_table = nn.Embedding(1000, n_embd)

        if use_images:
            # Image projection layer to align image embeddings with text embeddings
            self.image_projection = MultiModalProjector(n_embd, image_embed_dim)

        # Stack of transformer decoder blocks
        self.blocks = nn.Sequential(*[Block(n_embd, num_heads, is_decoder=True) for _ in range(n_layer)])

        # Final layer normalization
        self.ln_f = nn.LayerNorm(n_embd)

        # Language modeling head
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, image_embeds=None, targets=None):
        # Get token embeddings from the input indices
        tok_emb = self.token_embedding_table(idx)

        if self.use_images and image_embeds is not None:
            # Project and concatenate image embeddings with token embeddings
            img_emb = self.image_projection(image_embeds).unsqueeze(1)
            tok_emb = torch.cat([img_emb, tok_emb], dim=1)

        # Get position embeddings
        pos_emb = self.position_embedding_table(torch.arange(tok_emb.size(1), device=device)).unsqueeze(0)

        # Add position embeddings to token embeddings
        x = tok_emb + pos_emb

        # Pass through the transformer decoder blocks
        x = self.blocks(x)

        # Apply final layer normalization
        x = self.ln_f(x)

        # Get the logits from the language modeling head
        logits = self.lm_head(x)

        if targets is not None:
            if self.use_images and image_embeds is not None:
                # Prepare targets by concatenating a dummy target for the image embedding
                batch_size = idx.size(0)
                targets = torch.cat([torch.full((batch_size, 1), -100, dtype=torch.long, device=device), targets], dim=1)

            # Compute the cross-entropy loss
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100)
            return logits, loss

        return logits

    def generate(self, idx, image_embeds, max_new_tokens):
        # Get the batch size and sequence length
        B, T = idx.shape

        # Initialize the generated sequence with the input indices
        generated = idx

        if self.use_images and image_embeds is not None:
            # Project and concatenate image embeddings with token embeddings
            img_emb = self.image_projection(image_embeds).unsqueeze(1)
            current_output = torch.cat([img_emb, self.token_embedding_table(idx)], dim=1)
        else:
            current_output = self.token_embedding_table(idx)

        # Generate new tokens iteratively
        for i in range(max_new_tokens):
            # Get the current sequence length
            T_current = current_output.size(1)

            # Get position embeddings for the current sequence length
            current_pos_emb = self.position_embedding_table(torch.arange(T_current, device=device)).unsqueeze(0)

            
            current_output += current_pos_emb

            
            for block in self.blocks:
                current_output = block(current_output)

            
            logits = self.lm_head(current_output[:, -1, :])

           
            probs = F.softmax(logits, dim=-1)

            
            idx_next = torch.multinomial(probs, num_samples=1)

            
            generated = torch.cat((generated, idx_next), dim=1)

            
            idx_next_emb = self.token_embedding_table(idx_next)

            
            current_output = torch.cat((current_output, idx_next_emb), dim=1)

        return generated


In [20]:
model = DecoderLanguageModel(n_embd= 128 , image_embed_dim= 256 , vocab_size= 1000 , num_heads=8 , n_layer= 6 , use_images= True)

model.to(device)

B, T = 10 , 50

idx = torch.randint(0,1000, (B,T)).to(device)

image_embeds = torch.randn(B, 256).to(device)

targets = torch.randint(0, vocab_size, (B, T)).to(device) 

if targets is not None:
    logits, loss = model(idx, image_embeds, targets)
    print(f"Logits shape: {logits.shape}, Loss: {loss}")
else:
    logits = model(idx, image_embeds)  # Call without targets
    print(f"Logits shape: {logits.shape}")

generated = model.generate(idx, image_embeds, max_new_tokens=20)
print(f"Generated sequence shape: {generated.shape}")



Logits shape: torch.Size([10, 51, 1000]), Loss: 7.0855607986450195
Generated sequence shape: torch.Size([10, 70])


### Now lets put all the blocks together

In [21]:
class VisionLanguageModel(nn.Module):
    def __init__(self, n_embd, image_embed_dim, vocab_size, n_layer, img_size, patch_size, num_heads, num_blks, emb_dropout, blk_dropout):
        super().__init__()

        # Set num_hiddens equal to image_embed_dim
        num_hiddens = image_embed_dim

        # Assert that num_hiddens is divisible by num_heads
        assert num_hiddens % num_heads == 0, "num_hiddens must be divisible by num_heads"

        # Initialize the vision encoder (ViT)
        self.vision_encoder = ViT(img_size, patch_size, num_hiddens, num_heads, num_blks, emb_dropout, blk_dropout)

        # Initialize the language model decoder (DecoderLanguageModel)
        self.decoder = DecoderLanguageModel(n_embd, image_embed_dim, vocab_size, num_heads, n_layer, use_images=True)

    def forward(self, img_array, idx, targets=None):
        # Get the image embeddings from the vision encoder
        image_embeds = self.vision_encoder(img_array)

        # Check if the image embeddings are valid
        if image_embeds.nelement() == 0 or image_embeds.shape[1] == 0:
            raise ValueError("Something is wrong with the ViT model. It's returning an empty tensor or the embedding dimension is empty.")

        if targets is not None:
            # If targets are provided, compute the logits and loss
            logits, loss = self.decoder(idx, image_embeds, targets)
            return logits, loss
        else:
            # If targets are not provided, compute only the logits
            logits = self.decoder(idx, image_embeds)
            return logits

    def generate(self, img_array, idx, max_new_tokens):
        # Get the image embeddings from the vision encoder
        image_embeds = self.vision_encoder(img_array)

        # Check if the image embeddings are valid
        if image_embeds.nelement() == 0 or image_embeds.shape[1] == 0:
            raise ValueError("Something is wrong with the ViT model. It's returning an empty tensor or the embedding dimension is empty.")

        # Generate new tokens using the language model decoder
        generated_tokens = self.decoder.generate(idx, image_embeds, max_new_tokens)
        return generated_tokens

In [22]:
image_embed_dim = num_hiddens

n_layer, block_size =  8, 32

# Initialize the model
model = VisionLanguageModel(n_embd, image_embed_dim, vocab_size,  n_layer, img_size, patch_size, n_head, num_blks, dropout, dropout)
model.to(device)

# Create dummy data with correct dimensions
dummy_img = torch.randn(1, 3, img_size, img_size).to(device)  # Correct shape for image input
dummy_idx = torch.randint(0, vocab_size, (1, block_size)).to(device)  # Correct shape for text input

# Forward pass to initialize all parameters
try:
    output = model(dummy_img, dummy_idx)  # Output for debugging
    print("Output from initialization forward pass:", output)
except RuntimeError as e:
    print(f"Runtime Error during forward pass: {str(e)}")
    print("Check layer configurations and input shapes.")

Output from initialization forward pass: tensor([[[-0.5262, -0.0951, -0.9254,  ..., -0.1929, -0.3010, -0.1736],
         [-0.3884, -0.2352, -0.2910,  ...,  0.1570,  0.2544,  0.7743],
         [-0.0957, -0.5236,  0.1347,  ...,  0.4445,  0.2568,  0.4157],
         ...,
         [ 0.3338,  0.4779, -0.5849,  ..., -0.4378, -0.2896,  0.1432],
         [-0.3068,  0.6018,  0.8893,  ...,  0.3438, -0.6156, -0.8650],
         [-0.5292,  0.1760, -0.1767,  ..., -0.8456, -0.0648,  0.2146]]],
       device='cuda:0', grad_fn=<ViewBackward0>)


In [23]:

def base64_to_tensor(base64_str, img_size=96):
    image = Image.open(io.BytesIO(base64.b64decode(base64_str)))
    if image.mode != 'RGB':
        image = image.convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return transform(image).unsqueeze(0)

In [24]:
def get_batch(df, batch_size, split='train', img_size=96, val_batch_size=8):
    # Split data into training and validation sets
    n = int(0.9 * len(df))  # first 90% will be train, rest val
    df_train = df.iloc[:n]
    df_val = df.iloc[n:]
    data = df_train if split == 'train' else df_val
    batch_size = batch_size if split == 'train' else val_batch_size
    replace = False if split == 'train' else True
    batch = data.sample(n=batch_size, replace=replace)

    images = torch.cat([base64_to_tensor(img, img_size) for img in batch['b64string_images']], dim=0).to(device)
    text_indices = [torch.tensor(encode(desc), dtype=torch.long) for desc in batch['caption']]
    max_length = max(len(t) for t in text_indices)

    padded_text = torch.full((batch_size, max_length), fill_value=stoi[''], dtype=torch.long).to(device)
    for i, text in enumerate(text_indices):
        padded_text[i, :len(text)] = text

    targets = torch.cat([padded_text[:, 1:], torch.full((batch_size, 1), fill_value=stoi[''], dtype=torch.long, device=device)], dim=1)

    # Truncate or pad targets to match the length of padded_text
    if targets.size(1) > padded_text.size(1):
        targets = targets[:, :padded_text.size(1)]
    elif targets.size(1) < padded_text.size(1):
        targets = torch.cat([targets, torch.full((batch_size, padded_text.size(1) - targets.size(1)), fill_value=stoi[''], dtype=torch.long, device=device)], dim=1)

    return images, padded_text, targets

In [25]:
def train_model(model, df, epochs, vocab_size, img_size=96):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    model.to(device)
    for epoch in range(epochs):
        model.train()
        for _ in range(max_iters):
            images, idx, targets = get_batch(df, batch_size, 'train', img_size)
            optimizer.zero_grad()
            logits, loss = model(images, idx, targets)
            loss.backward()
            optimizer.step()
            if _ % eval_interval == 0:
                print(f"Loss at iteration {_}: {loss.item()}")
        val_loss = estimate_loss(model, df, 'val', img_size, val_batch_size=8)
        print(f"Validation Loss after epoch {epoch}: {val_loss}")

def estimate_loss(model, df, split, img_size=96, val_batch_size=8):
    losses = []
    model.eval()
    for _ in range(eval_iters):
        images, idx, targets = get_batch(df, batch_size, split, img_size, val_batch_size=val_batch_size)
        _, loss = model(images, idx, targets)
        losses.append(loss.item())
    return sum(losses) / len(losses)

In [28]:
import numpy as np
import pandas as pd
import base64
import io
from PIL import Image
import torch
from torchvision import transforms

# Assume existing stoi, itos, and vocab_size are already defined

# Modify the encode function to handle unknown characters
def encode(s):
    return [stoi.get(c, stoi['']) for c in s]  # Use '' for unknown characters

# Update get_batch function
def get_batch(df, batch_size, split='train', img_size=96, val_batch_size=None):
    # Split data into training and validation sets
    n = int(0.9 * len(df))  # first 90% will be train, rest val
    df_train = df.iloc[:n]
    df_val = df.iloc[n:]
    
    data = df_train if split == 'train' else df_val
    current_batch_size = batch_size if split == 'train' else (val_batch_size or batch_size)
    replace = False if split == 'train' else True
    
    batch = data.sample(n=current_batch_size, replace=replace)
    
    images = torch.cat([base64_to_tensor(img, img_size) for img in batch['b64string_images']], dim=0).to(device)
    text_indices = [torch.tensor(encode(desc), dtype=torch.long) for desc in batch['caption']]
    max_length = max(len(t) for t in text_indices)
    
    padded_text = torch.full((current_batch_size, max_length), fill_value=stoi[''], dtype=torch.long).to(device)
    for i, text in enumerate(text_indices):
        padded_text[i, :len(text)] = text
    
    targets = torch.cat([padded_text[:, 1:], torch.full((current_batch_size, 1), fill_value=stoi[''], dtype=torch.long, device=device)], dim=1)
    
    # Truncate or pad targets to match the length of padded_text
    if targets.size(1) > padded_text.size(1):
        targets = targets[:, :padded_text.size(1)]
    elif targets.size(1) < padded_text.size(1):
        targets = torch.cat([targets, torch.full((current_batch_size, padded_text.size(1) - targets.size(1)), fill_value=stoi[''], dtype=torch.long, device=device)], dim=1)
    
    return images, padded_text, targets

# The rest of your code remains the same
batch_size = 16
block_size = 32
max_iters = 100
eval_interval = 10
learning_rate = 1e-3
epochs = 1
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 40
num_blks = 3
head_size = 16
n_embd = 128
n_head = 8
n_layer = 8
dropout = 0.1
img_size = 96
patch_size = 16
image_embed_dim = 512
emb_dropout = blk_dropout = 0.1

# Assuming you have defined VisionLanguageModel and train_model functions
model = VisionLanguageModel(n_embd, image_embed_dim, vocab_size, n_layer, img_size, patch_size, n_head, num_blks, emb_dropout, blk_dropout)
model.to(device)

# Dummy data to initialize lazy modules
dummy_img = torch.randn(1, 3, img_size, img_size).to(device)
dummy_idx = torch.randint(0, vocab_size, (1, block_size)).to(device)
model(dummy_img, dummy_idx)  # Forward pass to initialize all parameters

# Train the model
train_model(model, df, epochs, vocab_size, img_size)

Loss at iteration 0: 4.3069024085998535
Loss at iteration 10: 0.1717265248298645
Loss at iteration 20: 0.08332188427448273
Loss at iteration 30: 0.06375972926616669
Loss at iteration 40: 0.06932108849287033
Loss at iteration 50: 0.08294254541397095
Loss at iteration 60: 0.07798963040113449
Loss at iteration 70: 0.05572696775197983
Loss at iteration 80: 0.03254299238324165
Loss at iteration 90: 0.0669950470328331
Validation Loss after epoch 0: 0.03557506361976266
