In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass

from pathlib import Path
from transformers import RobertaTokenizer, RobertaModel
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
from PIL import Image

# from transformers import ViTImageProcessor, ViTForImageClassification, ViTFeatureExtractor
import os
# from going_modular import engine
# import engine

from going_modular import engine
import pandas as pd
import numpy as np
import itertools

import albumentations as A
from albumentations.pytorch import ToTensorV2
from dataclasses import dataclass


  check_for_updates()


In [3]:
df = pd.read_csv('/content/drive/MyDrive/data/flickr8000/captions.txt', sep=',')
# df = df.sample(frac=0.01, random_state=42)

In [5]:
@dataclass
class llavaArgs:
    batch_size = 32
    device = 'cuda'
    vis_embd_out = 768
    text_embd_out = 768
    vocab_size = 50257
    block_size = 256
    lr = 1e-3
    text_hidden =  768 * 4

In [6]:
import torch
import torch.nn as nn
from transformers import CLIPModel, CLIPFeatureExtractor
from PIL import Image
import requests

# Vision model class using CLIP
class VisionModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").vision_model


        self.feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch16")

        self.multimodalVisionLayerProjector = nn.Linear(in_features=llavaArgs.vis_embd_out, out_features=llavaArgs.text_embd_out, device=llavaArgs.device) # Use proper dimensions

        self.main = nn.Sequential(
            nn.Flatten()
        )


        for p in self.model.parameters():
            p.requires_grad = False

    def forward(self, x):

        # inputs = self.feature_extractor(x['image'], return_tensors="pt")
        # inputs = inputs.to(llavaArgs.device)


        with torch.no_grad():
            outputs = self.model(x)


        x = outputs.pooler_output  # Get the pooled image embeddings (shape: [batch_size, 768])


        x = self.main(x)
        # return x
        return self.multimodalVisionLayerProjector(x)


In [None]:
#Language Decoder

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_quant_type="nf4",
#      bnb_4bit_compute_dtype=torch.bfloat16
#     )

tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")

class TextModel(nn.Module):
    def __init__(self):
        super().__init__()




        self.model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2", device_map='cuda', torch_dtype='auto', output_hidden_states=True)
        self.tokenizer = tokenizer
        self.linear_layer = nn.Linear(in_features=llavaArgs.text_embd_out, out_features=llavaArgs.vocab_size, device=llavaArgs.device, bias=False) # Takes in logits of dimensions- embeds_dims and converts it into dimension of vocab_size (logits in range of vocab_size)


        for p in self.model.parameters():
            p.requires_grad = False

    def forward(self, x, embeds=True):

        if(embeds):

          x = self.model(inputs_embeds=x).hidden_states[-1]
          x = self.linear_layer(x)
          return x
        else:
          x = self.model(input_ids = x['input_ids'], attention_mask = x['attention_mask'])
          return x






In [17]:
#Projector

class Projector(nn.Module):
    def __init__(self):

        super().__init__()

        self.linear_layer = nn.Linear(in_features=llavaArgs.vis_embd_out, out_features=llavaArgs.text_embd_out, device=llavaArgs.device)

    def forward(self, x):
        out = self.linear_layer(x)
        return out

In [19]:
train_transforms = A.Compose(
    [
        A.Resize(height=224, width=224),
        A.CenterCrop(height=224, width=224),
        # A.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], max_pixel_value=224.0,),
        # A.ToFloat(max_value=224),
        ToTensorV2(),
    ]
)

test_tyransforms = A.Compose(
    [
        # A.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], max_pixel_value=224.0,),
        # A.ToFloat(max_value=224),
        ToTensorV2(),
    ]
)

