In [None]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from transformers import BlipProcessor, BlipForConditionalGeneration
import os
from PIL import Image
from tqdm import tqdm

class ModelDataset(Dataset):
    def __init__(self, image_folder, captions_file, processor, transform=None):
        self.image_folder = image_folder
        self.processor = processor
        self.transform = transform
        self.image_captions = []

        with open(captions_file, 'r', encoding='utf-8') as f:
            lines = f.readlines()[1:]
            for line in lines:
                parts = line.strip().split(',', 1)
                if len(parts) == 2:
                    self.image_captions.append((parts[0].strip(), parts[1].strip()))

    def __len__(self):
        return len(self.image_captions)

    def __getitem__(self, idx):
        img_name, caption = self.image_captions[idx]
        img_path = os.path.join(self.image_folder, img_name)
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        inputs = self.processor(images=image, text=caption, return_tensors="pt", padding="max_length", truncation=True, max_length=50, do_rescale=False)

        return {
            "pixel_values": inputs["pixel_values"].squeeze(0),
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0)
        }

transform = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)

image_folder = r"dataset\Images"
captions_file = r"dataset\captions.txt"

dataset = ModelDataset(image_folder, captions_file, processor, transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
loss_fn = torch.nn.CrossEntropyLoss()

num_epochs = 3
model.train()

for epoch in range(num_epochs):
    total_loss = 0
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for batch in progress_bar:
        pixel_values = batch["pixel_values"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        outputs = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
        loss = outputs.loss
        total_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        progress_bar.set_postfix(loss=loss.item())

    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {total_loss/len(dataloader)}")

torch.save({
    "model_state_dict": model.state_dict(),
    "processor": processor,
    "use_fast": False
}, "model.pth")

print("Model training complete.")


In [None]:
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

checkpoint = torch.load("model.pth", map_location=device)

processor = checkpoint["processor"]
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()

def generate_caption(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(device)
    
    with torch.no_grad():
        generated_ids = model.generate(**inputs)
    
    caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return caption

image_path = "testImage/000000011699.jpg"
caption = generate_caption(image_path)
print("Generated Caption:", caption)
