<a href="https://colab.research.google.com/github/TamannaAhmad/research-paper-analyser/blob/main/image_captioning_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
from huggingface_hub import login, snapshot_download

# Access the token from Colab secrets
login(token=os.environ.get("HF_TOKEN"))

In [None]:
import torch
import json
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from tqdm import tqdm
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from transformers import BlipProcessor, BlipForConditionalGeneration


In [None]:
 # download the SciCap dataset
dataset_path = snapshot_download(repo_id="CrowdAILab/scicap", repo_type='dataset')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
class SciCapDataset(Dataset):
    def __init__(self, data_dir, split="train", transform=None, max_samples=None):
        self.data_dir = data_dir
        self.transform = transform
        self.samples = []

        # SciCap dataset structure is different than expected
        # Check for alternative file paths
        annotations_path = None

        # Try different possible annotation file locations
        possible_paths = [
            os.path.join(data_dir, f'{split}_captions.json'),  # Root directory
            os.path.join(data_dir, split, f'{split}_captions.json'),  # Split subdirectory
            os.path.join(data_dir, 'annotations', f'{split}_captions.json')  # Annotations subdirectory
        ]

        for path in possible_paths:
            if os.path.exists(path):
                annotations_path = path
                break

        if not annotations_path:
            raise FileNotFoundError(f"Could not find annotation file for {split} split. Searched: {possible_paths}")

        print(f"Using annotations from: {annotations_path}")

        # Load annotations
        with open(annotations_path, 'r', encoding='utf-8') as f:
            captions_data = json.load(f)

        # Try to find image directory
        img_dir = None
        possible_img_dirs = [
            os.path.join(data_dir, split, 'images'),
            os.path.join(data_dir, 'images', split),
            os.path.join(data_dir, 'images')
        ]

        for path in possible_img_dirs:
            if os.path.exists(path):
                img_dir = path
                break

        if not img_dir:
            raise FileNotFoundError(f"Could not find images directory for {split} split. Searched: {possible_img_dirs}")

        print(f"Using images from: {img_dir}")

        # Process images based on the actual structure of captions_data
        if isinstance(captions_data, dict):
            # Format: {img_name: caption_info, ...}
            for img_name, caption_info in captions_data.items():
                img_path = os.path.join(img_dir, img_name)

                if isinstance(caption_info, list):
                    caption = caption_info[0]  # Use first caption (fixed typo)
                elif isinstance(caption_info, dict) and 'caption' in caption_info:
                    caption = caption_info['caption']  # Fixed typo
                else:
                    caption = str(caption_info)

                if os.path.exists(img_path):
                    self.samples.append((img_path, caption))

        if max_samples and len(self.samples) > max_samples:  # Fixed typo
            self.samples = self.samples[:max_samples]

        print(f"Dataset created with {len(self.samples)} image-caption pairs from {split} set")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, caption = self.samples[idx]

        try:
            # Only load image when requested during iteration
            image = Image.open(img_path).convert('RGB')

            if self.transform:
                image = self.transform(image)

            return {'image': image, 'caption': caption, 'path': img_path}

        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            if self.transform:
                placeholder = torch.zeros(3, 384, 384)
            else:
                placeholder = Image.new('RGB', (384, 384), color='black')

            return {'image': placeholder, 'caption': "Error loading image", 'path': img_path}

In [None]:
# data preparation
def create_dataloaders(data_dir, batch_size=4, image_size=384):
    # image transformations
    transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor()
])

    # create train and val datasets
    train_dataset = SciCapDataset(data_dir, split = 'train', transform = transform)
    val_dataset = SciCapDataset(data_dir, split = 'val', transform = transform)

    # create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

    return train_loader, val_loader