class CLiPDatatset(Dataset):
    def __init__(self, path):
        self.tokenizer = tokenizer
        self.path = path
        self.block_size = llavaArgs.block_size
        # self.dir = os.listdir(self.path)
    def __len__(self):

        return df.shape[0]

    def __getitem__(self, idx):

        self.tokenizer.pad_token = self.tokenizer.eos_token

        text, img = df.iloc[idx][1], df.iloc[idx][0]
        # print(text)
        # print(img)
        img_path = os.path.join(self.path, img)
        # print(img_path)
        img = np.array(Image.open(img_path))

        input_transformed = train_transforms(image = img)['image']

        text_tokenized = self.tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=llavaArgs.block_size)


        # Get the input_ids tensor
        input_ids = text_tokenized['input_ids'][0]  # Access the actual input IDs tensor

        if input_ids.size(0) > 1:
                    # Create x (input) and y (target) sequences for next word prediction
                    x = input_ids[:-1]  # All tokens except the last one
                    y = input_ids[1:]   # All tokens except the first one
        else:
            # If the sequence is too short, x and y will be identical
            x = input_ids
            y = input_ids

        # Apply block size limit
        if x.size(0) > self.block_size:
            x = x[:self.block_size]
            y = y[:self.block_size]

        # Padding x
        if x.size(0) < self.block_size:
            padding_size = self.block_size - x.size(0)
            x = torch.cat([x, torch.full((padding_size,), tokenizer.pad_token_id)])

        # Padding y
        if y.size(0) < self.block_size:
            padding_size = self.block_size - y.size(0)
            y = torch.cat([y, torch.full((padding_size,), tokenizer.pad_token_id)])


        attention_mask = torch.ones_like(x)
        torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(0), attention_mask)
        # Pack inputs and outputs in dictionaries
        x_encoded_items = {
            'input_ids': x,
            # 'attention_mask': text_tokenized['attention_mask'][0][:x.size(0)],  # Corresponding attention mask
            'attention_mask': attention_mask,
            'image': input_transformed,
        }
        y_encoded_items = {
            'input_ids': y,
        }

        return x_encoded_items, y_encoded_items





dir = '/content/drive/MyDrive/data/flickr8000/Images'
dataset = CLiPDatatset(dir)

# Assuming 'dataset' is already created
# Split the dataset into training and validation sets
train_size = int(0.2 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])


#Creating dataloaders

trainloader = DataLoader(train_dataset, batch_size=llavaArgs.batch_size, shuffle=True)
valloader = DataLoader(val_dataset, batch_size=llavaArgs.batch_size, shuffle=False)



In [20]:
class Llava(nn.Module):
    def __init__(self):

        super().__init__()


        self.projector = Projector()
        self.vision = VisionModel()
        self.lang = TextModel()
        # self.embd = nn.Embedding(num_embeddings=llavaArgs.vocab_size, embedding_dim=llavaArgs.text_embd_out, device=llavaArgs.device, dtype=torch.float32)

    def forward(self, x):
        # print(X)
        vis_out = self.vision(x['image']).unsqueeze(1)
        # print(vis_out.shape)
        text_out = self.lang(x, False).hidden_states[-1]


        combined = vis_out + text_out

        combined_out = self.lang(combined, embeds=True)


        return combined_out


In [22]:
llava = Llava()
llava.to(llavaArgs.device)



Llava(
  (projector): Projector(
    (linear_layer): Linear(in_features=768, out_features=768, bias=True)
  )
  (vision): VisionModel(
    (model): CLIPVisionTransformer(
      (embeddings): CLIPVisionEmbeddings(
        (patch_embedding): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), bias=False)
        (position_embedding): Embedding(197, 768)
      )
      (pre_layrnorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (encoder): CLIPEncoder(
        (layers): ModuleList(
          (0-11): 12 x CLIPEncoderLayer(
            (self_attn): CLIPSdpaAttention(
              (k_proj): Linear(in_features=768, out_features=768, bias=True)
              (v_proj): Linear(in_features=768, out_features=768, bias=True)
              (q_proj): Linear(in_features=768, out_features=768, bias=True)
              (out_proj): Linear(in_features=768, out_features=768, bias=True)
            )
            (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          

In [30]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Assuming 'llava' is an instance of the Llava model
total_params = count_parameters(llava)
print(f"Total trainable parameters: {total_params}")


Total trainable parameters: 39778560


In [31]:
optimizer = torch.optim.Adam(lr=llavaArgs.lr, params=llava.parameters())
loss_fn = nn.CrossEntropyLoss()

results = engine.train(model=llava,
                       writer=None,
                       train_dataloader=trainloader,
                       test_dataloader=valloader,
                       optimizer=optimizer,
                       loss_fn=loss_fn,
                       epochs=5,
                       device=llavaArgs.device)

  0%|          | 0/5 [00:00<?, ?it/s]

  text, img = df.iloc[idx][1], df.iloc[idx][0]


Epoch: 1 | train_loss: 0.7070 | test_loss: 0.2627 
Epoch: 2 | train_loss: 0.2632 | test_loss: 0.2507 
Epoch: 3 | train_loss: 0.2523 | test_loss: 0.2407 
Epoch: 4 | train_loss: 0.2428 | test_loss: 0.2315 
Epoch: 5 | train_loss: 0.2341 | test_loss: 0.2239 
