<a href="https://colab.research.google.com/github/Ramkanc/IIITHgrp20/blob/main/Caption_with_pytorch_CLIP_Transform.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import shutil

In [2]:
import nltk
nltk.download('punkt')
nltk.download('wordnet')
nltk.download('punkt_tab')  # Download punkt_tab
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


In [3]:
# Download the punkt_tab data package if it's not already downloaded

In [None]:
!wget https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip -P dataset/
!unzip dataset/Flickr8k_Dataset.zip -d dataset/

--2025-02-14 09:03:34--  https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip
Resolving github.com (github.com)... 20.205.243.166
Connecting to github.com (github.com)|20.205.243.166|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/124585957/47f52b80-3501-11e9-8f49-4515a2a3339b?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=releaseassetproduction%2F20250214%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250214T090334Z&X-Amz-Expires=300&X-Amz-Signature=0a5ce5dfbce3024eb1f7ae21d1d463ef25bc8a96f98a3d9bdc0f8dfc32650426&X-Amz-SignedHeaders=host&response-content-disposition=attachment%3B%20filename%3DFlickr8k_Dataset.zip&response-content-type=application%2Foctet-stream [following]
--2025-02-14 09:03:35--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/124585957/47f52b80-3501-11e9-8f49-4515a2a3339b?X-Amz-Algorithm=AWS4-HMAC-

In [None]:
!wget https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip -P dataset/
!unzip dataset/Flickr8k_text.zip -d dataset/

In [None]:
shutil.rmtree('dataset/__MACOSX', ignore_errors=True)
if os.path.exists('dataset/Flickr8k_Dataset.zip'):
    os.remove('dataset/Flickr8k_Dataset.zip')
if os.path.exists('dataset/Flickr8k_text.zip'):
    os.remove('dataset/Flickr8k_text.zip')

In [None]:
image_data_location = "/content/dataset/Flicker8k_Dataset"
caption_data_location = "/content/dataset/Flickr8k.token.txt"

In [None]:
# Collect images from image_data_location folder which are getting opened
images_good = []
for filename in os.listdir(image_data_location):
    if filename.endswith(('.jpg', '.jpeg', '.png')):
        try:
            img = Image.open(os.path.join(image_data_location, filename))
            images_good.append(filename.split(".")[0]) #append to a list
        except IOError:
            print("Unable to open image:", filename)

print(f"Loaded {len(images_good)} images")

In [None]:
raw_df = pd.read_csv(caption_data_location, sep="\t", header=None, names=['image','caption'])

In [None]:
# Clean image names
raw_df['image'] = raw_df['image'].str.split('.').str[0]

raw_df.head()

In [None]:
# Find df[image] rows not in images_good list
indices_not_in_list = []
for index, image_name in raw_df['image'].items():
    if image_name not in images_good:
        print(image_name)
        indices_not_in_list.append(index)

In [None]:
df = raw_df.drop(indices_not_in_list, inplace=False)
print(df.head())
print(f"raw df shape - {raw_df.shape}")
print(f"new df shape - {df.shape}")


In [None]:
# Find the maximum words string from df['captions'] column
df['word_count'] = df['caption'].apply(lambda x: len(x.split()))
max_words_string = df.loc[df['word_count'].idxmax(), 'caption']
avg_words = df['word_count'].mean()
print(f"The string with the maximum words is:\n{max_words_string}")
print(f"the length of the string is {len(max_words_string.split())}")
print(f"The average length of the string is {avg_words}")

In [None]:
data_idx = 11
image_name = df.iloc[data_idx,0]
image_path = image_data_location + "/" + image_name+".jpg"
print(image_path)
img = mpimg.imread(image_path)
plt.imshow(img)
plt.show()

In [None]:
for i in range(data_idx, data_idx+5):
    print(f"Caption - {df.iloc[i,1]}")

In [None]:
def create_image_caption_dict(df):
    image_caption_dict = {}
    for image, group in df.groupby('image'):
        image_caption_dict[image] = group['caption'].tolist()
    return image_caption_dict

In [None]:
image_captions = create_image_caption_dict(df)

In [None]:
image_captions_iter = iter(image_captions.items())
print(len(image_captions))
print (next(image_captions_iter))

In [None]:
# Preprocess dataset
image_folder = image_data_location
dataframe = df

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
embed_dim = 512  # Embedding dimension of CLIP
hidden_dim = 512  # Hidden dimension of LSTM
tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32").tokenizer
vocab_size = tokenizer.vocab_size
print(f"length of tokenizer: {vocab_size}")

