In [15]:
import os
import zipfile
from PIL import Image
from tqdm import tqdm
import random
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from torchvision.transforms import Compose, Resize, ToTensor

from transformers import ViTModel, ViTConfig, GPT2Tokenizer, GPT2LMHeadModel

import gdown

# ----------------------------- #
#         Configuration          #
# ----------------------------- #

# Set random seeds for reproducibility
torch.manual_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [2]:
def download_and_extract(url, output_zip, extract_to):
    """
    Downloads and extracts a zip file from the given URL.
    """
    if not os.path.exists(output_zip):
        print("Downloading dataset...")
        gdown.download(url, output_zip, quiet=False)
    else:
        print(f"{output_zip} already exists. Skipping download.")

    if not os.path.exists(extract_to):
        print("Extracting dataset...")
        os.makedirs(extract_to, exist_ok=True)
        with zipfile.ZipFile(output_zip, 'r') as zip_ref:
            zip_ref.extractall(extract_to)
        print(f"Dataset extracted to: {extract_to}")
    else:
        print(f"{extract_to} already exists. Skipping extraction.")

# Flickr8k dataset Google Drive URL
# Ensure this URL points to a zip file containing the 'flicker8k' folder with images and text files
url = "https://drive.google.com/uc?id=1iFgG55ZUR1ZO-BrIc5PQ1AhrWa0NhQVZ"
output_zip = "flickr8k.zip"
extract_to = "./flickr8k"

# Download and extract the dataset
download_and_extract(url, output_zip, extract_to)

Downloading dataset...


Downloading...
From (original): https://drive.google.com/uc?id=1iFgG55ZUR1ZO-BrIc5PQ1AhrWa0NhQVZ
From (redirected): https://drive.google.com/uc?id=1iFgG55ZUR1ZO-BrIc5PQ1AhrWa0NhQVZ&confirm=t&uuid=05f23842-c96b-4df4-a2c4-018aa89ebe06
To: /content/flickr8k.zip
100%|██████████| 1.12G/1.12G [00:08<00:00, 134MB/s]


Extracting dataset...
Dataset extracted to: ./flickr8k


In [10]:
class Flickr8kDataset(Dataset):
    """
    Flickr8k Dataset for Image Captioning.
    """
    def __init__(self, image_dir, captions_file, image_list_file, transform=None):
        self.image_dir = image_dir
        self.transform = transform

        # Load the image list
        with open(image_list_file, "r") as file:
            self.image_list = set(line.strip() for line in file.readlines())

        # Load captions
        self.image_captions = {}
        with open(captions_file, "r") as file:
            for line in file.readlines():
                image_caption = line.strip().split("\t")
                if len(image_caption) != 2:
                    continue  # Skip malformed lines
                image_id_caption_number = image_caption[0]
                caption = image_caption[1].strip()
                image_id = image_id_caption_number.split("#")[0]
                if image_id in self.image_list:
                    if image_id in self.image_captions:
                        self.image_captions[image_id].append(caption)
                    else:
                        self.image_captions[image_id] = [caption]

        self.image_ids = list(self.image_captions.keys())
        print(f"Number of images in dataset: {len(self.image_ids)}")

    def get_image_ids(self):
        return self.image_ids

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image_path = os.path.join(self.image_dir, image_id)

        # Load and transform image
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        # Get a single caption (randomly selected)
        captions = self.image_captions[image_id]
        caption = random.choice(captions)

        return image, caption  # Return a single caption

In [11]:
# Define paths
folder_path = os.path.join(extract_to, "flicker8k")
image_dir = os.path.join(folder_path, "images")
captions_file = os.path.join(folder_path, "Flickr8k.token.txt")
train_list_file = os.path.join(folder_path, "Flickr_8k.trainImages.txt")
test_list_file = os.path.join(folder_path, "Flickr_8k.testImages.txt")

# Define image transformations
transform = Compose([
    Resize((224, 224)),
    ToTensor(),
])

# Create datasets
train_dataset = Flickr8kDataset(image_dir, captions_file, train_list_file, transform)
test_dataset = Flickr8kDataset(image_dir, captions_file, test_list_file, transform)

# Define collate function
def collate_fn(batch):
    images, captions = zip(*batch)
    images = torch.stack(images)
    return {"pixel_values": images, "captions": captions}

# Create data loaders
batch_size = 8

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

Number of images in dataset: 6000
Number of images in dataset: 1000


