<a href="https://colab.research.google.com/github/Meshal6299/arabic-image-captioning/blob/main/notebooks/01_BLIP_Arabic_FineTuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Run this cell
!pip install transformers datasets torch pillow

In [None]:
# Cell 2: Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Cell 3: Import All Libraries
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import BlipProcessor, BlipForConditionalGeneration
from torch.optim import AdamW  # <-- CORRECTED IMPORT
from PIL import Image
import os
import pandas as pd
from tqdm.auto import tqdm # For a nice progress bar

In [None]:
# Cell 4: Define the Custom Dataset Class
class ArabicImageCaptionDataset(Dataset):
    def __init__(self, dataset_file, image_dir, processor, max_length=128):
        """
        Args:
            dataset_file (str): Path to the .txt file (e.g., "image_name.jpg,arabic_text").
            image_dir (str): Directory with all the images.
            processor (BlipProcessor): The BLIP processor for images and text.
            max_length (int): Max token length for the text.
        """
        self.image_dir = image_dir
        self.processor = processor
        self.max_length = max_length

        # Load the dataset
        self.data = []
        with open(dataset_file, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split(',', 1) # Split only on the first comma
                if len(parts) == 2:
                    image_name, text = parts
                    self.data.append({"image_name": image_name, "text": text})

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # Load image
        image_path = os.path.join(self.image_dir, item["image_name"])
        try:
            image = Image.open(image_path).convert("RGB")
        except FileNotFoundError:
            print(f"Warning: Image file not found {image_path}. Skipping.")
            # Return a dummy item or handle this error
            # For simplicity, we'll just grab the next item (this is a simple fix)
            return self.__getitem__((idx + 1) % len(self))

        text = item["text"]

        # Process image and text
        # The processor will handle image normalization and text tokenization
        # For training, we pass the text to 'text' to be tokenized
        inputs = self.processor(images=image,
                                text=text,
                                return_tensors="pt",
                                padding="max_length",
                                truncation=True,
                                max_length=self.max_length)

        # Squeeze the dimensions from (1, C, H, W) to (C, H, W) etc.
        # This is because the processor batches them by default.
        inputs['pixel_values'] = inputs['pixel_values'].squeeze(0)
        inputs['input_ids'] = inputs['input_ids'].squeeze(0)
        inputs['attention_mask'] = inputs['attention_mask'].squeeze(0)

        # For fine-tuning, the 'input_ids' are the labels
        # We replace padding token IDs (0) with -100 so they are ignored in loss calculation
        inputs['labels'] = inputs['input_ids'].clone()
        inputs['labels'][inputs['labels'] == self.processor.tokenizer.pad_token_id] = -100

        return inputs

In [None]:
# Cell 5: Load Processor and Model
print("Loading model and processor...")
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
print("Done.")

In [None]:
# Cell 6: Define File Paths and Device

# This is the path to the folder you created in Google Drive
PROJECT_PATH = "/content/drive/MyDrive/PR Project/dataset"

DATASET_FILE = os.path.join(PROJECT_PATH, "Arabic_Description_sample.csv")
IMAGE_DIR = os.path.join(PROJECT_PATH, "Images")

# Set up the device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device) # Move the model to the GPU
print(f"Using device: {device}")

In [None]:
# Cell 7: Load and Split the Dataset
print("Loading dataset...")
full_dataset = ArabicImageCaptionDataset(dataset_file=DATASET_FILE,
                                         image_dir=IMAGE_DIR,
                                         processor=processor)

# Split into training and validation
train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

print(f"Dataset loaded. Training size: {len(train_dataset)}, Validation size: {len(val_dataset)}")

In [None]:
# Cell 8: Create DataLoaders
# A 'collate_fn' is needed to batch our processed inputs together
def collate_fn(batch):
    # 'batch' is a list of dictionaries from our Dataset
    pixel_values = torch.stack([item['pixel_values'] for item in batch])
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])

    return {
        'pixel_values': pixel_values,
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels
    }

# Create DataLoaders
BATCH_SIZE = 4 # Try 4 or 8. If you get "Out of Memory", lower this.
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)

In [None]:
# Cell 9: Set Up Optimizer
optimizer = AdamW(model.parameters(), lr=5e-5) # 5e-5 is a common learning rate

In [None]:
# Cell 10: The Training & Validation Loop
NUM_EPOCHS = 3 # Start with 3. You can increase this later if needed.

print("Starting training...")

for epoch in range(NUM_EPOCHS):
    print(f"--- Epoch {epoch+1}/{NUM_EPOCHS} ---")

    # --- Training ---
    model.train() # Set model to training mode
    train_loss = 0

    # Use tqdm for a progress bar
    for batch in tqdm(train_loader, desc="Training"):
        # Move batch to GPU
        inputs = {k: v.to(device) for k, v in batch.items()}

        # Get model outputs
        outputs = model(**inputs)

        # Get the loss
        loss = outputs.loss
        train_loss += loss.item()

        # Backpropagation
        optimizer.zero_grad() # Clear old gradients
        loss.backward()       # Calculate new gradients
        optimizer.step()      # Update model weights

    avg_train_loss = train_loss / len(train_loader)
    print(f"Average Training Loss: {avg_train_loss:.4f}")

    # --- Validation ---
    model.eval() # Set model to evaluation mode
    val_loss = 0

    with torch.no_grad(): # Don't calculate gradients
        for batch in tqdm(val_loader, desc="Validation"):
            # Move batch to GPU
            inputs = {k: v.to(device) for k, v in batch.items()}

            # Get model outputs
            outputs = model(**inputs)

            # Get the loss
            loss = outputs.loss
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"Average Validation Loss: {avg_val_loss:.4f}")

print("Training complete!")

In [None]:
# Cell 11: Save the Final Model
print("Saving model...")

# Define the path to save the model
SAVE_PATH = os.path.join(PROJECT_PATH, "arabic_blip_model")

# Create the directory if it doesn't exist
os.makedirs(SAVE_PATH, exist_ok=True)

# Save the model's state and the processor
model.save_pretrained(SAVE_PATH)
processor.save_pretrained(SAVE_PATH)

print(f"Model saved to {SAVE_PATH}")