In [1]:
import sys
import cv2
import numpy as np
import matplotlib.pyplot as plt

In [11]:
import torch
import torch.nn.functional as F
from datasets import load_from_disk

In [3]:
sys.path.append("/mnt/storage2/arafat_shovon/flow_matching_and_llm/notebooks")

In [4]:
from transformers import(
    Blip2ForConditionalGeneration,
    Blip2Processor
)

**Load the Model**

In [5]:
model_name = "Salesforce/blip2-opt-2.7b"

processor = Blip2Processor.from_pretrained(model_name, 
                                        cache_dir="../data/cache")

model = Blip2ForConditionalGeneration.from_pretrained(model_name,
                                                    torch_dtype=torch.float16,
                                                    cache_dir="../data/cache",
                                                    device_map="auto")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

**Load the Dataset**

In [6]:
data = load_from_disk("../data/flickr8k/train")

**GradCam Class**

In [33]:
class Blip2GradCam():
    def __init__(self, model:Blip2ForConditionalGeneration, processor:Blip2Processor):
        
        self.model = model
        self.processor = processor
        self.device = next(model.parameters()).device
        
        self.target_layer = model.vision_model.encoder.layers[-1]
        self.activations = None
        self.gradients = None
        
        self.target_layer.register_forward_hook(self.forward_hook)
        self.target_layer.register_full_backward_hook(self.backward_hook)

        for params in self.model.language_model.parameters():
            params.requires_grad = False
            
        for params in self.model.qformer.parameters():
            params.requires_grad = False
        
        
    def forward_hook(self, module:torch.nn.Module, input, output):
        self.activations = output[0].detach()
    
    
    def backward_hook(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()
        
    
    def generate_cam(self, image, prompt:str=None):
        orig_size = image.size
        print(orig_size)
        text = prompt if prompt else None
        inputs = self.processor(images=image, text=text, return_tensors="pt").to(self.device, dtype=torch.float16)
        
        for param in self.model.vision_model.parameters():
            param.requires_grad = True
        
        self.model.eval()
        with torch.enable_grad():
            generated_ids = self.model.generate(**inputs)
            output = self.model(pixel_values=inputs['pixel_values'],
                                input_ids=generated_ids,
                                labels=generated_ids
                            )
            
            loss = output.loss
            self.model.zero_grad()
            loss.backward()
            
            cam = self.compute_cam()
            print(cam.shape)
            cam_resized = cv2.resize(cam, orig_size)
            caption = self.processor.decode(generated_ids[0], skip_special_tokens=True)
        
        return cam_resized, caption
    
    
    def compute_cam(self):
        weights = torch.mean(self.gradients, dim=-1, keepdim=True)
        cam = torch.sum(weights * self.activations, dim=-1)
        cam = cam[0]
        cam = cam[1:]
        num_patches = cam.shape[0]
        grid_size = int(np.sqrt(num_patches))
        cam = cam.reshape(grid_size, grid_size)
        cam = F.relu(cam)
        cam = cam - cam.min()
        if cam.max() > 0:
            cam = cam / cam.max()
        
        return cam.detach().cpu().numpy()
        

In [34]:
grad_cam_model = Blip2GradCam(model=model, processor=processor)

**Generate Grad-Cam Image**

In [35]:
index = 203
image = data[index]['image']
cam, caption = grad_cam_model.generate_cam(image)

plt.subplot(1,2,1)
plt.imshow(image)
plt.title("Real Image")

plt.subolot(1,2,2)
plt.imshow(cam)
plt.title("Grad Cam Image")

print(f"Generated Caption: {caption}")

(500, 333)
(16, 16)


error: OpenCV(4.12.0) /io/opencv/modules/imgproc/src/resize.cpp:4086: error: (-215:Assertion failed) func != 0 in function 'resize'


In [27]:
image.size

# inputs = processor(images=image, return_tensors="pt").to("cuda", dtype=torch.float16)
# generated_ids = model.generate(**inputs, max_length=50)
# print(processor.decode(generated_ids[0], skip_special_tokens=True))

(500, 333)