In [12]:
class ViTGPT2Captioner(nn.Module):
    """
    Vision Transformer (ViT) + GPT-2 based Image Captioning Model.
    """
    def __init__(self, vit_model_name='google/vit-base-patch16-224', gpt2_model_name='gpt2'):
        super(ViTGPT2Captioner, self).__init__()

        # Load pre-trained ViT model
        print("Loading ViT model...")
        self.vit = ViTModel.from_pretrained(vit_model_name)
        self.vit_config = ViTConfig.from_pretrained(vit_model_name)
        self.vit_dim = self.vit_config.hidden_size  # Typically 768 for vit-base

        # Load pre-trained GPT-2 model
        print("Loading GPT-2 model...")
        self.gpt2 = GPT2LMHeadModel.from_pretrained(gpt2_model_name)
        self.gpt2.resize_token_embeddings(len(GPT2Tokenizer.from_pretrained(gpt2_model_name)))

        # Project ViT output to GPT-2's embedding size if necessary
        self.gpt2_embedding_dim = self.gpt2.config.n_embd  # Typically 768 for gpt2
        if self.vit_dim != self.gpt2_embedding_dim:
            self.vit_to_gpt2 = nn.Linear(self.vit_dim, self.gpt2_embedding_dim)
            print(f"Projecting ViT output from {self.vit_dim} to GPT-2 embedding size {self.gpt2_embedding_dim}")
        else:
            self.vit_to_gpt2 = nn.Identity()
            print("ViT and GPT-2 embedding dimensions match. No projection needed.")

        # Tokenizer
        self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token  # Set pad token to eos_token

    def forward(self, images, captions_input_ids, attention_mask=None, labels=None):
        """
        Forward pass through the model.
        """
        # Extract image features
        vit_outputs = self.vit(images)
        image_features = vit_outputs.last_hidden_state[:, 0, :]  # [batch_size, hidden_dim]
        image_features = self.vit_to_gpt2(image_features)  # [batch_size, gpt2_embedding_dim]

        # Prepare GPT-2 inputs by prepending image features
        # Expand image_features to sequence length 1
        image_features = image_features.unsqueeze(1)  # [batch_size, 1, gpt2_embedding_dim]

        # Get GPT-2 embeddings
        gpt2_embeddings = self.gpt2.transformer.wte(captions_input_ids)  # [batch_size, seq_len, gpt2_embedding_dim]

        # Concatenate image features and text embeddings
        inputs_embeds = torch.cat([image_features, gpt2_embeddings], dim=1)  # [batch_size, 1 + seq_len, gpt2_embedding_dim]

        # Adjust attention mask if provided
        if attention_mask is not None:
            image_attention_mask = torch.ones((attention_mask.size(0), 1), dtype=attention_mask.dtype, device=attention_mask.device)
            attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1)  # [batch_size, 1 + seq_len]

        # Pass through GPT-2
        outputs = self.gpt2(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)

        return outputs

In [13]:
model = ViTGPT2Captioner().to(device)

# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# Number of epochs
epochs = 5

# Function to save model checkpoints
def save_checkpoint(model, optimizer, epoch, loss, filename):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss
    }
    torch.save(checkpoint, filename)
    print(f"Checkpoint saved to {filename}")

Loading ViT model...


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loading GPT-2 model...
ViT and GPT-2 embedding dimensions match. No projection needed.


In [17]:


print("Starting training...")

