In [11]:
import torch
import torch.nn as nn
from transformers import BlipForConditionalGeneration, AutoProcessor, ViTModel

class CustomBLIP(nn.Module):
    def __init__(self, vision_encoder_name="vit-small-patch16-224",
                 base_blip_model="blip-image-captioning-base"):
        super().__init__()

        # Load pretrained BLIP
        self.blip = BlipForConditionalGeneration.from_pretrained(base_blip_model)

        # Load a new smaller vision encoder
        self.new_vision_encoder = ViTModel.from_pretrained(vision_encoder_name)
        
        blip_proj_in = self.blip.vision_model.encoder.config.projection_dim
        # Replace vision encoder
        self.blip.vision_model = self.new_vision_encoder

        # Match dimensions if needed
        vision_hidden_size = self.new_vision_encoder.config.hidden_size

        if vision_hidden_size != blip_proj_in:
            print(f"Projecting vision features from {vision_hidden_size} -> {blip_proj_in}")
            self.blip.vision_model.encoder.config.projection_dim = vision_hidden_size

    def forward(self, pixel_values, input_ids, attention_mask, labels=None):
        return self.blip(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            return_dict=True
        )

    def generate(self, pixel_values, **gen_kwargs):
        return self.blip.generate(pixel_values=pixel_values, **gen_kwargs)

In [15]:
from PIL import Image
from transformers import AutoProcessor
import requests

#image = Image.open("example.jpg").convert("RGB")
url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
processor = AutoProcessor.from_pretrained("blip-image-captioning-base")
processor.image_processor.size = {"height": 224, "width": 224}

inputs = processor(images=image, return_tensors="pt").to("cuda")
model = CustomBLIP().to("cuda")

Some weights of ViTModel were not initialized from the model checkpoint at vit-small-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Projecting vision features from 384 -> 512


In [13]:
out = model.generate(pixel_values=inputs["pixel_values"])
caption = processor.tokenizer.batch_decode(out, skip_special_tokens=True)[0]
print("Caption:", caption)

ValueError: Input image size (384*384) doesn't match model (224*224).

In [10]:
blip = BlipForConditionalGeneration.from_pretrained("blip-image-captioning-base")
blip.vision_model.encoder.config.projection_dim

512