In [1]:
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import torch
from accelerate.test_utils.testing import get_backend

processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
device, _, _ = get_backend()
model.to(device)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.38s/it]


Blip2ForConditionalGeneration(
  (vision_model): Blip2VisionModel(
    (embeddings): Blip2VisionEmbeddings(
      (patch_embedding): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14))
    )
    (encoder): Blip2Encoder(
      (layers): ModuleList(
        (0-38): 39 x Blip2EncoderLayer(
          (self_attn): Blip2Attention(
            (dropout): Dropout(p=0.0, inplace=False)
            (qkv): Linear(in_features=1408, out_features=4224, bias=True)
            (projection): Linear(in_features=1408, out_features=1408, bias=True)
          )
          (layer_norm1): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
          (mlp): Blip2MLP(
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=1408, out_features=6144, bias=True)
            (fc2): Linear(in_features=6144, out_features=1408, bias=True)
          )
          (layer_norm2): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
        )
      )
    )
    (post_layernorm): LayerNorm((

In [2]:
def extract_embeddings(image, question) -> torch.Tensor:
    inputs = processor(image, text=question, return_tensors="pt").to(device)
    with torch.no_grad():
        encoder_outputs = model.forward(**inputs)
    token_embeddings =  encoder_outputs.vision_outputs.last_hidden_state
    normalized_embeddings = torch.nn.functional.normalize(token_embeddings, p=2, dim=-1)
    avg_embeddings = torch.mean(normalized_embeddings, dim=1)
    return avg_embeddings.cpu()
    

In [3]:
import os
from PIL import Image

question = "What is shown in the image?"
prompt = f"Question: {question} Answer:"

for filename in os.listdir("images"):
    if filename.endswith(".png"):
        with open(os.path.join("images", filename), "rb") as f:
            image = Image.open(f)
            emb = extract_embeddings(image, prompt)
            torch.save(emb, os.path.join("embeddings", filename.replace(".png", ".pt")))
    else:
        continue