for epoch in range(1, epochs + 1):
    model.train()
    epoch_loss = 0
    for batch in tqdm(train_dataloader, desc=f"Training Epoch {epoch}/{epochs}"):
        images = batch['pixel_values'].to(device)
        captions = batch['captions']

        # Tokenize captions
        encoding = model.tokenizer(
            captions,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=30
        )
        input_ids = encoding['input_ids'].to(device)
        attention_mask = encoding['attention_mask'].to(device)

        # Prepare labels by prepending -100 to ignore the image feature token
        batch_size = input_ids.size(0)
        padding = torch.full((batch_size, 1), -100, dtype=input_ids.dtype, device=device)
        labels = torch.cat([padding, input_ids], dim=1)  # Shape: (batch_size, 1 + seq_len)

        # Forward pass
        outputs = model(images, input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    avg_loss = epoch_loss / len(train_dataloader)
    print(f"Epoch {epoch}/{epochs} - Loss: {avg_loss:.4f}")

    # Save checkpoint after each epoch
    checkpoint_filename = f"vitgpt2_epoch{epoch}.pth"
    save_checkpoint(model, optimizer, epoch, avg_loss, checkpoint_filename)

print("Training completed.")


Starting training...


Training Epoch 1/5: 100%|██████████| 750/750 [04:59<00:00,  2.50it/s]


Epoch 1/5 - Loss: 2.1819
Checkpoint saved to vitgpt2_epoch1.pth


Training Epoch 2/5: 100%|██████████| 750/750 [05:11<00:00,  2.41it/s]


Epoch 2/5 - Loss: 1.9058
Checkpoint saved to vitgpt2_epoch2.pth


Training Epoch 3/5: 100%|██████████| 750/750 [05:01<00:00,  2.49it/s]


Epoch 3/5 - Loss: 1.8263
Checkpoint saved to vitgpt2_epoch3.pth


Training Epoch 4/5: 100%|██████████| 750/750 [05:00<00:00,  2.50it/s]


Epoch 4/5 - Loss: 1.7317
Checkpoint saved to vitgpt2_epoch4.pth


Training Epoch 5/5: 100%|██████████| 750/750 [04:59<00:00,  2.51it/s]


Epoch 5/5 - Loss: 1.6594
Checkpoint saved to vitgpt2_epoch5.pth
Training completed.


In [64]:
import re  # Import the regular expressions module

def generate_caption(model, image, tokenizer, device, max_length=30, sampling_strategy='greedy', top_k=50, top_p=0.95, temperature=1.0):
    """
    Generates a caption for a given image using the trained model.

    Args:
        model (nn.Module): The trained ViT-GPT2 captioning model.
        image (PIL.Image or Tensor): The input image.
        tokenizer (GPT2Tokenizer): The tokenizer used during training.
        device (torch.device): The device to perform computations on.
        max_length (int): Maximum length of the generated caption.
        sampling_strategy (str): Decoding strategy ('greedy', 'top_k', 'top_p').
        top_k (int): Number of top tokens to consider for top-k sampling.
        top_p (float): Cumulative probability threshold for top-p sampling.
        temperature (float): Sampling temperature.

    Returns:
        str: The generated caption.
    """
    model.eval()
    with torch.no_grad():
        # Preprocess the image
        if isinstance(image, Image.Image):
            image = image.convert("RGB")
            image = transform(image)  # Ensure 'transform' is defined globally
        elif isinstance(image, torch.Tensor):
            pass  # Image is already a tensor
        else:
            raise ValueError("Unsupported image type. Provide a PIL.Image or torch.Tensor.")

        image = image.to(device).unsqueeze(0)  # Add batch dimension

        # Extract image features using ViT
        vit_outputs = model.vit(image)
        image_features = vit_outputs.last_hidden_state[:, 0, :]  # [1, hidden_dim]
        image_features = model.vit_to_gpt2(image_features)        # [1, gpt2_embedding_dim]
        image_features = image_features.unsqueeze(1)             # [1, 1, gpt2_embedding_dim]

        # Initialize generated tokens with a shorter prompt, e.g., "Image:"
        prompt = "A photo of"
        input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)  # [1, len(prompt)]
        attention_mask = torch.ones_like(input_ids).to(device)              # [1, len(prompt)]

        generated_captions = []

        for _ in range(max_length):
            # Get GPT-2 embeddings for the current input_ids
            gpt2_embeddings = model.gpt2.transformer.wte(input_ids)  # [1, seq_len, gpt2_embedding_dim]

            # Concatenate image features with text embeddings
            inputs_embeds = torch.cat([image_features, gpt2_embeddings], dim=1)  # [1, 1 + seq_len, gpt2_embedding_dim]

            # Update attention mask to account for image features
            current_attention_mask = torch.cat([
                torch.ones((1, image_features.size(1)), device=device),
                attention_mask
            ], dim=1)  # [1, 1 + seq_len]

            # Forward pass through GPT-2
            outputs = model.gpt2(inputs_embeds=inputs_embeds, attention_mask=current_attention_mask)
            logits = outputs.logits  # [1, 1 + seq_len, vocab_size]

            # Get the logits for the last token
            next_token_logits = logits[:, -1, :]  # [1, vocab_size]

            # Apply temperature
            next_token_logits = next_token_logits / temperature

            # Apply sampling strategy
            if sampling_strategy == 'greedy':
                next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)  # [1, 1]
            elif sampling_strategy == 'top_k':
                probabilities = torch.softmax(next_token_logits, dim=-1)
                top_k_probs, top_k_indices = torch.topk(probabilities, top_k, dim=-1)
                top_k_probs = top_k_probs.squeeze(0)
                top_k_indices = top_k_indices.squeeze(0)
                next_token = torch.multinomial(top_k_probs, num_samples=1).unsqueeze(0)
                next_token = top_k_indices[next_token]
            elif sampling_strategy == 'top_p':
                sorted_probs, sorted_indices = torch.sort(torch.softmax(next_token_logits, dim=-1), descending=True)
                cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
                sorted_indices_to_remove[:, 0] = 0  # Always keep the first token
                indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                next_token_logits[indices_to_remove] = -float('Inf')
                probabilities = torch.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probabilities, num_samples=1)
            else:
                raise ValueError("Unsupported sampling strategy. Choose from 'greedy', 'top_k', 'top_p'.")

            # Decode the predicted token (for debugging purposes)
            predicted_token = tokenizer.decode(next_token.squeeze(), skip_special_tokens=True)
            # Uncomment the following lines for debugging
            # print(f"Predicted Token ID: {next_token.item()} | Token: {predicted_token}")

            # Append the predicted token to the generated caption
            if next_token.item() == tokenizer.eos_token_id:
                break  # Stop generation if <EOS> token is predicted
            generated_captions.append(predicted_token)

            # Update input_ids and attention_mask for the next iteration
            input_ids = torch.cat([input_ids, next_token], dim=1)            # [1, seq_len + 1]
            attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=1)  # [1, seq_len + 1]

        # Join all predicted tokens to form the final caption
        caption = ' '.join(generated_captions).strip()

        # Eliminate multiple consecutive spaces using regex
        caption = re.sub(r'\s+', ' ', caption)

        # Optionally, capitalize the first letter and add a period at the end
        caption = caption.capitalize()
        if not caption.endswith('.'):
            caption += '.'

        return caption

