In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
# from torchtune.modules import RMSNorm
from tokenizers import Tokenizer
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.utils.data import random_split

import albumentations as A
from albumentations.pytorch import ToTensorV2
from dataclasses import dataclass
from going_modular.going_modular import engine


In [3]:
@dataclass
class PaligemmaArgs:
    batch_size = 8
    device = 'cpu'
    vis_embd_out = 768
    text_embd_out = 3072
    # text_embd_out = 768
    vocab_size = 256000
    # vocab_size = 50257
    block_size = 8094
    lr = 1e-3
    text_hidden =  768 * 4
    img_seq_len = 256
    # lr = 1e-3

In [4]:
#Loading SigLip Vision Encoder

import torch
import torch.nn as nn
from transformers import CLIPModel, CLIPFeatureExtractor

# Vision model class using CLIP
class VisionModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").vision_model


        self.feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch16")

        self.multimodalVisionLayerProjector = nn.Linear(in_features=PaligemmaArgs.vis_embd_out, out_features=PaligemmaArgs.text_embd_out, device=PaligemmaArgs.device) # Use proper dimensions

        self.main = nn.Sequential(
            nn.Flatten()
        )


        for p in self.model.parameters():
            p.requires_grad = False

    def forward(self, x):

    

        with torch.no_grad():
            outputs = self.model(x)


        x = outputs.pooler_output  # Get the pooled image embeddings (shape: [batch_size, 768])


        x = self.main(x)
        # return x
        return self.multimodalVisionLayerProjector(x)


In [None]:
#Language Decoder

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
     bnb_4bit_compute_dtype=torch.bfloat16
    )

tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
# tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-7b", device_map='cuda', torch_dtype='auto', output_hidden_states=True, quantization_config=bnb_config)
# model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2", device_map=PaligemmaArgs.device, torch_dtype='auto', output_hidden_states=True)
class TextModel(nn.Module):
    def __init__(self):
        super().__init__()


        self.tokenizer = tokenizer
        self.linear_layer = nn.Linear(in_features=PaligemmaArgs.text_embd_out, out_features=PaligemmaArgs.vocab_size, device=PaligemmaArgs.device, bias=False) # Takes in logits of dimensions- embeds_dims and converts it into dimension of vocab_size (logits in range of vocab_size)


        for p in model.parameters():
            p.requires_grad = False

    def forward(self, x, embeds=True):

        if(embeds):
          # print("Inputs ids shape:", x['input_ids'].shape)
          # print("Attention shape:", x['attention_mask'].shape)
          x = model(inputs_embeds=x).hidden_states[-1]
          x = self.linear_layer(x)
          return x
        else:
          # print("Inputs ids shape:", x['input_ids'].shape)
          # print("Attention shape:",x['attention_mask'].shape)
          x = model(input_ids = x['input_ids'], attention_mask = x['attention_mask'])
          return x




In [7]:
import pandas as pd
df = pd.read_csv('/content/drive/MyDrive/data/flickr8000/captions.txt', sep=",")

In [None]:
df

In [30]:
# # Define the special tokens
IMAGE_TOKEN = "<image>"
SEP = "<sep>"

# Add the special tokens to the tokenizer
tokens_to_add = {"additional_special_tokens": [IMAGE_TOKEN, SEP]}

# Add the tokens to the tokenizer
tokenizer.add_special_tokens(tokens_to_add)

# Resize the model's token embeddings to match the tokenizer's vocabulary
model.resize_token_embeddings(len(tokenizer))
PaligemmaArgs.vocab_size = len(tokenizer)


In [32]:
import os
import numpy as np
from PIL import Image


train_transforms = A.Compose(
    [
        A.Resize(height=224, width=224),
        A.CenterCrop(height=224, width=224),
        A.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], max_pixel_value=224.0,),
        A.ToFloat(max_value=224),
        ToTensorV2(),
    ]
)

def custom_collate_fn(batch):
    # Extract input_ids, attention_mask, and images
    batch_input_ids = [item[0]['input_ids'] for item in batch]
    batch_attention_mask = [item[0]['attention_mask'] for item in batch]
    batch_images = [item[0]['image'] for item in batch]

    # Pad input_ids and attention_mask to the longest sequence in the batch
    input_ids_padded = torch.nn.utils.rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    attention_mask_padded = torch.nn.utils.rnn.pad_sequence(batch_attention_mask, batch_first=True, padding_value=0)

    # Stack images if their sizes are consistent, else resize them to a uniform size
    images_padded = torch.stack(batch_images)

    return {
        'input_ids': input_ids_padded,
        'attention_mask': attention_mask_padded,
        'image': images_padded
    }



test_tyransforms = A.Compose(
    [
        A.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], max_pixel_value=224.0,),
        A.ToFloat(max_value=224),
        ToTensorV2(),
    ]
)



#IMP!!!
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False

