<a href="https://colab.research.google.com/github/TamannaAhmad/research-paper-analyser/blob/main/image_captioning_evaluation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install evaluate
!pip install rouge_score

Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting datasets>=2.0.0 (from evaluate)
  Downloading datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Collecting dill (from evaluate)
  Downloading dill-0.3.9-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from evaluate)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from evaluate)
  Downloading multiprocess-0.70.17-py311-none-any.whl.metadata (7.2 kB)
Collecting dill (from evaluate)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting multiprocess (from evaluate)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading datasets-3.3.2-py3-none-any.whl (485 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from transformers import BlipProcessor, BlipForConditionalGeneration
from evaluate import load
from PIL import Image
from tqdm import tqdm
import os
from google.colab import drive

In [None]:
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
class PaperImageCaptionDataset(Dataset):
    def __init__(self, data_dir, transform=None, max_samples=None):
        self.data_dir = data_dir
        self.transform = transform
        self.samples = []

        # Get subdirectories for each paper
        papers = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]

        # Collect valid image-caption pairs without loading images
        for paper in papers:
            paper_dir = os.path.join(data_dir, paper)
            files = os.listdir(paper_dir)

            image_files = [f for f in files if os.path.splitext(f)[1].lower() in ['.png', '.jpg', '.jpeg']]

            for img_file in image_files:
                base_name = os.path.splitext(img_file)[0]
                caption_file = f"{base_name}_caption.txt"

                if caption_file in files:
                    img_path = os.path.join(paper_dir, img_file)
                    caption_path = os.path.join(paper_dir, caption_file)

                    try:
                        with open(caption_path, 'r', encoding='utf-8') as f:
                            caption = f.read().strip()
                            if caption:
                                self.samples.append((img_path, caption))
                    except Exception as e:
                        print(f"Error reading caption {caption_path}: {e}")

        # Limit dataset size if specified
        if max_samples and len(self.samples) > max_samples:
            self.samples = self.samples[:max_samples]

        print(f"Dataset created with {len(self.samples)} image-caption pairs")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, caption = self.samples[idx]

        try:
            # Only load image when requested during iteration
            image = Image.open(img_path).convert('RGB')

            if self.transform:
                image = self.transform(image)

            return {'image': image, 'caption': caption, 'path': img_path}

        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a placeholder
            if self.transform:
                placeholder = torch.zeros(3, 384, 384)
            else:
                placeholder = Image.new('RGB', (384, 384), color='black')

            return {'image': placeholder, 'caption': "Error loading image", 'path': img_path}

In [None]:
# data preparation
def create_dataloaders(data_dir, batch_size=4, image_size=384, train_split=0.9):
    # image transformations
    transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor()
])

    # create dataset
    dataset = PaperImageCaptionDataset(data_dir, transform=transform)

    # check if dataset is empty
    if len(dataset) == 0:
        raise ValueError("Dataset is empty. Check that your image-caption pairs were properly extracted.")

    # split into train and validation
    train_size = int(train_split * len(dataset))
    val_size = len(dataset) - train_size

    if train_size == 0 or val_size == 0:
        raise ValueError(f"Not enough data for splitting. Found {len(dataset)} samples.")

    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

    # create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

    return train_loader, val_loader

In [None]:
def evaluate_model(model, processor, dataloader, device="cuda"):
    model.eval()
    references = []  # List of lists (each sublist contains reference captions)
    hypotheses = []  # List of generated captions

    # Load metrics
    bleu = load('bleu')
    rouge = load('rouge')
    meteor = load('meteor')

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            images = batch['image'].to(device)
            original_captions = batch['caption']

            # Skip batches with errors
            if "Error loading image" in original_captions:
                continue

            try:
                # Generate captions
                generated_ids = model.generate(
                    pixel_values=images,
                    max_length=100,
                    num_beams=5,
                    no_repeat_ngram_size=2,
                    temperature=1.2,
                    do_sample=True
                )
                generated_captions = processor.batch_decode(generated_ids, skip_special_tokens=True)

                hypotheses.extend(generated_captions)
                references.extend([[cap] for cap in original_captions])  # Wrap each caption in a list

            except Exception as e:
                print(f"Error generating captions: {e}")
                continue

    # Compute metrics
    bleu_results = bleu.compute(predictions=hypotheses, references=references)
    rouge_results = rouge.compute(predictions=hypotheses, references=references)
    meteor_results = meteor.compute(predictions=hypotheses, references=references)

    print("\nEvaluation Results:")
    print(f"BLEU: {bleu_results['bleu']:.4f}")
    print(f"ROUGE-L: {rouge_results['rougeL']:.4f}")
    print(f"METEOR: {meteor_results['meteor']:.4f}")

    return {
        'bleu': bleu_results,
        'rouge': rouge_results,
        'meteor': meteor_results
    }

In [None]:
def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # Paths (update with your paths)
    data_dir = "/content/drive/MyDrive/training_data_images"
    model_path = "/content/drive/MyDrive/best_blip_captioning_model.pth"

    # Create dataloaders
    _, val_loader = create_dataloaders(data_dir, batch_size=4)

    # Load model and processor
    processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
    model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)

    # Run evaluation
    results = evaluate_model(model, processor, val_loader, device)

In [None]:
if __name__ == "__main__":
    main()

Using device: cuda
Dataset created with 440 image-caption pairs


preprocessor_config.json:   0%|          | 0.00/445 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/527 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/4.60k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.88G [00:00<?, ?B/s]

  model.load_state_dict(torch.load(model_path, map_location=device))


Downloading builder script:   0%|          | 0.00/5.94k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.34k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/7.02k [00:00<?, ?B/s]

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
Evaluating: 100%|██████████| 11/11 [02:43<00:00, 14.85s/it]



Evaluation Results:
BLEU: 0.0000
ROUGE-L: 0.1447
METEOR: 0.1037