In [None]:
def train_model(train_loader, val_loader, learning_rate=5e-5, num_epochs=5, device="cuda"):
    # Load pre-trained model
    processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
    model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
    model.to(device)

    # Set up optimizer with a smaller batch size
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    # Use gradient accumulation to simulate larger batches
    gradient_accumulation_steps = 4

    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    # Training loop
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        valid_batches = 0

        # Use garbage collection more aggressively
        import gc

        for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training")):
            try:
                images = batch['image'].to(device)
                captions = batch['caption']

                if "Error loading image" in captions:
                    continue

                pixel_values = processor.image_processor(images, do_normalize=True, return_tensors="pt").pixel_values.to(device)

                batch_loss = 0
                for i, caption in enumerate(captions):
                    # Process one sample at a time to reduce memory usage
                    tokenized = processor.tokenizer(
                        caption,
                        padding="max_length",
                        truncation=True,
                        max_length=75,
                        return_tensors="pt"
                    ).to(device)

                    outputs = model(
                        pixel_values=pixel_values[i].unsqueeze(0),
                        input_ids=tokenized.input_ids,
                        attention_mask=tokenized.attention_mask,
                        labels=tokenized.input_ids
                    )

                    if outputs.loss is not None:
                        # Normalize loss by gradient accumulation steps
                        loss = outputs.loss / gradient_accumulation_steps
                        loss.backward()
                        batch_loss += outputs.loss.item()

                # Step optimizer only after accumulating gradients
                if (batch_idx + 1) % gradient_accumulation_steps == 0 or batch_idx == len(train_loader) - 1:
                    # Gradient clipping
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    optimizer.step()
                    optimizer.zero_grad()

                if batch_loss > 0:
                    epoch_loss += batch_loss / len(captions)
                    valid_batches += 1

                # Clear cache periodically
                if batch_idx % 10 == 0:
                    torch.cuda.empty_cache()
                    gc.collect()

            except Exception as e:
                print(f"Error during training: {e}")
                continue

        # update learning rate
        scheduler.step()

        # calculate average loss if there were valid batches
        avg_train_loss = epoch_loss / valid_batches if valid_batches > 0 else float('inf')
        train_losses.append(avg_train_loss)

        # validation
        model.eval()
        val_loss = 0
        valid_val_batches = 0

        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
                try:
                    images = batch['image'].to(device)
                    captions = batch['caption']

                    # skip batches with error placeholders
                    if "Error loading image" in captions:
                        continue

                    pixel_values = processor.image_processor(images, do_normalize=True, return_tensors="pt").pixel_values.to(device)

                    batch_inputs = []
                    for caption in captions:
                        tokenized = processor.tokenizer(
                            caption,
                            padding="max_length",
                            truncation=True,
                            max_length=75,
                            return_tensors="pt"
                        ).to(device)

                        batch_inputs.append({
                            "input_ids": tokenized.input_ids,
                            "attention_mask": tokenized.attention_mask
                        })

                    # calculate validation loss
                    batch_val_loss = 0
                    for i, inputs in enumerate(batch_inputs):
                        outputs = model(
                            pixel_values=pixel_values[i].unsqueeze(0),
                            input_ids=inputs["input_ids"],
                            attention_mask=inputs["attention_mask"],
                            labels=inputs["input_ids"]
                        )

                        if outputs.loss is not None:
                            batch_val_loss += outputs.loss.item()

                    if batch_val_loss > 0:
                        val_loss += batch_val_loss / len(batch_inputs)
                        valid_val_batches += 1

                except Exception as e:
                    print(f"Error during validation: {e}")
                    continue

        avg_val_loss = val_loss / valid_val_batches if valid_val_batches > 0 else float('inf')
        val_losses.append(avg_val_loss)

        print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, LR: {scheduler.get_last_lr()[0]:.6f}")

        # savve the best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), "best_blip_captioning_model.pth")
            print(f"Model saved with validation loss: {best_val_loss:.4f}")

    # plot training curves
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('training_curve.png')

    return model

