In [None]:
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel, Blip2Processor, Blip2ForConditionalGeneration

# Load the reconstructed image (will eventually be the MinD-Vis output)
reconstructed_image = Image.open("../test_images/n02128757_snow_leopard.JPEG")

# Load CLIP model and processor
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Process the image and generate CLIP embedding
clip_inputs = clip_processor(images=reconstructed_image, return_tensors="pt")
with torch.no_grad():
    image_embedding = clip_model.get_image_features(**clip_inputs)

# Load BLIP-2 model and processor
blip2_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
blip2_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")

# Generate captions directly using BLIP-2
blip2_inputs = blip2_processor(images=reconstructed_image, return_tensors="pt")
with torch.no_grad():
    generated_ids = blip2_model.generate(**blip2_inputs)
    caption = blip2_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

print("Generated Caption:", caption)


preprocessor_config.json: 100%|██████████| 432/432 [00:00<00:00, 1.37MB/s]
tokenizer_config.json: 100%|██████████| 904/904 [00:00<00:00, 2.86MB/s]
vocab.json: 100%|██████████| 798k/798k [00:00<00:00, 7.12MB/s]
merges.txt: 100%|██████████| 456k/456k [00:00<00:00, 14.6MB/s]
tokenizer.json: 100%|██████████| 2.11M/2.11M [00:00<00:00, 12.1MB/s]
special_tokens_map.json: 100%|██████████| 548/548 [00:00<00:00, 1.39MB/s]
config.json: 100%|██████████| 6.96k/6.96k [00:00<00:00, 13.7MB/s]
model.safetensors.index.json: 100%|██████████| 127k/127k [00:00<00:00, 75.5MB/s]
model-00001-of-00002.safetensors: 100%|██████████| 10.0G/10.0G [03:43<00:00, 44.8MB/s]
model-00002-of-00002.safetensors: 100%|██████████| 5.50G/5.50G [02:02<00:00, 45.0MB/s]
Downloading shards: 100%|██████████| 2/2 [05:46<00:00, 173.26s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [01:32<00:00, 46.21s/it]


Generated Caption: a snow leopard is resting on a rock

