In [None]:
# Code adapted from:
#   - GradCAM: https://colab.research.google.com/github/kevinzakka/clip_playground/blob/main/CLIP_GradCAM_Visualization.ipynb#scrollTo=caPbAhFlRBwT
#   - OpenCLIP: https://github.com/mlfoundations/open_clip

In [None]:
import urllib.request
import typing as t
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import open_clip
import torch as T
import torch.nn.functional as F
import torchvision.transforms.v2 as TVT
from PIL import Image
from scipy.ndimage import gaussian_filter
from tqdm import trange

In [None]:
from open_clip import pretrained

for model_name, weights in pretrained.list_pretrained():
    if 'convnext' not in model_name.lower():
        continue
    print(f"Model name: {model_name}")
    print(f"Weights: {weights}")
    print()

# Find vit types.
print(set("-".join(m.split('-')[:2]) for m,w in pretrained.list_pretrained() if 'vit' in m.lower()))

In [None]:
# model_name, model_weights_name = 'convnext_xxlarge', 'laion2b_s34b_b82k_augreg'
model_name, model_weights_name = 'convnext_base', 'laion400m_s13b_b51k'

model, _, preprocess = open_clip.create_model_and_transforms(
    model_name, 
    pretrained=model_weights_name, 
    device='cuda',
)
model = model.eval()  # type: ignore
tokenizer = open_clip.get_tokenizer(model_name)

In [None]:
from torchvision.io.image import read_image
from torchvision.transforms.functional import normalize, resize, to_pil_image
from torchcam.methods import GradCAM
from copy import deepcopy

In [None]:
# image_url = 'https://images2.minutemediacdn.com/image/upload/c_crop,h_706,w_1256,x_0,y_64/f_auto,q_auto,w_1100/v1554995050/shape/mentalfloss/516438-istock-637689912.jpg'
image_url = "https://static.toiimg.com/photo/79693966.cms"
# image_url = "https://live.staticflickr.com/2365/2238423921_1275e83f71_b.jpg"
image_path = 'image.jpg'
urllib.request.urlretrieve(image_url, image_path)

texts = [
    # 'hamburger',
    # 'lettuce',
    'pommes frites',
    # 'tomato',
]
tokenized_text = tokenizer(texts)

cell_model = deepcopy(model)
visual_model = cell_model.visual
output_stage = visual_model.trunk.stages[-1]
with T.cuda.amp.autocast():
    with T.no_grad():
        text_features = cell_model.encode_text(tokenized_text.cuda())

    with Image.open(image_path) as img:
        image: T.Tensor = preprocess(img).unsqueeze(0).cuda()

    with GradCAM(
        visual_model, 
        target_layer=output_stage
    ) as cam_extractor:
        # Preprocess your data and feed it to the model
        out = visual_model(image)
        # Retrieve the CAM by passing the class index and the model output
        print(f"Out shape: {out.shape}")
        print(f"Text features shape: {text_features.shape}")
        activation_map = cam_extractor(text_features, out)