# Integrated Gradients calculations

## Set-up

In [1]:
from datasets.mqnli import get_collate_fxn
from modeling.pretrained_bert import PretrainedBertModule
from modeling.lstm import LSTMModule
import os
from trainer import load_model
import torch
from torch.utils.data import DataLoader

from feature_importance import IntegratedGradientsBERT, IntegratedGradientsLSTM

## Loading models

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
def ig_load_model(src_basename, src_dirname="mqnli_data"):
    path = os.path.join(src_dirname, src_basename)
    if 'lstm' in src_basename:
        model_class = LSTMModule
    else:
        model_class = PretrainedBertModule

    model, _ = load_model(model_class, path, device=device)
    return model

In [4]:
bert_model_easy = ig_load_model("bert-easy-best.pt")

In [5]:
bert_model_hard = ig_load_model("bert-hard-best.pt")

In [6]:
lstm_model_easy = ig_load_model("lstm-easy-best.pt")

In [7]:
lstm_model_hard = ig_load_model("lstm-hard-best.pt")

## Loading data

In [8]:
def ig_load_data(src_basename, src_dirname="mqnli_data"):
    path = os.path.join(src_dirname, src_basename)
    data = torch.load(path)
    return data

In [9]:
bert_data = ig_load_data("bert-preprocessed-data.pt")

In [10]:
lstm_data = ig_load_data("lstm-preprocessed-data.pt")

## Calculations

In [11]:
def analyze_sample(model, data, n=8, batch_size=4, output_filename=None, layer=None, shuffle=True):
    examples = data.dev
    n_batches = int(n / batch_size)
    if 'LSTM' in model.__class__.__name__:
        ig_class = IntegratedGradientsLSTM
        # collate_fn = get_collate_fxn(examples, batch_first=False)
    else:
        ig_class = IntegratedGradientsBERT
        # collate_fn = None
    ig = ig_class(model, data, layer=layer)
    dataloader = DataLoader(examples, batch_size=batch_size, shuffle=shuffle, collate_fn=None)
    data = []
    for i, input_tuple in enumerate(dataloader, start=1):
        if i % 100 == 0:
            print(f"Batch {i} of {n_batches}")
        ig.model.train()
        input_tuple = tuple([x.to(device) for x in input_tuple])
        data += ig.predict_with_ig(input_tuple)
        if i == n_batches:
            break
    if output_filename:
        ig.to_json(data, output_filename)
    return data

### BERT demos

In [12]:
bert_data_emb = analyze_sample(
    bert_model_easy,
    bert_data,
    n=12,
    shuffle=False)

In [13]:
IntegratedGradientsBERT.visualize(bert_data_emb[: 3])

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
entailment,entailment (1.00),,7.8,[CLS] emptystring every emptystring producer does not emptystring receives emptystring every shiny piano [SEP] emptystring every slippery producer does not blindly receives emptystring every shiny piano [SEP]
,,,,
entailment,entailment (1.00),,26.94,[CLS] emptystring no emptystring veteran does not occasionally calls emptystring no fragile trumpet [SEP] emptystring some ugly veteran emptystring emptystring occasionally calls not every emptystring trumpet [SEP]
,,,,
contradiction,contradiction (1.00),,18.66,[CLS] emptystring every emptystring robber does not emptystring pushes emptystring some black harp [SEP] emptystring every japanese robber does not violently pushes not every emptystring harp [SEP]
,,,,


In [15]:
bert_data_enc0 = analyze_sample(
    bert_model_easy,
    bert_data,
    shuffle=False,
    layer=bert_model_easy.bert.encoder.layer[0])

In [16]:
IntegratedGradientsBERT.visualize(bert_data_enc0[: 3])

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
entailment,entailment (1.00),,12.39,[CLS] emptystring every emptystring producer does not emptystring receives emptystring every shiny piano [SEP] emptystring every slippery producer does not blindly receives emptystring every shiny piano [SEP]
,,,,
entailment,entailment (1.00),,11.76,[CLS] emptystring no emptystring veteran does not occasionally calls emptystring no fragile trumpet [SEP] emptystring some ugly veteran emptystring emptystring occasionally calls not every emptystring trumpet [SEP]
,,,,
contradiction,contradiction (1.00),,7.99,[CLS] emptystring every emptystring robber does not emptystring pushes emptystring some black harp [SEP] emptystring every japanese robber does not violently pushes not every emptystring harp [SEP]
,,,,


In [17]:
bert_data_enc11 = analyze_sample(
    bert_model_easy,
    bert_data,
    shuffle=False,
    layer=bert_model_easy.bert.encoder.layer[11])

In [18]:
IntegratedGradientsBERT.visualize(bert_data_enc11[: 3])

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
entailment,entailment (1.00),,10.28,[CLS] emptystring every emptystring producer does not emptystring receives emptystring every shiny piano [SEP] emptystring every slippery producer does not blindly receives emptystring every shiny piano [SEP]
,,,,
entailment,entailment (1.00),,10.99,[CLS] emptystring no emptystring veteran does not occasionally calls emptystring no fragile trumpet [SEP] emptystring some ugly veteran emptystring emptystring occasionally calls not every emptystring trumpet [SEP]
,,,,
contradiction,contradiction (1.00),,9.19,[CLS] emptystring every emptystring robber does not emptystring pushes emptystring some black harp [SEP] emptystring every japanese robber does not violently pushes not every emptystring harp [SEP]
,,,,


### LSTM demos

In [19]:
lstm_data_emb = analyze_sample(
    lstm_model_easy,
    lstm_data,
    shuffle=False)

In [20]:
IntegratedGradientsLSTM.visualize(lstm_data_emb[: 3])

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
entailment,neutral (1.00),,17.83,every emptystring producer doesnot emptystring receives every shiny piano [SEP] every slimy producer doesnot blindly receives every shiny piano
,,,,
entailment,neutral (1.00),,20.74,no emptystring zookeeper doesnot occasionally buries no flimsy trumpet [SEP] some ugly zookeeper emptystring occasionally buries notevery emptystring trumpet
,,,,
contradiction,neutral (1.00),,43.83,every emptystring robber doesnot emptystring pushes some black harp [SEP] every Japanese robber doesnot unabashedly pushes notevery emptystring harp
,,,,


### Big experiments

In [None]:
bert_data_easy = analyze_sample(
    bert_model_easy,
    bert_data,
    n=10000,
    layer=bert_model_easy.bert.embeddings,
    output_filename='../ig-bert-easy-emb-10k.json')

In [None]:
bert_data_hard = analyze_sample(
    bert_model_hard,
    bert_data,
    n=10000,
    layer=bert_model_hard.bert.embeddings,
    output_filename='../ig-bert-hard-emb-10k.json')

In [None]:
bert_data_easy = analyze_sample(
    bert_model_easy,
    bert_data,
    n=10000,
    layer=bert_model_easy.bert.encoder.layer[11],
    output_filename='../ig-bert-easy-layer11-10k.json')

In [None]:
bert_data_hard = analyze_sample(
    bert_model_hard,
    bert_data,
    n=10000,
    layer=bert_model_hard.bert.encoder.layer[11],
    output_filename='../ig-bert-hard-layer11-10k.json')

In [None]:
lstm_data_easy = analyze_sample(
    lstm_model_easy,
    lstm_data,
    n=10000,
    layer=None,  # Uses the embedding.
    output_filename='../ig-lstm-easy-emb-10k.json')

In [None]:
lstm_data_hard = analyze_sample(
    lstm_model_hard,
    lstm_data,
    n=10000,
    layer=None,  # Uses the embedding.
    output_filename='../ig-lstm-hard-emb-10k.json')