# Integrated Gradients calculations

## Set-up

In [1]:
from datasets.mqnli import get_collate_fxn
from modeling.pretrained_bert import PretrainedBertModule
from modeling.lstm import LSTMModule
import os
from trainer import load_model
import torch
from torch.utils.data import DataLoader

from feature_importance import IntegratedGradientsBERT, IntegratedGradientsLSTM

In [2]:
# LSTM hard classes:

lstm_hard_classes = (
    "neutral", "entailment", "contradiction",
    "independence", "equivalence", "entails",
    "reverse entails", "contradiction2",
    "alternation", "cover")

## Loading models

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
def ig_load_model(src_basename, src_dirname="mqnli_data"):
    path = os.path.join(src_dirname, src_basename)
    if 'lstm' in src_basename:
        model_class = LSTMModule
    else:
        model_class = PretrainedBertModule

    model, _ = load_model(model_class, path, device=device)
    return model

In [None]:
bert_model_easy = ig_load_model("bert-easy-best.pt")

In [None]:
bert_model_hard = ig_load_model("bert-hard-best.pt")

In [5]:
lstm_model_easy = ig_load_model("lstm-easy-best.pt", src_dirname="mqnli_models")

In [6]:
lstm_model_hard = ig_load_model("lstm-hard-best.pt", src_dirname="mqnli_models")

## Loading data

In [7]:
def ig_load_data(src_basename, src_dirname="mqnli_data"):
    path = os.path.join(src_dirname, src_basename)
    data = torch.load(path)
    return data

In [8]:
# bert_data = ig_load_data("bert-preprocessed-data.pt")
bert_data = ig_load_data("mqnli-bert-easy.pt")
lstm_data = bert_data

In [None]:
# lstm_data = ig_load_data("lstm-preprocessed-data.pt")

## Calculations

In [9]:
def analyze_sample(model, data, n=8, batch_size=4, output_filename=None, layer=None, shuffle=True, classes=None):
    if classes is None:
        classes = ('neutral', 'entailment', 'contradiction')
    examples = data.dev
    n_batches = int(n / batch_size)
    if 'LSTM' in model.__class__.__name__:
        ig_class = IntegratedGradientsLSTM
        # collate_fn = get_collate_fxn(examples, batch_first=False)
    else:
        ig_class = IntegratedGradientsBERT
        # collate_fn = None
    ig = ig_class(model, data, classes=classes, layer=layer)
    dataloader = DataLoader(examples, batch_size=batch_size, shuffle=shuffle, collate_fn=None)
    data = []
    for i, input_tuple in enumerate(dataloader, start=1):
        if i % 100 == 0:
            print(f"Batch {i} of {n_batches}")
        ig.model.train()
        input_tuple = tuple([x.to(device) for x in input_tuple])
        data += ig.predict_with_ig(input_tuple)
        if i == n_batches:
            break
    if output_filename:
        ig.to_json(data, output_filename)
    return data

### BERT demos

In [None]:
bert_data_emb = analyze_sample(
    bert_model_easy,
    bert_data,
    n=12,
    shuffle=False)

In [None]:
IntegratedGradientsBERT.visualize(bert_data_emb[: 3])

In [None]:
bert_data_enc0 = analyze_sample(
    bert_model_easy,
    bert_data,
    shuffle=False,
    layer=bert_model_easy.bert.encoder.layer[0])

In [None]:
IntegratedGradientsBERT.visualize(bert_data_enc0[: 3])

In [None]:
bert_data_enc11 = analyze_sample(
    bert_model_easy,
    bert_data,
    shuffle=False,
    layer=bert_model_easy.bert.encoder.layer[11])

In [None]:
IntegratedGradientsBERT.visualize(bert_data_enc11[: 3])

### LSTM demos

In [10]:
lstm_easy_data_emb = analyze_sample(
    lstm_model_easy,
    lstm_data,
    shuffle=False)

type of emb_x <class 'torch.Tensor'>
type of x <class 'torch.Tensor'>
type of x <class 'torch.Tensor'>
type of x <class 'torch.Tensor'>
type of x <class 'torch.Tensor'>
type of emb_x <class 'torch.Tensor'>
type of x <class 'torch.Tensor'>
type of x <class 'torch.Tensor'>
type of x <class 'torch.Tensor'>
type of x <class 'torch.Tensor'>


In [11]:
IntegratedGradientsLSTM.visualize(lstm_easy_data_emb[: 3])

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
entailment,entailment (1.00),,4.43,[CLS] emptystring every emptystring producer does not emptystring receives emptystring every shiny piano [SEP] emptystring every slippery producer does not blindly receives emptystring every shiny piano [SEP]
,,,,
entailment,entailment (1.00),,8.44,[CLS] emptystring no emptystring veteran does not occasionally calls emptystring no fragile trumpet [SEP] emptystring some ugly veteran emptystring emptystring occasionally calls not every emptystring trumpet [SEP]
,,,,
contradiction,contradiction (1.00),,19.17,[CLS] emptystring every emptystring robber does not emptystring pushes emptystring some black harp [SEP] emptystring every japanese robber does not violently pushes not every emptystring harp [SEP]
,,,,


In [None]:
lstm_hard_data_emb = analyze_sample(
    lstm_model_hard,
    lstm_data,
    shuffle=False,
    classes=lstm_hard_classes)

In [None]:
IntegratedGradientsLSTM.visualize(lstm_hard_data_emb[: 3])

In [None]:
# layer takes in index rather than pytorch module
lstm_easy_data_layer0 = analyze_sample(
    lstm_model_easy,
    lstm_data,
    layer=0,
    shuffle=False
)

In [None]:
IntegratedGradientsLSTM.visualize(lstm_easy_data_layer0[: 3])

### Big experiments

In [None]:
bert_data_easy = analyze_sample(
    bert_model_easy,
    bert_data,
    n=10000,
    layer=bert_model_easy.bert.embeddings,
    output_filename='../ig-bert-easy-emb-10k.json')

In [None]:
bert_data_hard = analyze_sample(
    bert_model_hard,
    bert_data,
    n=10000,
    layer=bert_model_hard.bert.embeddings,
    output_filename='../ig-bert-hard-emb-10k.json')

In [None]:
bert_data_easy = analyze_sample(
    bert_model_easy,
    bert_data,
    n=10000,
    layer=bert_model_easy.bert.encoder.layer[11],
    output_filename='../ig-bert-easy-layer11-10k.json')

In [None]:
bert_data_hard = analyze_sample(
    bert_model_hard,
    bert_data,
    n=10000,
    layer=bert_model_hard.bert.encoder.layer[11],
    output_filename='../ig-bert-hard-layer11-10k.json')

In [None]:
lstm_data_easy = analyze_sample(
    lstm_model_easy,
    lstm_data,
    n=10000,
    layer=None,  # Uses the embedding.
    output_filename='../ig-lstm-easy-emb-10k.json')

In [None]:
lstm_data_hard = analyze_sample(
    lstm_model_hard,
    lstm_data,
    n=10000,
    layer=None,  # Uses the embedding.
    classes=lstm_hard_classes,
    output_filename='../ig-lstm-hard-emb-10k.json')