In [8]:
from sklearn.utils import shuffle
import numpy as np
import pickle
from saliency.core.base import *
from saliency.core import IntegratedGradients
import torch
import torch.nn.functional as F
import ig2
from ig2 import REP_LAYER_VALUES, REP_DISTANCE_GRADIENTS
from models.model import TREC_CNN
import torch.nn as nn

### Load TREC dataset

In [2]:
def read_TREC():
    data = {}

    def read(mode):
        x, y = [], []

        with open("data/TREC/TREC_" + mode + ".txt", "r", encoding="utf-8") as f:
            for line in f:
                if line[-1] == "\n":
                    line = line[:-1]
                y.append(line.split()[0].split(":")[0])
                x.append(line.split()[1:])

        # x, y = shuffle(x, y)

        if mode == "train":
            dev_idx = len(x) // 10
            data["dev_x"], data["dev_y"] = x[:dev_idx], y[:dev_idx]
            data["train_x"], data["train_y"] = x[dev_idx:], y[dev_idx:]
        else:
            data["test_x"], data["test_y"] = x, y

    read("train")
    read("test")

    return data

data = read_TREC()
data["vocab"] = sorted(list(set([w for sent in data["train_x"] + data["dev_x"] + data["test_x"] for w in sent])))
data["classes"] = sorted(list(set(data["train_y"])))
data["word_to_idx"] = {w: i for i, w in enumerate(data["vocab"])}
data["idx_to_word"] = {i: w for i, w in enumerate(data["vocab"])}
params = {
    "MODEL":"pretrained",
    "MAX_SENT_LEN": max([len(sent) for sent in data["train_x"] + data["dev_x"] + data["test_x"]]),
    "BATCH_SIZE": 50,
    "WORD_DIM": 300,
    "VOCAB_SIZE": len(data["vocab"]),
    "CLASS_SIZE": len(data["classes"]),
    "FILTERS": [3, 4, 5],
    "FILTER_NUM": [100, 100, 100],
    "DROPOUT_PROB": 0.5,
    "NORM_LIMIT": 3,
    "GPU":'cuda:7'
}
cls_fullname_dict ={'ABBR':'ABBREVIATION','DESC':'DESCRIPTION','ENTY':'ENTITY','HUM':'HUMAN',
                        'LOC':'LOCATION','NUM':'NUMERIC'}

class_idx_str = 'class_idx_str'
sentence_x_word, sentence_y_name = data["train_x"][:500], data["train_y"][:500]
sentence_x = [[data["word_to_idx"][w] if w in data["vocab"] else params["VOCAB_SIZE"] for w in sent] +
        [params["VOCAB_SIZE"] + 1] * (params["MAX_SENT_LEN"] - len(sent))
        for sent in sentence_x_word]
sentence_y = np.array([data["classes"].index(c) for c in sentence_y_name])

### Load question classification model

In [3]:
model = TREC_CNN(**params).to(params["GPU"])
model.load_state_dict( torch.load("models/TREC_CNN.pkl"))
model.eval()
rep_layer = model.fc
rep_layer_outputs = {}
def rep_layer_forward(m, i, o):
    rep_layer_outputs[REP_LAYER_VALUES] = i[0]
forward_hook = rep_layer.register_forward_hook(rep_layer_forward)
emb_min, emb_max = model.embedding.weight.min().item(), model.embedding.weight.max().item()

def get_embedding(sentence):
    return model.embedding(sentence)

def get_prediction(embedding):
    WORD_DIM = model.WORD_DIM
    MAX_SENT_LEN = model.MAX_SENT_LEN
    x = embedding.view(-1, 1, WORD_DIM * MAX_SENT_LEN)
    conv_results = [
        F.max_pool1d(F.relu(model.get_conv(i)(x)), MAX_SENT_LEN - model.FILTERS[i] + 1)
            .view(-1, model.FILTER_NUM[i])
        for i in range(len(model.FILTERS))]

    x = torch.cat(conv_results, 1)
    x = F.dropout(x, p=model.DROPOUT_PROB, training=model.training)
    x = model.fc(x)
    return x

### Utility methods

