In [None]:
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel, BlipForConditionalGeneration, BlipProcessor

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load the fine-tuned CLIP model
clip_model = CLIPModel.from_pretrained("clip_rl_finetuned").to(device)#arange path from mentined model file
clip_processor = CLIPProcessor.from_pretrained("clip_rl_finetuned_processor")#arange path from mentined model file

# Load the BLIP model for text generation
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")

def process_image(image_path):
    """
    Process input image and generate features using CLIP.
    """
    image = Image.open(image_path).convert("RGB")
    clip_inputs = clip_processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        clip_features = clip_model.get_image_features(**clip_inputs)
    return image, clip_features

def generate_caption(clip_features):
    """
    Use BLIP to generate a caption based on CLIP features.
    """
    clip_features = clip_features.unsqueeze(0)  # Add batch dimension
    blip_inputs = {"pixel_values": clip_features}
    with torch.no_grad():
        generated_ids = blip_model.generate(**blip_inputs, max_length=50, num_beams=5)
        caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return caption

def main():
    """
    Main script to handle user input and process images.
    """
    # Get image path from user
    image_path = input("Enter the path to your image: ").strip()
    
    try:
        # Process image and extract features
        image, clip_features = process_image(image_path)
        
        # Generate a caption using BLIP
        caption = generate_caption(clip_features)
        
        # Display results
        print(f"Generated Caption: {caption}")
        image.show()  # Opens the image in the default viewer
    except Exception as e:
        print(f"Error processing the image: {e}")

if __name__ == "__main__":
    main()
