# Integrated Gradients calculations

## Set-up

In [None]:
from datasets.mqnli import get_collate_fxn
from modeling.pretrained_bert import PretrainedBertModule
from modeling.lstm import LSTMModule
import os
from trainer import load_model
import torch
from torch.utils.data import DataLoader

from feature_importance import IntegratedGradientsBERT, IntegratedGradientsLSTM

## Loading models

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def ig_load_model(src_basename, src_dirname="mqnli_models"):
    path = os.path.join(src_dirname, src_basename)
    if 'lstm' in src_basename:
        model_class = LSTMModule
    else:
        model_class = PretrainedBertModule

    model, _ = load_model(model_class, path, device=device)
    model.eval()
    return model

In [None]:
bert_model_easy = ig_load_model("bert-easy-best.pt")

In [None]:
bert_model_hard = ig_load_model("bert-hard-best.pt")

In [None]:
#lstm_model_easy = ig_load_model("lstm-easy-best.pt")

In [None]:
#lstm_model_hard = ig_load_model("lstm-hard-best.pt")

## Loading data

In [None]:
def ig_load_dev(src_basename, src_dirname="mqnli_models"):
    path = os.path.join(src_dirname, src_basename)
    data = torch.load(path)
    return data.dev

In [None]:
bert_dev = ig_load_dev("bert-preprocessed-data.pt")

In [None]:
#lstm_dev = ig_load_dev("lstm-preprocessed-data.pt")

## Calculations

In [None]:
def analyze_sample(model, examples, n=8, batch_size=4, output_filename=None, layer=None, shuffle=True):
    n_batches = int(n / batch_size)
    if 'LSTM' in model.__class__.__name__:
        ig_class = IntegratedGradientsLSTM
        collate_fn = get_collate_fxn(examples, batch_first=False)
    else:
        ig_class = IntegratedGradientsBERT
        collate_fn = None
    ig = ig_class(model, layer=layer)
    dataloader = DataLoader(examples, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
    data = []
    with torch.no_grad():
        for i, input_tuple in enumerate(dataloader, start=1):
            if i % 100 == 0:
                print(f"Batch {i} of {n_batches}")
            input_tuple = tuple([x.to(device) for x in input_tuple])
            data += ig.predict_with_ig(input_tuple)
            if i == n_batches:
                break
    if output_filename:
        ig.to_json(data, output_filename)
    return data

### Demos

In [None]:
bert_data_emb = analyze_sample(
    bert_model_easy,
    bert_dev, shuffle=False)

In [None]:
IntegratedGradientsBERT.visualize(bert_data_emb[: 3])

In [None]:
bert_data_enc0 = analyze_sample(
    bert_model_easy,
    bert_dev,
    shuffle=False,
    layer=bert_model_easy.bert.encoder.layer[0])

In [None]:
IntegratedGradientsBERT.visualize(bert_data_enc0[: 3])

In [None]:
bert_data_enc11 = analyze_sample(
    bert_model_easy,
    bert_dev,
    shuffle=False,
    layer=bert_model_easy.bert.encoder.layer[11])

In [None]:
IntegratedGradientsBERT.visualize(bert_data_enc11[: 3])

In [None]:
# lstm_data_emb = analyze_sample(
#     lstm_model_easy,
#     lstm_dev,
#     shuffle=False)

### Big experiments

In [None]:
bert_data_easy = analyze_sample(
    bert_model_easy,
    bert_dev,
    n=10000,
    layer=bert_model_easy.bert.embeddings,
    output_filename='../ig-bert-easy-emb-10k.json')

In [None]:
bert_data_hard = analyze_sample(
    bert_model_hard,
    bert_dev,
    n=10000,
    layer=bert_model_hard.bert.embeddings,
    output_filename='../ig-bert-hard-emb-10k.json')

In [None]:
bert_data_easy = analyze_sample(
    bert_model_easy,
    bert_dev,
    n=10000,
    layer=bert_model_easy.bert.encoder.layer[11],
    output_filename='../ig-bert-easy-layer11-10k.json')

In [None]:
bert_data_hard = analyze_sample(
    bert_model_hard,
    bert_dev,
    n=10000,
    layer=bert_model_hard.bert.encoder.layer[11],
    output_filename='../ig-bert-hard-layer11-10k.json')