In [4]:
class_idx_str = 'class_idx_str'
def call_model_function(explicand_emb, call_model_args=None, expected_keys=None):
    explicand_emb = torch.tensor(explicand_emb, dtype=torch.float32).to(params["GPU"])
    explicand_emb.requires_grad = True
    logits = get_prediction(explicand_emb)
    target_class_idx =  torch.LongTensor(np.array(call_model_args[class_idx_str]).reshape(-1,1)).to(params["GPU"])
    # logits = model(sentence)
    # m = torch.nn.Softmax(dim=1)
    # output = m(logits)
    if INPUT_OUTPUT_GRADIENTS in expected_keys:
        # outputs = logits[:,target_class_idx]
        if logits.size(0)>1 and target_class_idx.size(0) == 1:
            target_class_idx = target_class_idx.repeat(logits.size(0),1) 
        target_class_idx = torch.zeros_like(logits).scatter_(1, target_class_idx, 1).detach()        
        grads = torch.autograd.grad(logits, explicand_emb, grad_outputs=target_class_idx)
        gradients = grads[0].cpu().detach().numpy()
        # To word-level gradient (sum across embedding dimension)
        # gradients = np.sum(gradients,axis=-1)
        return {INPUT_OUTPUT_GRADIENTS: gradients}

    if REP_LAYER_VALUES in expected_keys:        
        return rep_layer_outputs

    if REP_DISTANCE_GRADIENTS in expected_keys:
        loss_fn = torch.nn.MSELoss()
        baseline_conv = call_model_args['layer_baseline']
        input_conv = rep_layer_outputs[REP_LAYER_VALUES]
        loss = -1 * loss_fn(input_conv, baseline_conv)
        loss.backward()
        grads = explicand_emb.grad.data
        gradients = grads.cpu().detach().numpy()
        return {REP_DISTANCE_GRADIENTS: gradients,
                'loss':loss}

### IG2 attribution for ten samples

The attribution are calulated on the word embeddings.

In [10]:
 # Only uses the samples with different labels to explained y as the references
def select_counter_refs(explained_y):
    ref_idx = []
    for i in range(100,120): # Randomly choose from 20 samples
        if sentence_y[i]!=explained_y:
            ref_idx.append(i)
    return  np.array(sentence_x)[ref_idx,:]

ig2_mask_list = []
exp_idx_list = np.arange(100,110) # Randomly choose 10 sentences
for exp_idx in exp_idx_list:
    print([sentence_y_name[exp_idx]]+sentence_x_word[exp_idx])
    explicand = np.array(sentence_x)[exp_idx,:]
    explicand_emb = get_embedding(torch.LongTensor(explicand).to(params["GPU"])).detach().cpu().numpy()
    prediction = torch.argmax(get_prediction(torch.tensor(explicand_emb).to(params["GPU"])),dim=1).cpu().numpy()
    assert prediction==sentence_y[exp_idx]

    ref_sentence = select_counter_refs(prediction)
    reference_emb = get_embedding(torch.LongTensor(ref_sentence).to(params["GPU"])).detach().cpu().numpy()
    
    call_model_args = {class_idx_str: prediction}
    explainer = ig2.IG2()
    ig2_mask = explainer.GetMask(explicand_emb,reference_emb,
            call_model_function,call_model_args,steps=501,step_size=0.01,clip_min_max=[emb_min,emb_max],)
    ig2_mask = np.sum(ig2_mask,axis=-1)
    ig2_mask_list.append(ig2_mask)

