In [5]:
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import torch
import os
from accelerate.test_utils.testing import get_backend

processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
device = 'cuda:1'
model = model.to(device)

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.23s/it]


In [6]:
def generate_embedding(image) -> torch.Tensor:
    question = "Based on visual cues in the image, what is the most likely religion of the person shown? Provide reasoning for your answer."
    prompt = f"Question: {question} Answer:"
    inputs = processor(image, text=prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        encoder_outputs = model.forward(**inputs)
    token_embeddings =  encoder_outputs.vision_outputs.last_hidden_state
    normalized_embeddings = torch.nn.functional.normalize(token_embeddings, p=2, dim=-1)
    avg_embeddings = torch.mean(normalized_embeddings, dim=1)
    return avg_embeddings.cpu()

In [7]:
from PIL import Image

def generate_embeddings(folder:str, setting: str, key: str, obj:str, prefixes: list[str], images_per_prompt: int) -> torch.Tensor:
    input_folder_base = f"{folder}/{setting}/{key}/{obj}"
    output_folder_base = f"embeddings/{setting}/{key}/{obj}"

    if not os.path.exists(output_folder_base):
                os.makedirs(output_folder_base)

    for file_name in os.listdir(input_folder_base):
        if file_name.endswith(".png"):
            with open(os.path.join(input_folder_base, file_name), "rb") as f:
                image = Image.open(f)
                emb = generate_embedding(image)
                output_file_name = f"{output_folder_base}/{file_name.replace('.png', '.pt')}"
                torch.save(emb, output_file_name)
        else:
            continue

In [8]:
from src.utils import for_each_prompt


for_each_prompt("prompts.json", "images", "work", generate_embeddings)
for_each_prompt("prompts.json", "images", "home", generate_embeddings)