In [None]:
# Image Transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
# Preprocess images and captions
images_tensor = []
captions_tensor = []

In [None]:
for k,v in image_captions.items():
    img_path = os.path.join(image_folder, k+".jpg")
    img = Image.open(img_path).convert('RGB')
    img = transform(img)
    images_tensor.append(img)
    for caption in v:
        tokenized_caption = tokenizer(caption, return_tensors="pt", padding='max_length', truncation=True, max_length=16).input_ids.squeeze() # maximum length was 20
        captions_tensor.append(tokenized_caption)

images_tensor = torch.stack(images_tensor).to(device)
captions_tensor = torch.stack(captions_tensor).to(device)

In [None]:
print(f"Shape of captions_tensor: {captions_tensor.shape}")
print(f"Shape of images_tensor: {images_tensor.shape}")

In [None]:
# Dataset Class
class ImageCaptionDataset(Dataset):
    def __init__(self, images_tensor, captions_tensor):
        self.images_tensor = images_tensor
        self.captions_tensor = captions_tensor

    def __len__(self):
        return len(self.images_tensor)

    def __getitem__(self, idx):
        image = self.images_tensor[idx].to(torch.float32)
        caption = self.captions_tensor[idx]
        return image, caption

In [None]:
# Split dataset
train_size = int(0.8 * len(images_tensor))
val_size = int(0.1 * len(images_tensor))
test_size = len(images_tensor) - train_size - val_size

train_cap_size = int(0.8 * len(captions_tensor))
val_cap_size = int(0.1 * len(captions_tensor))
test_cap_size = len(captions_tensor) - train_cap_size - val_cap_size

train_images, val_images, test_images = torch.utils.data.random_split(images_tensor, [train_size, val_size, test_size])
train_captions, val_captions, test_captions = torch.utils.data.random_split(captions_tensor, [train_cap_size, val_cap_size, test_cap_size])

# Create Datasets and Loaders
train_dataset = ImageCaptionDataset(train_images, train_captions)
val_dataset = ImageCaptionDataset(val_images, val_captions)
test_dataset = ImageCaptionDataset(test_images, test_captions)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

print(type(train_dataset[0]))
print(len(train_dataset[0]))
print(train_dataset[0][0].shape)
print(train_dataset[0][1].shape)

In [None]:
# Encoder: CLIP
class CLIPEncoder(nn.Module):
    def __init__(self, model_name="openai/clip-vit-base-patch32"):
        super(CLIPEncoder, self).__init__()
        self.clip_model = CLIPModel.from_pretrained(model_name)
        self.processor = CLIPProcessor.from_pretrained(model_name)


    def forward(self, images):
        with torch.no_grad():
            image_features = self.clip_model.get_image_features(images)

        return image_features


In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [None]:
# Decoder: GPT-2
class GPT2Decoder(nn.Module):
    def __init__(self, model_name="gpt2"):
        super(GPT2Decoder, self).__init__()
        self.decoder = GPT2LMHeadModel.from_pretrained(model_name)
        self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token  # Set pad token for GPT-2

    def forward(self, features, captions):
        # Prepare inputs for GPT-2 decoder
        input_ids = captions
        attention_mask = input_ids.ne(self.tokenizer.pad_token_id).type(torch.float32)

        # Project CLIP features to GPT-2's hidden state space (optional)
        # projection_layer = nn.Linear(clip_embedding_dim, gpt2_hidden_dim)
        # projected_features = projection_layer(features)

        # Pass image features as past key values (optional)
        # past_key_values = tuple(
        #     tuple(past_key_value.unsqueeze(0))  # Add batch dimension
        #     for past_key_value in projected_features.chunk(2, dim=-1)  # Split into key and value
        # )

        # Generate outputs from the decoder
        outputs = self.decoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            #past_key_values=past_key_values,  # Optional: Pass image features as past key values
        )
        return outputs.logits

In [None]:
# Caption Generation
def generate_caption(image_tensor, encoder, decoder, max_length=20):
    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        image_features = encoder(image_tensor.unsqueeze(0))  # Encode the image
        caption_ids = [decoder.tokenizer.bos_token_id]  # Start with BOS token
        for _ in range(max_length):
            caption_tensor = torch.tensor([caption_ids]).to(image_features.device)  # Move to the same device
            output = decoder(image_features, caption_tensor)
            predicted_id = output.argmax(2)[:, -1].item()
            if predicted_id == decoder.tokenizer.eos_token_id:  # Stop when EOS token is predicted
                break
            caption_ids.append(predicted_id)
        decoded_caption = decoder.tokenizer.decode(caption_ids[1:], skip_special_tokens=True)  # Decode, skipping BOS
        return decoded_caption

