In [1]:
import clip
import torch
import torch.nn as nn
from typing import List
from PIL import Image
from captum.attr import IntegratedGradients, visualization

backbone_ckpt = "/home/ksas/Public/model_zoo/clip"
backbone_name = "ViT-B/32"
device = "cuda" if torch.cuda.is_available() else "cpu"

backbone, preprocess = clip.load(backbone_name, device=device, download_root=backbone_ckpt)
backbone = backbone.float()\
            .to("cuda")\
            .eval()

In [None]:
image:torch.Tensor = preprocess(Image.open("data/images/glasses.png")).unsqueeze(0).to(device)
text = clip.tokenize(["African with sunglasses", "Asisan", "European", "Dog"]).to(device)

image_features = backbone.encode_image(image)
text_features = backbone.encode_text(text)

logits_per_image, logits_per_text = backbone(image, text)
print(logits_per_image)
print(logits_per_text)

In [3]:
class ClipVisualWithSimilarity(nn.Module):
    def __init__(self, clip_model, comparison_text:List[str]):
        super(ClipVisualWithSimilarity, self).__init__()
        self.clip_model = clip_model
        self.comparison_text = clip.tokenize(comparison_text).to(device)

    def forward(self, image):
        logits_per_image, _ = self.clip_model(image, self.comparison_text)
        return logits_per_image

In [None]:
clip_sim = ClipVisualWithSimilarity(backbone, ["a man with eyeglasses",
                                               "smile"])
ig = IntegratedGradients(clip_sim)

# 对图像的 Integrated Gradients 计算
attributions = ig.attribute(image, target=0)
_ = visualization.visualize_image_attr(attributions.squeeze().permute((1, 2, 0)).detach().cpu().numpy(), 
                                       image.squeeze().permute((1, 2, 0)).detach().cpu().numpy(), 
                                       "blended_heat_map")

attributions = ig.attribute(image, target=1)
_ = visualization.visualize_image_attr(attributions.squeeze().permute((1, 2, 0)).detach().cpu().numpy(), 
                                       image.squeeze().permute((1, 2, 0)).detach().cpu().numpy(), 
                                       "blended_heat_map")