### https://captum.ai/tutorials/Multimodal_VQA_Captum_Insights

In [12]:
import os, sys

PYTORCH_VQA_DIR = os.path.realpath(f"{os.getcwd()}{os.sep}pytorch_vqa")
PYTORCH_RESNET_DIR = os.path.realpath(f"{os.getcwd()}{os.sep}pytorch_resnet")

# Please modify this path to where it is located on your machine
# you can download this model from: 
# https://github.com/Cyanogenoid/pytorch-vqa/releases/download/v1.0/2017-08-04_00.55.19.pth
VQA_MODEL_PATH = "models/2017-08-04_00.55.19.pth"

assert(os.path.exists(PYTORCH_VQA_DIR))
assert(os.path.exists(PYTORCH_RESNET_DIR))
assert(os.path.exists(VQA_MODEL_PATH))

In [13]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from PIL import Image

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

try:
    from pytorch_resnet import resnet  # from pytorch-resnet
except:
    print("please provide a valid path to pytorch-resnet")

try:
    from pytorch_vqa.model import Net, apply_attention, tile_2d_over_nd  # from pytorch-vqa
    from pytorch_vqa.utils import get_transform  # from pytorch-vqa
except:
    print("please provide a valid path to pytorch-vqa")
    
from captum.insights import AttributionVisualizer, Batch
from captum.insights.attr_vis.features import ImageFeature, TextFeature
from captum.attr import TokenReferenceBase, configure_interpretable_embedding_layer, remove_interpretable_embedding_layer

In [14]:
# Let's set the device we will use for model inference
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [15]:
saved_state = torch.load(VQA_MODEL_PATH, map_location=device)

# reading vocabulary from saved model
vocab = saved_state["vocab"]

# reading word tokens from saved model
token_to_index = vocab["question"]

# reading answers from saved model
answer_to_index = vocab["answer"]

num_tokens = len(token_to_index) + 1

# reading answer classes from the vocabulary
answer_words = ["unk"] * len(answer_to_index)
for w, idx in answer_to_index.items():
    answer_words[idx] = w

vqa_net = torch.nn.DataParallel(Net(num_tokens), device_ids=[0, 1])
vqa_net.load_state_dict(saved_state["weights"])
vqa_net = vqa_net.to(device)

In [17]:
 # for visualization to convert indices to tokens for questions
question_words = ["unk"] * num_tokens
for w, idx in token_to_index.items():
    question_words[idx] = w