class PaliGemmaDataset(Dataset):
    def __init__(self):
        super(PaliGemmaDataset, self).__init__(
        )

        # self.dir = os.listdir('/content/drive/MyDrive/data/flickr8000/captions.txt')
        self.tokenizer = tokenizer
        self.block_size = PaligemmaArgs.block_size


    def __len__(self):

        return df.shape[0]

    def __getitem__(self, idx):

        txt, img = df.iloc[idx, 1], df.iloc[idx, 0]

        img_path = os.path.join('/content/drive/MyDrive/data/flickr8000/Images/', img)
        # print(img_path)
        img = np.array(Image.open(img_path))

        input_transformed = train_transforms(image = img)['image']


        prompt = "Explain the different components of the picture."
        temp1 = tokenizer.encode("<image>") * PaligemmaArgs.img_seq_len + tokenizer.encode(prompt) + tokenizer.encode('<bos>')
        temp2 = tokenizer.encode(txt) + tokenizer.encode('<eos>')
        eos = tokenizer.encode('<eos>')
        sep = tokenizer.encode('<sep>')


        x = temp2[1:]
        y = torch.tensor(temp2[:-1])
        x = temp1 + sep + x + eos
        x = torch.tensor(x)

     

        if x.size(0) > self.block_size:
            x = x[:self.block_size]
            y = y[:self.block_size]

        if(len(x) < self.block_size):
            for i in range(self.block_size - len(x)):
                # x += tokenizer.encode('<pad>')
                x = torch.cat([x, torch.tensor(tokenizer.encode('<pad>'))])


        if(len(y) < self.block_size):
            for i in range(self.block_size - len(y)):
                # y += tokenizer.encode('<pad>')

                y = torch.cat([y, torch.tensor(tokenizer.encode('<pad>'))])

        # Padding x
        # if x.size(0) < self.block_size:
        #     padding_size = self.block_size - x.size(0)
        #     # print(padding_size)
        #     x = torch.cat([x, torch.full((padding_size,), tokenizer.pad_token_id)])

        # # Padding y
        # if y.size(0) < self.block_size:
        #     padding_size = self.block_size - y.size(0)
        #     y = torch.cat([y, torch.full((padding_size,), tokenizer.pad_token_id)])


        attn_1 = torch.zeros_like(torch.tensor(temp1))
        # torch.where(temp1 == tokenizer.pad_token_id, torch.tensor(0), attn_1)

        attn_2 = torch.cat([ torch.tensor([0]), torch.ones_like(torch.tensor(x)), torch.tensor([0])])
        # if(len(final_attn) < self.block_size):
        #     for i in range(self.block_size - len(final_attn)):
        final_attn = torch.cat([attn_1 , attn_2])
        #         # final_attn += torch.tensor(tokenizer.encode('<pad>'))
        #         final_attn = torch.cat([final_attn, torch.tensor(tokenizer.encode('<pad>'))])


        final_attn = final_attn[:self.block_size]

        x_values = {
            'input_ids' : x,
            'attention_mask': final_attn,
            'image' : input_transformed,

        }
        y_values = {
            "input_ids": y
        }
        x_values = {key: torch.tensor(value) for key, value in x_values.items()}
        y_values = {key: torch.tensor(value) for key, value in y_values.items()}

       
        return x_values, y_values

In [33]:
#Creating an instance of the dataset class
dataset = PaliGemmaDataset()

# Assuming 'dataset' is already created
# Split the dataset into training and validation sets
train_size = int(0.2 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

import os
#Creating a dataloader
# Create DataLoaders for training and validation
train_loader = DataLoader(train_dataset, batch_size=PaligemmaArgs.batch_size, shuffle=True,  pin_memory=False, num_workers=os.cpu_count())
val_loader = DataLoader(val_dataset, batch_size=PaligemmaArgs.batch_size, shuffle=True,  pin_memory=False, num_workers=os.cpu_count())



In [None]:
sample = next(iter(train_loader))
print(sample)

In [35]:
class PaliGemma(nn.Module):
    def __init__(self):

        super().__init__()
        self.vision = VisionModel()
        self.lang = TextModel()
        # self.linear_proj = PaliGemmaVisionProjector()

    def forward(self, x) :

        vis_out = self.vision(x['image']).unsqueeze(1)
        # print(vis_out.shape)
        text_out = self.lang(x, False).hidden_states[-1]


        combined = vis_out + text_out

        combined_out = self.lang(combined, embeds=True)

        return combined_out

In [None]:
pali = PaliGemma()
pali.to(PaligemmaArgs.device)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Assuming 'llava' is an instance of the Llava model
total_params = count_parameters(pali)
print(f"Total trainable parameters: {total_params}")


In [None]:
optimizer = torch.optim.Adam(lr=PaligemmaArgs.lr, params=pali.parameters())
loss_fn = nn.CrossEntropyLoss()

results = engine.train(model=pali,
                       writer=None,
                       train_dataloader=train_loader,
                       test_dataloader=val_loader,
                       optimizer=optimizer,
                       loss_fn=loss_fn,
                       epochs=5,
                       device=PaligemmaArgs.device)