In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from scipy.ndimage import gaussian_filter
import open_clip
import urllib.request

In [None]:
class Hook:
    """Attaches to a module and records its activations and gradients."""

    def __init__(self, module: nn.Module):
        self.data = None
        self.hook = module.register_forward_hook(self.save_grad)
        
    def save_grad(self, module, input, output):
        self.data = output
        output.requires_grad_(True)
        output.retain_grad()
        
    def __enter__(self):
        return self
    
    def __exit__(self, exc_type, exc_value, exc_traceback):
        self.hook.remove()
        
    @property
    def activation(self) -> torch.Tensor:
        return self.data
    
    @property
    def gradient(self) -> torch.Tensor:
        return self.data.grad


# Reference: https://arxiv.org/abs/1610.02391
def gradCAM(
    model: nn.Module,
    input: torch.Tensor,
    target: torch.Tensor,
    layer: nn.Module
) -> torch.Tensor:
    # Zero out any gradients at the input.
    if input.grad is not None:
        input.grad.data.zero_()
        
    # Disable gradient settings.
    requires_grad = {}
    for name, param in model.named_parameters():
        requires_grad[name] = param.requires_grad
        param.requires_grad_(False)
        
    # Attach a hook to the model at the desired layer.
    assert isinstance(layer, nn.Module)
    with Hook(layer) as hook:        
        # Do a forward and backward pass.
        output = model(input)
        output.backward(target)

        grad = hook.gradient.float()
        act = hook.activation.float()
    
        # Global average pool gradient across spatial dimension
        # to obtain importance weights.
        alpha = grad.mean(dim=(2, 3), keepdim=True)
        # Weighted combination of activation maps over channel
        # dimension.
        gradcam = torch.sum(act * alpha, dim=1, keepdim=True)
        # We only want neurons with positive influence so we
        # clamp any negative ones.
        gradcam = torch.clamp(gradcam, min=0)

    # Resize gradcam to input resolution.
    gradcam = F.interpolate(
        gradcam,
        input.shape[2:],
        mode='bicubic',
        align_corners=False)
    
    # Restore gradient settings.
    for name, param in model.named_parameters():
        param.requires_grad_(requires_grad[name])
        
    return gradcam

In [None]:
def normalize(x: np.ndarray) -> np.ndarray:
    # Normalize to [0, 1].
    x = x - x.min()
    if x.max() > 0:
        x = x / x.max()
    return x

# Modified from: https://github.com/salesforce/ALBEF/blob/main/visualization.ipynb
def getAttMap(img, attn_map, blur=True):
    if blur:
        attn_map = gaussian_filter(attn_map, 0.02*max(img.shape[:2]))
    attn_map = normalize(attn_map)
    cmap = plt.get_cmap('jet')
    attn_map_c = np.delete(cmap(attn_map), 3, 2)
    attn_map = 1*(1-attn_map**0.7).reshape(attn_map.shape + (1,))*img + \
            (attn_map**0.7).reshape(attn_map.shape+(1,)) * attn_map_c
    return attn_map

def viz_attn(img, attn_map, blur=True):
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(img)
    # Rescale the attention map to match the image size.
    attn_map = np.array(Image.fromarray(attn_map).resize((img.shape[1], img.shape[0])))
    axes[1].imshow(getAttMap(img, attn_map, blur))
    for ax in axes:
        ax.axis("off")
    return fig, axes
    
def load_image(img_path, resize=None):
    image = Image.open(img_path).convert("RGB")
    if resize is not None:
        image = image.resize((resize, resize))
    return np.asarray(image).astype(np.float32) / 255.

In [None]:
from open_clip import pretrained

for model_name, weights in pretrained.list_pretrained():
    if 'convnext' not in model_name.lower():
        continue
    print(f"Model name: {model_name}")
    print(f"Weights: {weights}")
    print()

# Find vit types.
print(set("-".join(m.split('-')[:2]) for m,w in pretrained.list_pretrained() if 'vit' in m.lower()))

In [None]:
model, _, preprocess = open_clip.create_model_and_transforms('convnext_xxlarge', pretrained='laion2b_s34b_b82k_augreg')
model = model.to('cuda')

In [None]:
tokenizer = open_clip.get_tokenizer('ViT-g-14')

In [None]:
# image_url = 'https://images2.minutemediacdn.com/image/upload/c_crop,h_706,w_1256,x_0,y_64/f_auto,q_auto,w_1100/v1554995050/shape/mentalfloss/516438-istock-637689912.jpg'
image_url = "https://static.toiimg.com/photo/79693966.cms"
image_path = 'image.png'
urllib.request.urlretrieve(image_url, image_path)

texts = [
    # "the pommes frites", 
    'the hamburger', 
    # 'hamburger', 
    # 'the pizza',
    # 'the lettuce',
    # 'tomato', 
    # 'the tomato', 
    # 'hamburger bun', 
    # 'cheese',
    # "food",
    # "cutting board",
    # "the meat patty",
    # "ground beef",
]
tokenized_text = tokenizer(texts)

with torch.no_grad(), torch.cuda.amp.autocast():
    text_features = model.encode_text(tokenized_text.cuda())

image = preprocess(Image.open(image_path)).unsqueeze(0).cuda()
image_np = load_image(image_path, resize=224)

In [None]:
import torchvision.transforms as TVT
from tqdm import trange

augment = TVT.Compose([TVT.ColorJitter(0.5, 0.5, 0.5, 0.5)])
augment = torch.nn.Identity()

attn_maps_3 = []
attn_maps_4 = []
with torch.cuda.amp.autocast():
    for n in trange(10):
        img = augment(image)
        attn_map = gradCAM(model.visual, image, text_features, model.visual.trunk.stages[2])
        attn_maps_3.append(attn_map.squeeze().detach().cpu().numpy())
        attn_map = gradCAM(model.visual, image, text_features, model.visual.trunk.stages[3])
        attn_maps_4.append(attn_map.squeeze().detach().cpu().numpy())
np_att_3 = np.stack(attn_maps_3, axis=0)
np_att_4 = np.stack(attn_maps_4, axis=0)

In [None]:
# pred_map = np.mean(np_att_3, axis=0)
# fig, axes = viz_attn(image_np, pred_map, blur=False)

pred_map = np.mean(np.power(np_att_4, 2), axis=0)
pred_map = pred_map / pred_map.max()
pred_map = pred_map * (pred_map > 0.5)
fig, axes = viz_attn(image_np, pred_map, blur=False)