['LOC', 'What', 'country', 'was', 'A', 'Terrible', 'Beauty', 'to', 'Leon', 'Uris', '?']
GradPath search...
0 iterations, rep distance Loss -0.42010876536369324
100 iterations, rep distance Loss -0.15807844698429108
200 iterations, rep distance Loss -0.04529042914509773
300 iterations, rep distance Loss -0.019004816189408302
400 iterations, rep distance Loss -0.008226913399994373
500 iterations, rep distance Loss -0.002988226944580674
Integrate gradients on GradPath...
['DESC', 'What', 'is', 'typhoid', 'fever', '?']
GradPath search...
0 iterations, rep distance Loss -0.3230636417865753
100 iterations, rep distance Loss -0.12282904982566833
200 iterations, rep distance Loss -0.04098369553685188
300 iterations, rep distance Loss -0.01772630587220192
400 iterations, rep distance Loss -0.007620504125952721
500 iterations, rep distance Loss -0.002859528176486492
Integrate gradients on GradPath...
['LOC', 'Where', 'can', 'I', 'buy', 'a', 'hat', 'like', 'the', 'kind', 'Jay', 'Kay', 'from', 'Ja

### Vanilla IG attribution with zero baseline
The basic benchmarkto be compared

In [11]:
ig_mask_list = []
for exp_idx in exp_idx_list:
    print([sentence_y_name[exp_idx]]+sentence_x_word[exp_idx])
    explicand = np.array(sentence_x)[exp_idx,:]
    explicand_emb = get_embedding(torch.LongTensor(explicand).to(params["GPU"])).detach().cpu().numpy()
    prediction = torch.argmax(get_prediction(torch.tensor(explicand_emb).to(params["GPU"])),dim=1).cpu().numpy()
    assert prediction==sentence_y[exp_idx]
    
    call_model_args = {class_idx_str: prediction}
    black = np.zeros_like(explicand_emb)
    ig_mask_r = IntegratedGradients().GetMask(
            explicand_emb, call_model_function, call_model_args, x_steps=200, x_baseline=black, batch_size=32)
    ig_mask_r = np.sum(ig_mask_r,axis=-1)
    ig_mask_list.append(ig_mask_r)

['LOC', 'What', 'country', 'was', 'A', 'Terrible', 'Beauty', 'to', 'Leon', 'Uris', '?']
['DESC', 'What', 'is', 'typhoid', 'fever', '?']
['LOC', 'Where', 'can', 'I', 'buy', 'a', 'hat', 'like', 'the', 'kind', 'Jay', 'Kay', 'from', 'Jamiroquai', 'wears', '?']
['DESC', 'What', 'are', 'the', 'Twin', 'Cities', '?']
['ENTY', 'What', 'is', 'Beethoven', "'s", '9th', 'symphony', 'called', '?']
['HUM', 'Name', 'the', 'lawyer', 'for', 'Randy', 'Craft', '.']
['ENTY', 'Which', 'breakfast', 'cereal', 'brought', 'you', '``', 'the', 'best', 'each', 'morning', "''", '?']
['ABBR', 'What', 'does', 'Ms.', ',', 'Miss', ',', 'and', 'Mrs.', 'stand', 'for', '?']
['LOC', 'What', 'is', 'the', 'Homelite', 'Inc.', 'home', 'page', '?']
['ENTY', 'What', 'explosive', 'do', 'you', 'get', 'by', 'mixing', 'charcoal', ',', 'sulfur', 'and', 'saltpeter', '?']


### Visualize word attributions in sentence


In [14]:
# took the code for this cell block from 
#  https://docs.seldon.io/projects/alibi/en/stable/examples/integrated_gradients_imdb.html
from matplotlib.colors import Normalize, rgb2hex
import matplotlib as mpl
import seaborn as sns
import matplotlib.pyplot as plt
from IPython.display import HTML
from matplotlib.backends.backend_pgf import FigureCanvasPgf
mpl.use('ps')
mpl.backend_bases.register_backend('pdf', FigureCanvasPgf)
from matplotlib import rc
rc('text',usetex=True)
rc('pgf', preamble=r'\usepackage{xcolor}') 
rc('text.latex', preamble=r'\usepackage{xcolor}') #!not pgf

def hlstr(string, color='white'):
    if string == '&':
        string = '\&'
    return r"{\setlength{\fboxsep}{3pt}\colorbox[HTML]{"+f"{color}"+"}{"+f"{string}"+"}}"

def colorize(attrs, cmap='PiYG'):
    """
    Compute hex colors based on the attributions for a single instance.
    Uses a diverging colorscale by default and normalizes and scales
    the colormap so that colors are consistent with the attributions.
    """
    # cmap = sns.diverging_palette(232, 10, s=82, l=40, n=9, sep=10, center="dark", as_cmap=True)
    # attrs = attrs / np.sum(np.abs(attrs[:len(sent)]))
    cmap_bound = np.max(np.abs(attrs))
    attrs = attrs.tolist()
    norm = Normalize(vmin=-cmap_bound, vmax=cmap_bound)
    cmap = mpl.cm.get_cmap(cmap)

    # now compute hex values of colors
    colors = list(map(lambda x: rgb2hex(cmap(norm(x))).split('#')[-1].upper(), attrs))
    return colors

fig = plt.figure()
plt.axis([0, 9, 0, 9])
plt.axis('off')
plt.text(0.3,10,"[Predictions]",
    va='top',ha='left',fontsize=6.5,wrap=True)
plt.text(2.0,10,"[Questions with explanations]",
    va='top',ha='left',fontsize=6.5,wrap=True)
for n, exp_idx in enumerate(exp_idx_list):
    base = 9.5 - 1.1*n
    sent = sentence_x_word[exp_idx]
    plt.text(0.0,base-0.35,f"\#{n+1}",
        va='top',ha='left',fontsize=5,wrap=True)

    plt.text(0.3,base-0.35,f"[{cls_fullname_dict[sentence_y_name[exp_idx]]}]",
        va='top',ha='left',fontsize=5,wrap=True)

    plt.text(1.35,base,r"{\setlength{\fboxsep}{3pt}\colorbox{white}{[IG\textsuperscript{2}]}}",
    va='top',ha='left',fontsize=6.3,wrap=True)
    plt.text(1.8,base,"".join(list(map(hlstr, sent, colorize(ig2_mask_list[n])))),
        va='top',ha='left',fontsize=6.3,wrap=True)
    plt.text(1.35,base-0.45,r"{\setlength{\fboxsep}{3pt}\colorbox{white}{[IG]}}",
        va='top',ha='left',fontsize=6.3,wrap=True)
    plt.text(1.8,base-0.45,"".join(list(map(hlstr, sent, colorize(ig_mask_list[n])))),
        va='top',ha='left',fontsize=6.3,wrap=True)
        
plt.tight_layout()    
plt.savefig('results/TREC_IG2toIG.pdf')