In [None]:
def get_grad(img_path, text, std=0.0):
    img = transform(Image.open(img_path)).unsqueeze(0).to(device).to(torch.float)
    if std != 0:
        img = img + torch.normal(0.0, std, list(img.shape), device=img.device)
    img.requires_grad = True
    img_feats = model.encode_image(img)

    tokenized_text = clip.tokenize(text, truncate=True).to(device)
    text_feats = model.encode_text(tokenized_text)

    sim = torch.cosine_similarity(img_feats, text_feats).unsqueeze(-1)
    sim.backward()

    grad = img.grad
    return grad

def sg(img_path, text, std=1.0, nt=10, pos=False, absolute=False):
    grads = torch.stack([get_grad(img_path, text, std=std) for _ in range(nt)])
    if pos:
        grads[grads < 0] = 0
    if absolute:
        grads = torch.abs(grads)
    return torch.mean(grads, dim=0)
    
    

#grad = get_grad(img_path, text, std=1.0)
#mean_grad = sg(img_path, text, std=1.0, nt=10, pos=False, absolute=True)
#p = plt.hist(mean_grad.cpu().flatten().numpy(), bins=1000)

In [None]:
from captum.attr import IntegratedGradients, NoiseTunnel, Saliency
import torchvision

to_pil = torchvision.transforms.ToPILImage()

def minmax(a):
    min_ = a.min()
    max_ = a.max()
    return (a - min_) / (max_ - min_)


class MeanModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        
    def forward(self, x):
        return self.model.encode_image(x).mean(dim=1).unsqueeze(-1)

    
class CosineModel(torch.nn.Module):
    def __init__(self, model, tokenized_text):
        super().__init__()
        self.model = model
        self.text_feats = model.encode_text(tokenized_text)
    
    def forward(self, img):
        #img, tokenized_text = inputs
        img_feats = self.model.encode_image(img)
        #print(img_feats.squeeze().shape)
        #return img_feats.squeeze()
        return torch.cosine_similarity(img_feats, self.text_feats).unsqueeze(-1)


def get_saliency(model, img_path, transform):
    # wrap model to get mean feature
    model = MeanModelWrapper(model).cpu().float()
    # load img
    img = transform(Image.open(img_path)).unsqueeze(0).cpu().float()
    img.requires_grad = True
    # get attribution
    ig = IntegratedGradients(model)
    #ig = Saliency(model)
    nt = NoiseTunnel(ig)
    attribution = nt.attribute(img, nt_type='smoothgrad',
                               nt_samples=10, target=0)
    return attribution


def get_saliency_cosine(model, img_path, text, transform, std=0.5, 
                        nt_samples=100, use_ig=False, nt_type="smoothgrad",
                        abs=False, pos=False, t=5, perc=1):
    device = torch.device("cuda")
    dtype = torch.float
    # tokenize text
    tokenized = clip.tokenize(text, truncate=True).to(device)
    # wrap model to get cosine sim of text and img. feed text in here to avoid casting to float when passing it in  nt.attribute
    model.to(device).to(dtype)
    model = CosineModel(model, tokenized)
    # load img
    img = transform(Image.open(img_path)).unsqueeze(0).to(device).to(dtype)
    img.requires_grad = True
    # get attribution
    if use_ig:
        ig = IntegratedGradients(model, multiply_by_inputs=True)
        kwargs = {"internal_batch_size": 8, "n_steps": 50}
    else:
        ig = Saliency(model)
        kwargs = {"abs": abs}
    nt = NoiseTunnel(ig)
    attribution = nt.attribute(img, nt_type=nt_type, nt_samples=nt_samples, 
                               target=0, stdevs=std, nt_samples_batch_size=4,
                               **kwargs)
    #attribution = attribution.squeeze().mean(dim=0)
    # normed attribution
    attribution = attribution.squeeze().mean(dim=0) # take mean over rgb dimension
    
    p = plt.hist(attribution.cpu().flatten().numpy(), bins=1000)
    plt.show()
    
    if pos:
        print("Min val: ", min(attribution[attribution > 0]))
        attribution[attribution < 0] = min(attribution[attribution > 0])
    
    attribution = minmax(attribution)
    #attribution = torch.softmax(attribution.flatten() * t, dim=0).reshape(*attribution.shape)
    
    quant = torch.quantile(attribution, 1 - perc / 100)
    attribution[attribution > quant] = quant
    quant = torch.quantile(attribution, perc / 100)
    attribution[attribution < quant] = quant

    
    attribution = minmax(attribution)
    
    if not abs and not pos:
        attribution = attribution * 2 - 1
    return attribution            

