In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
import os
import time
from tqdm import tqdm

class ImageCaptionDataset(Dataset):
    def __init__(self, image_folder, caption_file, processor):
        self.image_folder = image_folder
        with open(caption_file, 'r') as f:
            self.captions = [line.strip().split(',') for line in f.readlines()]
        self.processor = processor

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        img_name, caption = self.captions[idx]
        img_path = os.path.join(self.image_folder, img_name)
        image = Image.open(img_path).convert("RGB")
        inputs = self.processor(
            images=image,
            text=caption,
            return_tensors="pt",
            padding="max_length",
            max_length=50,
            truncation=True
        )
        return {key: val.squeeze(0) for key, val in inputs.items()}

image_folder = "dataset/Images"
caption_file = "dataset/captions.txt"

processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
dataset = ImageCaptionDataset(image_folder, caption_file, processor)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

early_stopping_patience = 2
best_loss = float('inf')
epochs_without_improvement = 0

for epoch in range(5):
    model.train()
    epoch_loss = 0
    correct_predictions = 0
    total_samples = 0

    for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}", unit="batch"):
        inputs = {key: val.to(device) for key, val in batch.items()}
        outputs = model(**inputs, labels=inputs["input_ids"])
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        epoch_loss += loss.item()

        generated_ids = model.generate(**inputs)
        generated_captions = processor.batch_decode(generated_ids, skip_special_tokens=True)
        ground_truth_captions = [caption for _, caption in batch["input_ids"]]
        
        for gen, gt in zip(generated_captions, ground_truth_captions):
            if gen.strip().lower() == gt.strip().lower():
                correct_predictions += 1
            total_samples += 1

    avg_epoch_loss = epoch_loss / len(dataloader)
    accuracy = correct_predictions / total_samples * 100
    print(f"Epoch {epoch + 1}, Loss: {avg_epoch_loss:.4f}, Accuracy: {accuracy:.2f}%")

    if avg_epoch_loss < best_loss:
        best_loss = avg_epoch_loss
        epochs_without_improvement = 0
        torch.save(model.state_dict(), "best_model.pth")
    else:
        epochs_without_improvement += 1

    if epochs_without_improvement >= early_stopping_patience:
        print("Early stopping triggered. Training stopped.")
        break

print("Training complete.")