In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
import os
from tqdm import tqdm


class ImageCaptionDataset(Dataset):
    def __init__(self, image_folder, caption_file, processor):
        self.image_folder = image_folder
        with open(caption_file, 'r') as f:
            lines = f.readlines()[1:]
            self.captions = [line.strip().split(',', 1) for line in lines]
        self.processor = processor

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        img_name, caption = self.captions[idx]
        img_path = os.path.join(self.image_folder, img_name)
        image = Image.open(img_path).convert("RGB")
        inputs = self.processor(
            images=image,
            text=caption,
            return_tensors="pt",
            padding="max_length",
            max_length=50,
            truncation=True
        )
        output = {key: val.squeeze(0) for key, val in inputs.items()}
        output["caption_str"] = caption
        return output


image_folder = "dataset/Images"
caption_file = "dataset/captions.txt"

processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base", use_fast=True)
processor.tokenizer.padding_side = "left"

dataset = ImageCaptionDataset(image_folder, caption_file, processor)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

early_stopping_patience = 2
best_loss = float('inf')
epochs_without_improvement = 0

for epoch in range(5):
    model.train()
    epoch_loss = 0
    correct_predictions = 0
    total_samples = 0

    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}", unit="batch", total=len(dataloader))
    
    for batch in progress_bar:
        inputs = {key: val.to(device) for key, val in batch.items() if key != "caption_str"}
        outputs = model(**inputs, labels=inputs["input_ids"])
        loss = outputs.loss

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        epoch_loss += loss.item()

        generated_ids = model.generate(**inputs)
        generated_captions = processor.batch_decode(generated_ids, skip_special_tokens=True)
        ground_truth_captions = batch["caption_str"]

        for gen, gt in zip(generated_captions, ground_truth_captions):
            if gen.strip().lower() == gt.strip().lower():
                correct_predictions += 1
            total_samples += 1

        progress_bar.set_postfix(loss=loss.item())

    avg_epoch_loss = epoch_loss / len(dataloader)
    accuracy = (correct_predictions / total_samples) * 100
    print(f"Epoch {epoch + 1}, Loss: {avg_epoch_loss:.4f}, Accuracy: {accuracy:.2f}%")

    if avg_epoch_loss < best_loss:
        best_loss = avg_epoch_loss
        epochs_without_improvement = 0
        torch.save(model.state_dict(), "model.pth")
    else:
        epochs_without_improvement += 1

    if epochs_without_improvement >= early_stopping_patience:
        print("Early stopping triggered. Training stopped.")
        break

print("Training complete.")


In [None]:
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image

processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base", use_fast=True)
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
model.load_state_dict(torch.load("model.pth"))
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

image_path = "testImage/000000011699.jpg"
image = Image.open(image_path).convert("RGB")

inputs = processor(images=image, return_tensors="pt").to(device)
generated_ids = model.generate(**inputs)
caption = processor.decode(generated_ids[0], skip_special_tokens=True)

print(caption)
