In [None]:
# Code adapted from:
#   - GradCAM: https://colab.research.google.com/github/kevinzakka/clip_playground/blob/main/CLIP_GradCAM_Visualization.ipynb#scrollTo=caPbAhFlRBwT   # noqa E501
#   - OpenCLIP: https://github.com/mlfoundations/open_clip

In [None]:
import urllib.request
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import open_clip
import torch as T
import torch.nn.functional as F
import torchvision.transforms as TVT
from open_clip import pretrained
from PIL import Image
from scipy.ndimage import gaussian_filter
from tqdm import trange

In [None]:
class Hook:
    """Attaches to a module and records its activations and gradients."""

    def __init__(self, module: T.nn.Module):
        self.data: T.Tensor = None  # type: ignore
        self.hook = module.register_forward_hook(self.save_grad)

    def save_grad(self, module, input, output):
        self.data = output
        output.requires_grad_(True)
        output.retain_grad()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, exc_traceback):
        self.hook.remove()

    @property
    def activation(self) -> T.Tensor:
        return self.data

    @property
    def gradient(self) -> T.Tensor:
        return self.data.grad   # type: ignore


# Reference: https://arxiv.org/abs/1610.02391
def gradCAM(
    model: T.nn.Module,
    input: T.Tensor,
    target: T.Tensor,
    layer: T.nn.Module
) -> T.Tensor:
    # Zero out any gradients at the input.
    if input.grad is not None:
        input.grad.data.zero_()

    # Disable gradient settings.
    requires_grad = {}
    for name, param in model.named_parameters():
        requires_grad[name] = param.requires_grad
        param.requires_grad_(False)

    # Attach a hook to the model at the desired layer.
    assert isinstance(layer, T.nn.Module)
    with Hook(layer) as hook:
        # Do a forward and backward pass.
        output = model(input)
        output.backward(target)

        grad = hook.gradient.float()
        act = hook.activation.float()

        # Global average pool gradient across spatial dimension
        # to obtain importance weights.
        alpha = grad.mean(dim=(2, 3), keepdim=True)
        # Weighted combination of activation maps over channel
        # dimension.
        gradcam = T.sum(act * alpha, dim=1, keepdim=True)
        # We only want neurons with positive influence so we
        # clamp any negative ones.
        gradcam = T.clamp(gradcam, min=0)

    # Resize gradcam to input resolution.
    gradcam = F.interpolate(
        gradcam,
        input.shape[2:],
        mode='bicubic',
        align_corners=False)

    # Restore gradient settings.
    for name, param in model.named_parameters():
        param.requires_grad_(requires_grad[name])

    return gradcam

In [None]:
def normalize(x: np.ndarray) -> np.ndarray:
    # Normalize to [0, 1].
    x = x - x.min()
    if x.max() > 0:
        x = x / x.max()
    return x

# Modified from: https://github.com/salesforce/ALBEF/blob/main/visualization.ipynb


def getAttMap(img, attn_map, blur=True):
    if blur:
        attn_map = gaussian_filter(attn_map, 0.02*max(img.shape[:2]))
    attn_map = normalize(attn_map)
    cmap = plt.get_cmap('jet')
    attn_map_c = np.delete(cmap(attn_map), 3, 2)
    attn_map = 1*(1-attn_map**0.7).reshape(attn_map.shape + (1,))*img + \
        (attn_map**0.7).reshape(attn_map.shape+(1,)) * attn_map_c
    return attn_map


def viz_attn(img, attn_map, blur=True):
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(img)
    # Rescale the attention map to match the image size.
    attn_map = np.array(Image.fromarray(attn_map).resize((img.shape[1], img.shape[0])))
    axes[1].imshow(getAttMap(img, attn_map, blur))
    for ax in axes:
        ax.axis("off")
    return fig, axes


def load_image(img_path, resize=None):
    image = Image.open(img_path).convert("RGB")
    if resize is not None:
        image = image.resize((resize, resize))
    return np.asarray(image).astype(np.float32) / 255.

In [None]:
for model_name, weights in pretrained.list_pretrained():
    if 'convnext' not in model_name.lower():
        continue
    print(f"Model name: {model_name}")
    print(f"Weights: {weights}")
    print()

# Find vit types.
print(set("-".join(m.split('-')[:2]) for m, w in pretrained.list_pretrained() if 'vit' in m.lower()))

In [None]:
# model_name, model_weights_name = 'convnext_xxlarge', 'laion2b_s34b_b82k_augreg'
model_name, model_weights_name = 'convnext_base', 'laion400m_s13b_b51k'

model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=model_weights_name)
model = model.to('cuda').eval()  # type: ignore

In [None]:
tokenizer_name = 'ViT-g-14'
tokenizer = open_clip.get_tokenizer(tokenizer_name)

In [None]:
# image_url = 'https://images2.minutemediacdn.com/image/upload/c_crop,h_706,w_1256,x_0,y_64/f_auto,q_auto,w_1100/v1554995050/shape/mentalfloss/516438-istock-637689912.jpg'  # noqa E501
image_url = "https://static.toiimg.com/photo/79693966.cms"
image_path = 'image.png'
urllib.request.urlretrieve(image_url, image_path)

texts = [
    'cheese',
    'cutting board',
    'ground beef',
    'hamburger bun',
    'hamburger',
    'lettuce',
    'pizza',
    'pommes frites',
    'tomato',
]
tokenized_text = tokenizer(texts)

with T.no_grad(), T.cuda.amp.autocast():  # type: ignore
    text_features = model.encode_text(tokenized_text.cuda())  # type: ignore

image: T.Tensor = preprocess(Image.open(image_path)).unsqueeze(0).cuda()  # type: ignore

augment = TVT.Compose([TVT.ColorJitter(0.5, 0.5, 0.5, 0.5)])
# augment = T.nn.Identity()

n_samples = 30

attn_maps: dict[str, list[np.ndarray]] = defaultdict(list)
with T.cuda.amp.autocast():  # type: ignore
    model_visual = model.visual  # type: ignore
    model_layers = model_visual.trunk.stages  # type: ignore
    for text, text_feature in zip(texts, text_features):
        attn_map_dict = attn_maps[f"{text} - layer_4"]
        for i in trange(n_samples, desc=f"Text: {text}"):
            img = augment(image.clone())
            attn_map_dict.append(gradCAM(model_visual, img, text_feature.unsqueeze(0),
                                 model_layers[3]).squeeze().detach().cpu().numpy())

In [None]:
image_np = load_image(image_path, resize=224)

n_rows = n_cols = int(np.ceil(np.sqrt(len(texts))))
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols*5, n_rows*5))
fig.suptitle(f"Model: {model_name}\nWeights: {model_weights_name}\nTokenizer: {tokenizer_name}")
axes = axes.flatten()
for i, (text, attn_map_array) in enumerate(attn_maps.items()):
    atmap = np.stack(attn_map_array, axis=0).mean(axis=0)
    atmap = np.array(Image.fromarray(atmap).resize((image_np.shape[1], image_np.shape[0])))
    ax = axes[i]
    ax.imshow(getAttMap(image_np, atmap))
    ax.set_title(text)
    ax.axis("off")
plt.show()