In [None]:
from transformers import AutoConfig
from model import Features2WordsModel, TranslationTransformerConfig
import os
from argparse import Namespace # needed for reading saved argparse parameters
import torch
from torch.nn import functional as F
from PIL import Image
from itertools import product
from evaluation import get_saliency_map
import matplotlib.pyplot as plt

In [None]:
# PARAMETERS TO DEFINE
MODEL_CKPT = # path to model checkpoint
IMG = '../assets/catdog.png'
QUERIES = ['cat', 'dog', 'animal']
LAYER = -1 # e.g. -1; if left None will compute metrics with features for all layers.
HRES = False # set to True if you want to generate saliency maps for a given layer but with the resolution of the lowest layer.
POOL_LOCS = 'keep_dims' # choose between None, reduce_dims, keep_dims. Pools lower layer locations.
KERNEL_SIZE = -1 # kernel size for pooling lower layer locations. We typically use 3 for layer -2 and 7 for layer -3.
GPU_ID = 0

In [None]:
# load model
translation_config = TranslationTransformerConfig.from_json_file(os.path.join(MODEL_CKPT, 'translation_model_config.json'))
lm_model_config = AutoConfig.from_pretrained(os.path.join(MODEL_CKPT, 'lm_model_config.json'))
max_length = lm_model_config.max_length

with open(os.path.join(os.path.join(MODEL_CKPT, '../train_params.txt')), 'r') as f:
    namespace_str = f.read()
train_args = eval(namespace_str)
vision_backbone = train_args.vision_backbone
vision_feat_func = train_args.vision_feat_func
vision_feat_layers = train_args.vision_feat_layers
language_model = train_args.language_model
model = Features2WordsModel(translation_config=translation_config, cnn_name=vision_backbone, vision_feat_func=vision_feat_func, vision_feat_layers=vision_feat_layers, lm_model_name=language_model, max_length=max_length)
model_checkpoint = torch.load(os.path.join(MODEL_CKPT, 'model.pt'))
model.translation.load_state_dict(model_checkpoint["MODEL_STATE"])
model.eval()
model.to(GPU_ID)
transform = model.transform

In [None]:
def grid(n, dim):
  return sorted(list(set(product(range(n), repeat=dim))))

def prepare_grid(model, layer, hres, pool_locs):
  stride = 1
  # prepare spatial locations
  gs = model.grid_sizes[::-1]
  if layer is None:
      layer_id = None
      layer_name = 'all'
      gs = gs[0] # get the size of the lowest layer the model was trained on
  else:
      layer_name = str(layer)
      layer_id = -layer - 1

      if hres:
          gs = gs[0] # get the size of the lowest layer the model was trained on
          layer_name += '_hres'
      else:
          if pool_locs == 'reduce_dims':
              gs = gs[-1]
              stride = int(gs[layer] / gs[-1])
          else:
              gs = gs[layer]
  
  loc_ids = grid(gs, 2)
  return gs, loc_ids, layer_id, layer_name, stride

def get_saliency(model, layer, hres, pool_locs, kernel_size, img, query):
    gs, loc_ids, layer_id, layer_name, stride = prepare_grid(model, layer, hres, pool_locs)
    loss_saliency, _ = get_saliency_map(model, loc_ids, img.unsqueeze(0), query, layer_id, gpu_id=GPU_ID, hres=hres, pool_locs=pool_locs, kernel_size=kernel_size, stride=stride)
    loss_saliency_interp = loss_saliency.view(-1, gs, gs).unsqueeze(dim=1)
    loss_saliency_interp = F.interpolate(loss_saliency_interp, tuple(img.shape[1:]), mode='bilinear')
    smap_loss = loss_saliency_interp[0].permute(1, 2, 0).cpu().numpy()
    smap_loss = -smap_loss

    return smap_loss, layer_name

In [None]:
# preprocess image
img = Image.open(IMG).convert("RGB")
img_wo_transform = img.copy()
img = transform(img)
img = img.to(GPU_ID)

for i, q in enumerate(QUERIES):
    smap, layer_name = get_saliency(model, LAYER, HRES, POOL_LOCS, KERNEL_SIZE, img, q)
    plt.figure(i)
    plt.imshow(img_wo_transform)
    plt.axis('off')
    plt.title(f'Layer {layer_name}')
    plt.imshow(smap, cmap='jet', alpha=0.5)
    plt.show()