In [None]:
# Assuming you have an image_tensor and device
encoder = CLIPEncoder().to(device)
decoder = GPT2Decoder().to(device)
#decoder = GPT2DecoderWithCrossAttention().to(device)
#caption = generate_caption(image_tensor, encoder, decoder)
#print(caption)

In [None]:
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=0.001)  # Combine parameter

In [None]:
# Training Loop
num_epochs = 50  # Adjust as needed
for epoch in range(num_epochs):
    encoder.train()
    decoder.train()
    total_loss = 0
    for images, captions in train_loader:
        images, captions = images.to(device), captions.to(device)

        # Encode images
        features = encoder(images)

        # Prepare decoder inputs
        input_ids = captions[:, :-1]
        targets = captions[:, 1:]

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass through decoder
        outputs = decoder(features, input_ids)

        # Calculate loss
        outputs = outputs.view(-1, outputs.size(-1))
        targets = targets.reshape(-1)
        loss = criterion(outputs, targets)

        # Backpropagation and optimization
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {total_loss / len(train_loader):.4f}")

    # Validation Loop
    encoder.eval()
    decoder.eval()
    total_val_loss = 0
    with torch.no_grad():
        for images, captions in val_loader:
            images, captions = images.to(device), captions.to(device)

            # Encode images
            features = encoder(images)

            # Prepare decoder inputs
            input_ids = captions[:, :-1]
            targets = captions[:, 1:]

            # Forward pass through decoder
            outputs = decoder(features, input_ids)

            # Calculate loss
            outputs = outputs.view(-1, outputs.size(-1))
            targets = targets.reshape(-1)
            loss = criterion(outputs, targets)

            total_val_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Val Loss: {total_val_loss / len(val_loader):.4f}")


In [None]:
# prompt: provide code to save trained model

# Save the trained models
torch.save(encoder.state_dict(), 'encoder_model_clip.pth')
torch.save(decoder.state_dict(), 'decoder_model_gpt2.pth')
print("Model saved successfully!")


In [None]:
# prompt: provide code to download the saved models

from google.colab import files
files.download('encoder_model_clip.pth')
files.download('decoder_model_gpt2.pth')


In [None]:
# prompt: provide code for test loop with generating the captions

# Test Loop
encoder.eval()
decoder.eval()
smoothing = SmoothingFunction().method1

total_bleu4 = 0

with torch.no_grad():
    for images, captions in test_loader:
        images, captions = images.to(device), captions.to(device)

        # Generate captions
        generated_captions = []
        for image in images:
            generated_caption = generate_caption(image, encoder, decoder)
            generated_captions.append(generated_caption)

        # Calculate BLEU scores
        for i, caption in enumerate(captions):
            reference = tokenizer.decode(caption, skip_special_tokens=True)
            candidate = generated_captions[i]
            reference_tokens = nltk.word_tokenize(reference.lower())
            candidate_tokens = nltk.word_tokenize(candidate.lower())
            bleu4_score = sentence_bleu([reference_tokens], candidate_tokens, smoothing_function=smoothing)
            total_bleu4 += bleu4_score
            print(f"Reference: {reference}")
            print(f"Generated: {candidate}")
            print(f"BLEU-4 score: {bleu4_score}")

average_bleu4 = total_bleu4 / len(test_dataset)
print(f"Average BLEU-4 score on test set: {average_bleu4:.4f}")


Reference: silver and blue car marked 1 0 4 raises dust on road as two
Generated: im divid Sundim Billperueosp Billgetue funding fielded fielded fielded fielded fielded fielded fielded fielded
BLEU-4 score: 0
Reference: a soccer team poses for a picture .
Generated: im divid Sundim Billperueosp Billgetue funding fielded fielded fielded fielded fielded fielded fielded fielded
BLEU-4 score: 0
Reference: a boy in a green shirt is jumping with his arms in the air
Generated: im divid Sundim Billperueosp Billgetue funding fielded fielded fielded fielded fielded fielded fielded fielded
BLEU-4 score: 0
Reference: a woman with dark hair walks through a grassy field on a path with
Generated: im divid Sundim Billperueosp Billgetue funding fielded fielded fielded fielded fielded fielded fielded fielded
BLEU-4 score: 0
Reference: a brown - haired child in green shoes swings on a swing in a
Generated: im divid Sundim Billperueosp Billgetue funding fielded fielded fielded fielded fielded fielded fiel