# Integrated Gradients calculations

## Set-up

In [1]:
from modeling.pretrained_bert import PretrainedBertModule
from trainer import load_model
import torch
from torch.utils.data import DataLoader

from feature_importance import IntegratedGradients

## Loading models

In [2]:
bert_model_path = "mqnli_models/bert-easy-best.pt"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

bert_model, _ = load_model(PretrainedBertModule, bert_model_path, device=device)

bert_model = bert_model.to(device)
bert_model.eval() # disable dropout

None

## Loading data (Bert)

In [3]:
bert_easy_data_path = "mqnli_models/bert-preprocessed-data.pt"

bert_data = torch.load(bert_easy_data_path)

bert_dev_set = bert_data.dev

## Calculations

In [4]:
def analyze_sample(model, examples, n=8, batch_size=4, output_filename=None, layer=None, shuffle=True):
    n_batches = int(n / batch_size)
    ig = IntegratedGradients(model, layer=layer)
    dataloader = DataLoader(examples, batch_size=batch_size, shuffle=shuffle)
    data = []
    with torch.no_grad():
        for i, input_tuple in enumerate(dataloader, start=1):
            if i % 100 == 0:
                print(f"Batch {i} of {n_batches}")
            input_tuple = tuple([x.to(device) for x in input_tuple])
            data += ig.predict_with_ig(tuple(input_tuple))
            if i == n_batches:
                break
    if output_filename:
        ig.to_json(data, output_filename)
    return data

### Big experiment

In [5]:
#data = analyze_sample(bert_model, bert_dev_set, n=20000, output_filename='ig-batch01.json')

In [6]:
#IntegratedGradients.visualize(data[: 5])

### Demos

In [7]:
data_emb = analyze_sample(bert_model, bert_dev_set, shuffle=False)

In [8]:
IntegratedGradients.visualize(data_emb[: 3])

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
entailment,entailment (1.00),,11.92,[CLS] emptystring every emptystring producer does not emptystring receives emptystring every shiny piano [SEP] emptystring every slippery producer does not blindly receives emptystring every shiny piano [SEP]
,,,,
entailment,entailment (1.00),,4.06,[CLS] emptystring no emptystring veteran does not occasionally calls emptystring no fragile trumpet [SEP] emptystring some ugly veteran emptystring emptystring occasionally calls not every emptystring trumpet [SEP]
,,,,
contradiction,contradiction (1.00),,16.38,[CLS] emptystring every emptystring robber does not emptystring pushes emptystring some black harp [SEP] emptystring every japanese robber does not violently pushes not every emptystring harp [SEP]
,,,,


In [13]:
data_enc0 = analyze_sample(bert_model, bert_dev_set, shuffle=False, layer=bert_model.bert.encoder.layer[0])

In [14]:
IntegratedGradients.visualize(data_enc0[: 3])

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
entailment,entailment (1.00),,12.11,[CLS] emptystring every emptystring producer does not emptystring receives emptystring every shiny piano [SEP] emptystring every slippery producer does not blindly receives emptystring every shiny piano [SEP]
,,,,
entailment,entailment (1.00),,8.3,[CLS] emptystring no emptystring veteran does not occasionally calls emptystring no fragile trumpet [SEP] emptystring some ugly veteran emptystring emptystring occasionally calls not every emptystring trumpet [SEP]
,,,,
contradiction,contradiction (1.00),,12.01,[CLS] emptystring every emptystring robber does not emptystring pushes emptystring some black harp [SEP] emptystring every japanese robber does not violently pushes not every emptystring harp [SEP]
,,,,


In [15]:
data_enc11 = analyze_sample(bert_model, bert_dev_set, shuffle=False, layer=bert_model.bert.encoder.layer[11])

In [17]:
IntegratedGradients.visualize(data_enc[: 3])

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
entailment,entailment (1.00),,11.04,[CLS] emptystring every emptystring producer does not emptystring receives emptystring every shiny piano [SEP] emptystring every slippery producer does not blindly receives emptystring every shiny piano [SEP]
,,,,
entailment,entailment (1.00),,10.97,[CLS] emptystring no emptystring veteran does not occasionally calls emptystring no fragile trumpet [SEP] emptystring some ugly veteran emptystring emptystring occasionally calls not every emptystring trumpet [SEP]
,,,,
contradiction,contradiction (1.00),,9.67,[CLS] emptystring every emptystring robber does not emptystring pushes emptystring some black harp [SEP] emptystring every japanese robber does not violently pushes not every emptystring harp [SEP]
,,,,
