<a href="https://colab.research.google.com/github/Vignesh-397/Image_Captioning/blob/main/BLIP/BLIP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
!pip install datasets
!pip install transformers
!pip install peft
!pip install bitsandbytes

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.13.0->peft)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.13.0->peft)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.13.0->peft)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.13.0->peft)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.13.0->peft)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=1.13.0->peft)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting

Collecting bitsandbytes
  Downloading bitsandbytes-0.45.4-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Downloading bitsandbytes-0.45.4-py3-none-manylinux_2_24_x86_64.whl (76.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.0/76.0 MB[0m [31m29.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.45.4


In [1]:
import os
import torch
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from datasets import Dataset as HFDataset
from transformers import AutoProcessor, Blip2ForConditionalGeneration, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model

In [2]:
captions_file = "/content/drive/MyDrive/Major Project 2024/Dataset/Custom_Dataset/Image Captioning Data.xlsx"
image_folder = "/content/drive/MyDrive/Major Project 2024/Dataset/Custom_Dataset/Images"

In [3]:
df = pd.read_excel(captions_file)

# Convert captions to a dictionary
image_captions = {}
for _, row in df.iterrows():
    image_captions[row["Name"]] = [row[f"Caption-{i}"] for i in range(1, 6)]

In [4]:
# Image transformations
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [38]:
class ImageCaptioningDataset(Dataset):
    def __init__(self, image_folder, captions, processor):
        self.image_folder = image_folder
        self.captions = captions
        self.processor = processor
        self.image_filenames = list(self.captions.keys())
        #self.image_filenames = [int(os.path.splitext(filename)[0]) for filename in self.image_filenames] #change 1

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_name = self.image_filenames[idx]
        # Format image name to include leading zeros and .jpg extension
        image_name_formatted = f"{int(image_name):03d}.jpg"
        image_path = os.path.join(self.image_folder, image_name_formatted)

        # Load and process the image
        image = Image.open(image_path).convert("RGB")
        image = image_transform(image)

        # Select a random caption
        # Use the original image name (image_name) to access captions,
        # ensuring consistency with the keys in the self.captions dictionary
        caption = self.captions[image_name][torch.randint(0, 5, (1,)).item()]

        encoding = self.processor(images=image, padding="max_length", return_tensors="pt")
        encoding = {k: v.squeeze() for k, v in encoding.items()}
        encoding["text"] = caption

        return encoding

In [43]:
def collate_fn(batch):
    processed_batch = {}
    for key in batch[0].keys():
        if key != "text":
            processed_batch[key] = torch.stack([example[key] for example in batch])
        else:
            # Ensure that the text input is a list of strings
            text_inputs = processor.tokenizer(
                [str(example["text"]) for example in batch],  # Convert to string
                padding=True,
                return_tensors="pt"
            )
            processed_batch["input_ids"] = text_inputs["input_ids"]
            processed_batch["attention_mask"] = text_inputs["attention_mask"]
    return processed_batch

In [32]:
# Load BLIP-2 model and processor
quant_config = BitsAndBytesConfig(load_in_8bit=True)
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained(
    "ybelkada/blip2-opt-2.7b-fp16-sharded",
    device_map="auto",
    quantization_config=quant_config
)



Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [33]:
# Apply LoRA
config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules=["q_proj", "k_proj"]
)
model = get_peft_model(model, config)
model.print_trainable_parameters()

trainable params: 5,242,880 || all params: 3,749,922,816 || trainable%: 0.1398


In [44]:
# Load dataset
train_dataset = ImageCaptioningDataset(image_folder, image_captions, processor)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=3, collate_fn=collate_fn)

# Define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)


In [45]:

# Training Loop
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.train()

PeftModel(
  (base_model): LoraModel(
    (model): Blip2ForConditionalGeneration(
      (vision_model): Blip2VisionModel(
        (embeddings): Blip2VisionEmbeddings(
          (patch_embedding): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14))
        )
        (encoder): Blip2Encoder(
          (layers): ModuleList(
            (0-38): 39 x Blip2EncoderLayer(
              (self_attn): Blip2Attention(
                (dropout): Dropout(p=0.0, inplace=False)
                (qkv): Linear8bitLt(in_features=1408, out_features=4224, bias=True)
                (projection): Linear8bitLt(in_features=1408, out_features=1408, bias=True)
              )
              (layer_norm1): LayerNorm((1408,), eps=1e-05, elementwise_affine=True)
              (mlp): Blip2MLP(
                (activation_fn): GELUActivation()
                (fc1): Linear8bitLt(in_features=1408, out_features=6144, bias=True)
                (fc2): Linear8bitLt(in_features=6144, out_features=1408, bias=True)
      

In [46]:
for epoch in range(10):  # Adjust epochs as needed
    print(f"Epoch {epoch+1}:")
    for idx, batch in enumerate(train_dataloader):
        input_ids = batch.pop("input_ids").to(device)
        # Change the dtype to torch.float32 when moving to device
        pixel_values = batch.pop("pixel_values").to(device, torch.float32) # Changed to float32

        # Cast the model to float32 before the forward pass
        with torch.autocast(device_type=device, dtype=torch.float32): # New line to force float32
            outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=input_ids)
        loss = outputs.loss

        print(f"Batch {idx+1}, Loss: {loss.item()}")

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

Epoch 1:




Batch 1, Loss: 3.2914822101593018
Batch 2, Loss: 2.8990631103515625
Batch 3, Loss: 2.9845874309539795
Batch 4, Loss: 2.824375867843628
Batch 5, Loss: 2.1578121185302734
Batch 6, Loss: 2.5487258434295654
Batch 7, Loss: 2.501188039779663
Batch 8, Loss: 2.3063182830810547
Batch 9, Loss: 3.972632884979248
Batch 10, Loss: 2.7438015937805176
Batch 11, Loss: 2.2262442111968994
Batch 12, Loss: 2.807004451751709
Batch 13, Loss: 2.2125654220581055
Batch 14, Loss: 2.403764009475708
Batch 15, Loss: 3.4833078384399414
Batch 16, Loss: 2.9445085525512695
Batch 17, Loss: 4.217895984649658
Batch 18, Loss: 1.90729558467865
Batch 19, Loss: 1.803754210472107
Batch 20, Loss: 2.274237871170044
Batch 21, Loss: 3.139176368713379
Batch 22, Loss: 2.4953346252441406
Batch 23, Loss: 2.0718533992767334
Batch 24, Loss: 1.9580312967300415
Batch 25, Loss: 2.3912527561187744
Batch 26, Loss: 2.333970546722412
Batch 27, Loss: 2.3069827556610107
Batch 28, Loss: 3.240786552429199
Batch 29, Loss: 2.140280246734619
Batch 30

KeyboardInterrupt: 