In [None]:
idx = 0

std = 1.0
nt_samples = 30
abs = True
pos = False
t = 1

#img_path = img_paths[idx]
img_path = "ipod_note_apple.png"
#caption = captions["caption_text"].iloc[idx * 5]
#caption = "wait a second, this is just an apple with a label saying ipod"
#caption = "a backyard"
caption = "an apple stem"

loaded_img = torchvision.transforms.Compose(transform.transforms[:-1])(Image.open(img_path)).squeeze()
print("Caption:", caption)
#Image.open(img_path).show()
to_pil(loaded_img).show()

import cv2
def show_cam_on_image(img, mask):
    mask = (mask * 255).astype(np.uint8)
    heatmap = cv2.applyColorMap(mask, cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return cam

print("Custom SG:")
sal = sg(img_path, caption, std=std, nt=nt_samples, pos=pos, absolute=abs)
sal = sal.squeeze().mean(dim=0)
sal = (sal - sal.min()) / (sal.max() - sal.min())
p = plt.hist(sal.cpu().flatten().numpy(), bins=1000)
plt.show()
cmap = None if abs or pos else "coolwarm"
sns.heatmap(sal.squeeze().cpu().numpy(), cmap=cmap)
plt.axis("off")
plt.show()
#cosine_sal = torch.abs(cosine_sal)
to_pil(sal.squeeze()).show()
to_pil(sal.squeeze() * loaded_img.squeeze().cuda()).show()
# heatmap mask
image = np.moveaxis(np.array(loaded_img), 0, -1)
image_relevance = sal.cpu().unsqueeze(-1).float().numpy()
vis = show_cam_on_image(image, image_relevance)
vis = np.uint8(255 * vis)
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
plt.imshow(vis)
plt.show()

print("Smoothgrad")
cosine_sal = get_saliency_cosine(model, img_path, caption, transform, std=std, nt_samples=nt_samples, abs=abs, t=t, pos=pos)
p = plt.hist(cosine_sal.cpu().flatten().numpy(), bins=1000)
plt.show()
cmap = None if abs or pos else "coolwarm"
sns.heatmap(cosine_sal.squeeze().cpu().numpy(), cmap=cmap)
plt.axis("off")
plt.show()
cosine_sal[cosine_sal < 0.5] = 0
cosine_sal = torch.abs(cosine_sal)
to_pil(cosine_sal.squeeze()).show()
to_pil(cosine_sal.squeeze() * loaded_img.squeeze().cuda()).show()

# create heatmap from mask on image
image = np.moveaxis(np.array(loaded_img), 0, -1)
image_relevance = cosine_sal.cpu().unsqueeze(-1).float().numpy()
vis = show_cam_on_image(image, image_relevance)
vis = np.uint8(255 * vis)
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)

plt.imshow(vis)
plt.show()


vargrad = True
if vargrad:
    print("VarGrad")
    cosine_sal = get_saliency_cosine(model, img_path, caption, transform, std=std, nt_samples=nt_samples, nt_type="vargrad", t=t, abs=abs)
    p = plt.hist(minmax(cosine_sal).cpu().flatten().numpy(), bins=1000)
    plt.show()
    to_pil(minmax(cosine_sal).squeeze()).show()
    to_pil(minmax(cosine_sal).squeeze() * loaded_img.squeeze().cuda()).show()
ig = False
if ig:
    print("IG")
    cosine_sal = get_saliency_cosine(model, img_path, caption, transform, std=std, nt_samples=nt_samples // 5, use_ig=True)
    cosine_sal = minmax(torch.abs(cosine_sal - 0.5))
    p = plt.hist((cosine_sal).cpu().flatten().numpy(), bins=1000)
    plt.show()
    to_pil((cosine_sal).squeeze()).show()
    to_pil((cosine_sal).squeeze() * loaded_img.squeeze().cuda()).show()