In [None]:
import torch
import torch.nn as nn
from transformers import CLIPProcessor, BertTokenizer, BertForMaskedLM
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import random
import numpy as np
from tqdm import tqdm

class ImageCaptionDataset(Dataset):
    def __init__(self, image_folder, processor, tokenizer):
        self.image_folder = image_folder
        self.image_files = os.listdir(image_folder)
        self.processor = processor
        self.tokenizer = tokenizer
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_name = self.image_files[idx]
        image_path = os.path.join(self.image_folder, image_name)
        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)
        inputs = self.processor(images=image, return_tensors="pt")
        return inputs, image_name


class ImageCaptionModel(nn.Module):
    def __init__(self, vision_model, text_model, tokenizer):
        super(ImageCaptionModel, self).__init__()
        self.vision_model = vision_model
        self.text_model = text_model
        self.tokenizer = tokenizer

    def forward(self, pixel_values, input_ids, attention_mask):
        vision_outputs = self.vision_model(pixel_values=pixel_values)
        image_features = vision_outputs.pooler_output
        text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
        return image_features, text_outputs.loss

    def generate_caption(self, pixel_values):
        vision_outputs = self.vision_model(pixel_values=pixel_values)
        image_features = vision_outputs.pooler_output
        text_inputs = self.tokenizer.encode("A photo of", return_tensors="pt")
        generated_ids = self.text_model.generate(input_ids=text_inputs, decoder_start_token_id=self.tokenizer.pad_token_id)
        generated_caption = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        return generated_caption


vision_model = models.resnet50(pretrained=True)
vision_model.fc = nn.Identity()

processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
text_model = BertForMaskedLM.from_pretrained("bert-base-uncased")

dataset = ImageCaptionDataset(image_folder="dataset/images", processor=processor, tokenizer=tokenizer)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ImageCaptionModel(vision_model=vision_model, text_model=text_model, tokenizer=tokenizer).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

early_stopping_patience = 3
best_loss = float('inf')
epochs_without_improvement = 0

epochs = 5
for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    correct_predictions = 0
    total_samples = 0
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}", unit="batch", total=len(dataloader))

    for batch in progress_bar:
        inputs, image_names = batch
        pixel_values = inputs["pixel_values"].squeeze(0).to(device)
        input_ids = tokenizer.encode("A photo of", return_tensors="pt").to(device)
        attention_mask = torch.ones(input_ids.shape, device=device)

        optimizer.zero_grad()
        image_features, loss = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        generated_caption = model.generate_caption(pixel_values)
        ground_truth_caption = "A photo of"

        if generated_caption.strip().lower() == ground_truth_caption.strip().lower():
            correct_predictions += 1
        total_samples += 1

        accuracy = (correct_predictions / total_samples) * 100
        progress_bar.set_postfix(loss=loss.item(), accuracy=accuracy)

    avg_epoch_loss = epoch_loss / len(dataloader)
    print(f"Epoch {epoch+1} - Loss: {avg_epoch_loss:.4f}, Accuracy: {accuracy:.2f}%")

    if avg_epoch_loss < best_loss:
        best_loss = avg_epoch_loss
        epochs_without_improvement = 0
        torch.save(model.state_dict(), "cmodel.pth")
    else:
        epochs_without_improvement += 1

    if epochs_without_improvement >= early_stopping_patience:
        print("Early stopping triggered. Training stopped.")
        break

    model.eval()
    for idx in range(3):
        image_name = dataset.image_files[idx]
        image_path = os.path.join(dataset.image_folder, image_name)
        image = Image.open(image_path).convert("RGB")
        inputs = dataset.processor(images=image, return_tensors="pt").to(device)
        caption = model.generate_caption(inputs["pixel_values"])
        print(f"Generated caption for {image_name}: {caption}")

print("Training complete.")


He
