In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
import matplotlib.pyplot as plt
import PIL
import gc

import urllib.request
import logging
import os

import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import cv2
from scipy.ndimage import filters
from torch import nn

import clip

from matplotlib.backends.backend_pdf import PdfPages
import torch.distributed as dist
from lavis.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized
from lavis.common.logger import MetricLogger, SmoothedValue
from lavis.common.registry import registry
from lavis.datasets.data_utils import prepare_sample
from lavis.models.clip_models.model import CLIP, load_state_dict, load_openai_model, build_model_from_openai_state_dict
from lavis.models.clip_models import tokenizer
from lavis.processors.clip_processors import ClipImageTrainProcessor
from lavis.processors.blip_processors import BlipCaptionProcessor
from transformers import CLIPProcessor, CLIPModel

from omnixai.data.text import Text
from omnixai.data.image import Image
from omnixai.data.multi_inputs import MultiInputs
from omnixai.preprocessing.image import Resize
from omnixai.explainers.vision_language.specific.gradcam import GradCAM

from lavis.models.clip_models.pretrained import (
    download_pretrained,
    get_pretrained_url,
    list_pretrained_tag_models,
)

In [None]:
def load_Fmodel_clip(chekpoint=True, vit=True):
    preprocess=None
    if vit:
        if chekpoint:
            model_cfg ={
            "embed_dim": 768,
            "quick_gelu": True,
            "vision_cfg": {
                "image_size": 336,
                "layers": 24,
                "width": 1024,
                "patch_size": 14},
            "text_cfg": {
                "context_length": 77,
                "vocab_size": 49408,
                "width": 768,
                "heads": 12,
                "layers": 12}}
            model = CLIP(**model_cfg, add_cls_token=False)
            checkpoint_path = "lavis/checkpoint_best (1).pth"
            model.load_state_dict(load_state_dict(checkpoint_path))
        else:
            model_cfg ={
            "embed_dim": 768,
            "quick_gelu": True,
            "vision_cfg": {
                "image_size": 224,
                "layers": 24,
                "width": 1024,
                "patch_size": 14},
            "text_cfg": {
                "context_length": 77,
                "vocab_size": 49408,
                "width": 768,
                "heads": 12,
                "layers": 12}}
            model = CLIP(**model_cfg, add_cls_token=False)
            model = load_openai_model(name="ViT-L-14", device="cpu", jit=False)
            model = model.float()
        _, preprocess = clip.load("ViT-L/14")
    else:
        model, preprocess = clip.load("RN50", device="cuda")
    return model, preprocess

m,p=load_Fmodel_clip(chekpoint=False, vit=True)

In [None]:
class GradCamViT:
    def __init__(self, model, target, height=24):
        self.model = model.eval()  
        self.handlers = [] 
        self.target = target
        self.height=height

    def encode_text(self, text):
        x = self.model.token_embedding(text)
        x = x + self.model.positional_embedding
        
        x = x.permute(1, 0, 2) 
        x = self.model.transformer(x, attn_mask=self.model.attn_mask)
        x = x.permute(1, 0, 2) 
        
        x = self.model.ln_final(x)
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.model.text_projection
        return x

    def image_attn_blocks(self, size):
        image_attn_blocks = list(dict(self.model.visual.transformer.resblocks.named_children()).values())
        num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
            
        cam = image_attn_blocks[-1].attn_grad.detach()
        cam = cam.mean(dim=0) 
        cam = cam[0, 1:]
        cam = cam.reshape(1, 1, size, size)
        cam = torch.nn.functional.interpolate(cam, size=224, mode='bilinear')
        cam = cam.reshape(224, 224)
        cam = (cam - cam.min()) / (cam.max() - cam.min())
        return cam

    def txt_attn_blocks(self, index):
        txt_attn_blocks = list(dict(self.model.transformer.resblocks.named_children()).values())
        num_tokens = txt_attn_blocks[0].attn_grad.shape[-1]
        cam = txt_attn_blocks[-1].attn_grad.detach()
        cam = cam.mean(dim=0)
        cam = cam[index]
        cam = (cam - cam.min()) / (cam.max() - cam.min())
        return cam
            

    def __call__(self, inputs, val, index, size):
        self.model.zero_grad()  # Zero the gradients

        image=inputs["im"]
        text=inputs["txt"]
        
        self.model.eval()
        
        image_features = self.model.visual(image)
        image_features_norm = image_features.norm(dim=-1, keepdim=True)
        image_features = image_features / image_features_norm
        
        target_features = self.encode_text(text)
        target_features_norm = target_features.norm(dim=-1, keepdim=True)
        target_features = target_features / target_features_norm
        
        similarity = image_features[0].dot(target_features[0])
        self.model.zero_grad()
        similarity.backward(retain_graph=True)

        cam_img = self.image_attn_blocks(size)
        cam_txt = self.txt_attn_blocks(index)
        
        return cam_img.cpu().data.numpy(), cam_txt.cpu().data.numpy()

In [None]:
def _get_text_xticks(sentence):
    tokens = [word_.strip() for word_ in sentence.split('<\w>')][:77]
    return tokens

In [None]:
def _plot_score(vec, pred_text, xticks, baton=True):
    xticks=xticks[:xticks.index('<end_of_text>')+1]
    vec=vec[:len(xticks)]
    _axis_fontsize=13
    if baton==True:
        fig=plt.figure(figsize = (len(vec)+1,10))
        
        colors = plt.cm.viridis(vec)
        
        plt.bar(range(len(vec)), vec, color=colors, edgecolor='black', align='center')
        plt.yticks([0, 0.5, 1], fontsize=_axis_fontsize)
        plt.xticks(range(0,len(vec)+1), xticks+[" "], fontsize=_axis_fontsize)

        sm = plt.cm.ScalarMappable(cmap='viridis', norm=plt.Normalize(vmin=0, vmax=1))
        sm.set_array([])
        cbar = plt.colorbar(sm, orientation='horizontal', pad=0.2)
        cbar.set_label('Value Intensity', fontsize=12)
        
        plt.figtext(x=0.13, y=0.54, s='Prediction: {}'.format(pred_text), fontsize=15, fontname='sans-serif')
    else:
        fig=plt.figure(figsize = (len(vec)+1,2))
        
        plt.yticks([])
        plt.xticks(range(0,len(vec)+1), xticks+[" "], fontsize=_axis_fontsize)
        fig.add_subplot(1, 1, 1)
        plt.figtext(x=0.13, y=0.54, s='Prediction: {}'.format(pred_text), fontsize=15, fontname='sans-serif')
        
        img = plt.imshow([vec], vmin=0, vmax=1)

    plt.show()