In [65]:

# Generate captions for the first 5 test images
for i in range(5):
    image, ref = test_dataset[i]
    caption = generate_caption(model, image, model.tokenizer, device=device)
    print(f"Image {i+1} Caption: {caption}")
    print(f"Image {i+1} Reference Caption: {ref}")

print("Caption generation completed.")

Image 1 Caption: A woman in a bikini is shown on a street .
Image 1 Reference Caption: A woman is signaling is to traffic , as seen from behind .
Image 2 Caption: A boy in a swimming pool .
Image 2 Reference Caption: Children playing on the beach .
Image 3 Caption: A man and a woman sitting on a bench .
Image 3 Reference Caption: A man and a woman sitting on a dock .
Image 4 Caption: A dog with a red collar is shown .
Image 4 Reference Caption: A white dog is resting its head on a tiled floor with its eyes open .
Image 5 Caption: A little boy with a red shirt and blue jeans .
Image 5 Reference Caption: A boy with a toy gun .
Caption generation completed.


In [26]:
!pip install evaluate rouge_score

Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24935 sha256=39eeef5dd7e3437a1bc73ee1a1c961389b04e87672c18c7607db3a7566e64e4d
  Stored in directory: /root/.cache/pip/wheels/5f/dd/89/461065a73be61a532ff8599a28e9beef17985c9e9c31e541b4
Successfully built rouge_score
Installing collected packages: rouge_score
Successfully installed rouge_score-0.1.2


In [58]:
from evaluate import load

def compute_evaluation_metrics(generated_captions, reference_captions):
    """
    Compute METEOR and ROUGE scores for generated captions.

    Args:
        generated_captions (list of str): Captions generated by the model.
        reference_captions (list of list of str): Reference captions for each image.

    Returns:
        dict: Dictionary containing METEOR, BLEU and ROUGE scores.
    """
    meteor = load("meteor")
    rouge = load("rouge")
    bleu = load("bleu")

    meteor_score = meteor.compute(predictions=generated_captions, references=reference_captions)
    rouge_score = rouge.compute(predictions=generated_captions, references=reference_captions)
    bleu_score = bleu.compute(predictions=generated_captions, references=reference_captions)

    results = {
        "meteor": meteor_score["meteor"],
        "bleu": bleu_score["bleu"],
        "rouge1": rouge_score["rouge1"],
        "rouge2": rouge_score["rouge2"],
        "rougeL": rouge_score["rougeL"],
    }
    return results

In [66]:
print("Generating captions for test images and evaluating...")

generated_captions = []
reference_captions = []

# Since test_dataset is set to return all captions, we can access them directly
for idx in tqdm(range(len(test_dataset)), desc="Evaluating"):
    image, captions = test_dataset[idx]  # captions is a list of 5 reference captions
    # Generate caption
    generated_caption = generate_caption(model, image, model.tokenizer, device=device)
    generated_captions.append(generated_caption)
    reference_captions.append(captions)  # List of 5 reference captions

# Compute evaluation metrics
evaluation_results = compute_evaluation_metrics(generated_captions, reference_captions)

# Display the results
print("\nEvaluation Metrics:")
print(f"METEOR: {evaluation_results['meteor']:.4f}")
print(f"ROUGE-1: {evaluation_results['rouge1']:.4f}")
print(f"ROUGE-L: {evaluation_results['rougeL']:.4f}")

Generating captions for test images and evaluating...


Evaluating: 100%|██████████| 1000/1000 [02:09<00:00,  7.73it/s]
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!



Evaluation Metrics:
METEOR: 0.2310
ROUGE-1: 0.2972
ROUGE-L: 0.2752
