In [1]:
import cv2
import numpy as np
import re
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from skimage import transform as skimage_transform
from scipy.ndimage import filters
from matplotlib import pyplot as plt
from transformers import AutoProcessor, BlipForQuestionAnswering, BlipImageProcessor, BlipProcessor

vl_model_name = "/data/hyeongchanim/Fine_Tune_BLIP/Model/blip-saved-model"
cache_dir = "/data/huggingface_models"
use_cuda = False

model = BlipForQuestionAnswering.from_pretrained(vl_model_name)# , cache_dir=cache_dir)
processor = BlipProcessor.from_pretrained('Salesforce/blip-vqa-base', cache_dir=cache_dir)


def pre_caption(caption,max_words=70):
    caption = re.sub(
        r"([,.'!?\"()*#:;~])",
        '',
        caption.lower(),
    ).replace('-', ' ').replace('/', ' ')

    caption = re.sub(
        r"\s{2,}",
        ' ',
        caption,
    )
    caption = caption.rstrip('\n') 
    caption = caption.strip(' ')

    #truncate caption
    caption_words = caption.split(' ')
    if len(caption_words)>max_words:
        caption = ' '.join(caption_words[:max_words])            
    return caption

transform = transforms.Compose([
    transforms.Resize((384,384),interpolation=Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])     


tokenizer = processor.tokenizer 

model.eval()

Some weights of BlipForQuestionAnswering were not initialized from the model checkpoint at /data/hyeongchanim/Fine_Tune_BLIP/Model/blip-saved-model and are newly initialized: ['text_encoder.embeddings.position_ids', 'text_decoder.bert.embeddings.position_ids']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BlipForQuestionAnswering(
  (vision_model): BlipVisionModel(
    (embeddings): BlipVisionEmbeddings(
      (patch_embedding): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (encoder): BlipEncoder(
      (layers): ModuleList(
        (0-11): 12 x BlipEncoderLayer(
          (self_attn): BlipAttention(
            (dropout): Dropout(p=0.0, inplace=False)
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (projection): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): BlipMLP(
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
    (post_layernorm): LayerNorm((768,), eps=1e-05, e

In [2]:
if use_cuda:
    model.cuda() 
    
    
image_path = '/data/hyeongchanim/0a2b797d08.jpg'
image_pil = Image.open(image_path).convert('RGB')   
image = transform(image_pil).unsqueeze(0)  

caption = "~~~"
text = pre_caption(caption)
query = tokenizer(text, return_tensors="pt")
print(query)

inputs = processor(image_pil, text, return_tensors="pt")
outputs = model.generate(**inputs)
print(processor.decode(outputs[0], skip_special_tokens=True))

block_num = 4
model.text_encoder.base_model.base_model.encoder.layer[block_num].crossattention.self.save_attention = True
caption = "1, 2"
text = pre_caption(caption)
query = tokenizer(text, return_tensors="pt")
print(query)
if use_cuda:
    image = image.cuda()
    query = query.to(image.device)
    

encoding = processor(image_pil, text, return_tensors="pt")       

answer = ['1, 2']
print(answer)
labels = processor.tokenizer(answer, return_tensors="pt")["input_ids"]
outputs = model(**encoding, labels=labels)

{'input_ids': tensor([[  101,  3830,  1996,  5344,  1999,  1996,  2445,  6302,  2004,  4076,
         23911, 10873,  2000,  1996,  5344,  2013,  2187,  2000,  2157,  3225,
          2007,  1015,  2079,  2025, 23911, 10873,  2000,  1996, 10457,  1997,
          4641,  2069,  3830,  5710,  5344,  2065,  2045,  2015,  2053, 15681,
          2854,  2227,  2074,  2069,  2507,  3437,  2066,  2023,  3904,  2065,
          2045,  2003,  1037, 15681,  2854,  2227,  2074,  2069,  2507,  3437,
          2066,  2023,  1996,  8275,  2227,  2003,  1063,  3830,  3616,  1997,
          1996,  6302,  2007, 15681,  2854,  5344,  1065,   102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1]])}




the fake face is 1, 2
{'input_ids': tensor([[ 101, 1996, 8275, 2227, 2003, 1015, 1016,  102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}
['The fake face is 1, 2']


In [23]:
model

BlipForQuestionAnswering(
  (vision_model): BlipVisionModel(
    (embeddings): BlipVisionEmbeddings(
      (patch_embedding): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (encoder): BlipEncoder(
      (layers): ModuleList(
        (0-11): 12 x BlipEncoderLayer(
          (self_attn): BlipAttention(
            (dropout): Dropout(p=0.0, inplace=False)
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (projection): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): BlipMLP(
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
    (post_layernorm): LayerNorm((768,), eps=1e-05, e

In [None]:
import torch.nn.functional as F

for i in range(12):
    block_num = i
    target_layer = model.text_encoder.base_model.base_model.encoder.layer[block_num].crossattention
    target_layer.self.save_attention = True
    labels = processor.tokenizer(answer, return_tensors="pt")["input_ids"]
    outputs = model(**encoding, labels=labels)

    loss = outputs[0]

    model.zero_grad()
    loss.backward()  

    with torch.no_grad():
        mask = query.attention_mask.view(query.attention_mask.size(0),1,-1,1,1) # (bsz, 1, token_len, 1, 1)
        token_length = query.attention_mask.sum(dim=-1) - 2
        token_length = token_length.cpu()
        # grads and cams [bsz, num_head, seq_len, image_patch] = [1, 12, -1, 24*24]
        grads=target_layer.self.get_attn_gradients()
        grads=F.relu(grads)
        cams=target_layer.self.get_attention_map()
        cams=F.relu(cams)

        cams = cams[:, :, :, 1:].reshape(image.size(0), 12, -1, 24, 24) * mask
        grads = grads[:, :, :, 1:].clamp(0).reshape(image.size(0), 12, -1, 24, 24) * mask

        gradcams = cams * grads
        for ind in range(image.size(0)):
            token_length_ = token_length[ind]
            gradcam = gradcams[ind].mean(0).cpu().detach()
            # [enc token gradcam, average gradcam across token, gradcam for individual token]
            gradcam = torch.cat(
                (
                    gradcam[0:1, :],
                    gradcam[1 : token_length_ + 1, :].sum(dim=0, keepdim=True)
                    / token_length_,
                    gradcam[1:, :],
                )
            )
            gradcams = gradcam
            
    def getAttMap(img, attMap, blur=True, overlap=True):
        # If attMap is a torch tensor, convert it to numpy array
        if isinstance(attMap, torch.Tensor):
            attMap = attMap.detach().cpu().numpy()
        if isinstance(img, torch.Tensor):
            img = img.detach().cpu().numpy()
            
        attMap -= attMap.min()
        if attMap.max() > 0:
            attMap /= attMap.max()
        attMap = skimage_transform.resize(attMap, img.shape[:2], order=3, mode='constant')
        if blur:
            attMap = filters.gaussian_filter(attMap, 0.02*max(img.shape[:2]))
            attMap -= attMap.min()
            attMap /= attMap.max()
        cmap = plt.get_cmap('jet')
        attMapV = cmap(attMap)
        attMapV = np.delete(attMapV, 3, 2)
        if overlap:
            attMap = 1*(1-attMap**0.7).reshape(attMap.shape + (1,))*img + (attMap**0.7).reshape(attMap.shape+(1,)) * attMapV
        return attMap


    num_image = len(query.input_ids[0]) 
    fig, ax = plt.subplots(num_image, 1, figsize=(15,5*num_image))

    rgb_image = cv2.imread(image_path)[:, :, ::-1]
    rgb_image = np.float32(rgb_image) / 255

    ax[0].imshow(rgb_image)
    ax[0].set_yticks([])
    ax[0].set_xticks([])
    ax[0].set_xlabel("Image")
                
    for i,token_id in enumerate(query.input_ids[0][-3:-1]):
        word = tokenizer.decode([token_id])
        gradcam_image = getAttMap(rgb_image, gradcam[i+1])
        ax[i+1].imshow(gradcam_image)
        ax[i+1].set_yticks([])
        ax[i+1].set_xticks([])
        ax[i+1].set_xlabel(word)