In [None]:
def generate_captions(model, processor, test_loader, device="cuda", num_samples=5):
    model.eval()
    results = []

    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            if i >= num_samples:
                break

            images = batch['image'].to(device)
            original_captions = batch['caption']
            image_paths = batch['path']

            if "Error loading image" in original_captions:
                continue

            try:
                # generate captions with sampling enabled
                generated_ids = model.generate(
                    pixel_values=images,
                    max_length=100,
                    num_beams=5,
                    no_repeat_ngram_size=2,
                    temperature=1.2,
                    do_sample=True
                )
                generated_captions = processor.batch_decode(generated_ids, skip_special_tokens=True)

                for j, (img_path, orig_cap, gen_cap) in enumerate(zip(image_paths, original_captions, generated_captions)):
                    pil_image = Image.open(img_path).convert('RBG')
                    results.append({
                        '\nimage' : pil_image,
                        '\nimage_path': img_path,
                        '\noriginal_caption': orig_cap,
                        '\ngenerated_caption': gen_cap
                    })
            except Exception as e:
                print(f"Error generating captions: {e}")
                continue
    return results

In [None]:
def main():
    # check if CUDA is available
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # define parameters
    data_dir = dataset_path
    batch_size = 8
    learning_rate = 5e-5
    num_epochs = 5

    try:
        # create dataloaders
        train_loader, val_loader = create_dataloaders(data_dir, batch_size=batch_size)

        print(f"Training on {len(train_loader.dataset)} samples, validating on {len(val_loader.dataset)} samples")

        # train model
        model = train_model(train_loader, val_loader, learning_rate=learning_rate,
                           num_epochs=num_epochs, device=device)

        # Save the model to Google Drive
        model_save_path = "/content/drive/MyDrive/best_blip_captioning_model.pth"
        torch.save(model.state_dict(), model_save_path)
        print(f"Best model saved to {model_save_path}")

        # load processor for inference
        processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")

        # generate some example captions
        results = generate_captions(model, processor, val_loader, device=device, num_samples=5)

        # display results
        for item in results:
          plt.figure(figsize=(10, 10))
          plt.imshow(item['image'])
          plt.axis('off')
          plt.title(f"Image: {os.path.basename(item['image_path'])}")
          plt.show()
          print(f"Original: {item['original_caption']}")
          print(f"Generated: {item['generated_caption']}")
          print("-" * 80)

    except Exception as e:
        print(f"An error occurred in the main function: {e}")
        import traceback
        traceback.print_exc()

In [None]:
if __name__ == "__main__":
    main()

Using device: cuda
An error occurred in the main function: Could not find annotation file for train split. Searched: ['/root/.cache/huggingface/hub/datasets--CrowdAILab--scicap/snapshots/60e504baa94423f63cda87d5442e73a696b953d3/train_captions.json', '/root/.cache/huggingface/hub/datasets--CrowdAILab--scicap/snapshots/60e504baa94423f63cda87d5442e73a696b953d3/train/train_captions.json', '/root/.cache/huggingface/hub/datasets--CrowdAILab--scicap/snapshots/60e504baa94423f63cda87d5442e73a696b953d3/annotations/train_captions.json']


Traceback (most recent call last):
  File "<ipython-input-27-b0b52a2609ba>", line 14, in main
    train_loader, val_loader = create_dataloaders(data_dir, batch_size=batch_size)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<ipython-input-24-b4da2fc243f0>", line 10, in create_dataloaders
    train_dataset = SciCapDataset(data_dir, split = 'train', transform = transform)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<ipython-input-23-f27a3e028ec7>", line 24, in __init__
    raise FileNotFoundError(f"Could not find annotation file for {split} split. Searched: {possible_paths}")
FileNotFoundError: Could not find annotation file for train split. Searched: ['/root/.cache/huggingface/hub/datasets--CrowdAILab--scicap/snapshots/60e504baa94423f63cda87d5442e73a696b953d3/train_captions.json', '/root/.cache/huggingface/hub/datasets--CrowdAILab--scicap/snapshots/60e504baa94423f63cda87d5442e73a696b953d3/train/t