In [1]:
import torch
import captum
from main import TextClassificationModel
from torchtext.datasets import AG_NEWS
from torch import nn
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
from captum.attr import LimeBase, KernelShap
from captum._utils.models.linear_model import SkLearnLasso
import torch.nn.functional as F
from IPython.core.display import HTML, display
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer, GradientShap
from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization, IntegratedGradients
# import shap


def num_to_text(text_nums, vocab) :
    return [vocab.vocab.itos_[i] for i in text_nums]

In [2]:
    PATH = "./text_classification.model"
    tokenizer = get_tokenizer('basic_english')
    train_iter = AG_NEWS(split='train')

    def yield_tokens(data_iter):
        for _, text in data_iter:
            yield tokenizer(text)

    vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
    vocab.set_default_index(vocab["<unk>"])

    text_pipeline = lambda x: vocab(tokenizer(x))
    label_pipeline = lambda x: int(x) - 1

    #

In [3]:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def collate_batch(batch):
        label_list, text_list, offsets = [], [], [0]
        for (_label, _text) in batch:
            label_list.append(label_pipeline(_label))
            processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
            text_list.append(processed_text)
            offsets.append(processed_text.size(0))
        label_list = torch.tensor(label_list, dtype=torch.int64)
        offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
        text_list = torch.cat(text_list)
        return label_list.to(device), text_list.to(device), offsets.to(device)



    train_iter = AG_NEWS(split='train')
    num_class = len(set([label for (label, text) in train_iter]))
    vocab_size = len(vocab)
    emsize = 64
    #model = TextClassificationModel(vocab_size, emsize, num_class).to(device)
    model = torch.load("text_classification.model")
    import time

In [4]:
    def train(dataloader):
        model.train()
        total_acc, total_count = 0, 0
        log_interval = 500
        start_time = time.time()

        for idx, (label, text, offsets) in enumerate(dataloader):
            optimizer.zero_grad()
            predicted_label = model(text, offsets)
            loss = criterion(predicted_label, label)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
            optimizer.step()
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
            if idx % log_interval == 0 and idx > 0:
                elapsed = time.time() - start_time
                print('| epoch {:3d} | {:5d}/{:5d} batches '
                      '| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),
                                                  total_acc / total_count))
                total_acc, total_count = 0, 0
                start_time = time.time()

In [5]:
    def evaluate(dataloader):
        model.eval()
        total_acc, total_count = 0, 0

        with torch.no_grad():
            for idx, (label, text, offsets) in enumerate(dataloader):
                #print("One test label : ", text.shape, offsets.shape)
                predicted_label = model(text, offsets)
                text_converted = num_to_text(text, vocab)
                loss = criterion(predicted_label, label)
                total_acc += (predicted_label.argmax(1) == label).sum().item()
                total_count += label.size(0)
        return total_acc / total_count

    # Hyperparameters
    EPOCHS = 1  # epoch
    LR = 5  # learning rate
    BATCH_SIZE = 1  # batch size for training

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=LR)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
    total_accu = None
    train_iter, test_iter = AG_NEWS()

    train_dataset = to_map_style_dataset(train_iter)

    test_dataset = to_map_style_dataset(test_iter)
    num_train = int(len(train_dataset) * 0.95)
    split_train_, split_valid_ = \
        random_split(train_dataset, [num_train, len(train_dataset) - num_train])


    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                                 shuffle=True, collate_fn=collate_batch)

    
    



In [6]:
    # Lime Code
    # remove the batch dimension for the embedding-bag model
    def forward_func(text, offsets):
        return model(text.squeeze(0), offsets)

    # encode text indices into latent representations & calculate cosine similarity
    def exp_embedding_cosine_distance(original_inp, perturbed_inp, _, **kwargs):
        original_emb = model.embedding(original_inp, None)
        perturbed_emb = model.embedding(perturbed_inp, None)
        distance = 1 - F.cosine_similarity(original_emb, perturbed_emb, dim=1)
        return torch.exp(-1 * (distance ** 2) / 2)

    # binary vector where each word is selected independently and uniformly at random
    def bernoulli_perturb(text, **kwargs):
        probs = torch.ones_like(text) * 0.5
        return torch.bernoulli(probs).long()

    # remove absenst token based on the intepretable representation sample
    def interp_to_input(interp_sample, original_input, **kwargs):
        return original_input[interp_sample.bool()].view(original_input.size(0), -1)


    
    lasso_lime_base = LimeBase(
        forward_func,
        interpretable_model=SkLearnLasso(alpha=0.08),
        similarity_func=exp_embedding_cosine_distance,
        perturb_func=bernoulli_perturb,
        perturb_interpretable_space=True,
        from_interp_rep_transform=interp_to_input,
        to_interp_rep_transform=None
    )
    
    test_label = 2  # {1: World, 2: Sports, 3: Business, 4: Sci/Tec}
    test_line = ('US Men Have Right Touch in Relay Duel Against Australia THENS, Aug. 17 '
                 '- So Michael Phelps is not going to match the seven gold medals won by Mark Spitz. '
                 'And it is too early to tell if he will match Aleksandr Dityatin, '
                 'the Soviet gymnast who won eight total medals in 1980.')

    test_labels, test_text, test_offsets = collate_batch([(test_label, test_line)])

    probs = F.softmax(model(test_text, test_offsets), dim=1).squeeze(0)
    print('Prediction probability:', round(probs[test_labels[0]].item(), 4), probs)

    attrs = lasso_lime_base.attribute(
        test_text.unsqueeze(0),  # add batch dimension for Captum
        target=test_labels,
        additional_forward_args=(test_offsets,),
        n_samples=32000,
        show_progress=True
    ).squeeze(0)
    attrs = F.normalize(attrs, p=2.0, dim=0, eps=1e-12, out=None)

    print(attrs)
    def show_text_attr(attrs):
        rgb = lambda x: '255,0,0' if x < 0 else '0,255,0'
        alpha = lambda x: abs(x) ** 0.5
        token_marks = [
            f'<mark style="background-color:rgba({rgb(attr)},{alpha(attr)})">{token}</mark>'
            for token, attr in zip(tokenizer(test_line), attrs.tolist())
        ]
    
        display(HTML('<p>' + ' '.join(token_marks) + '</p>'))
    
    show_text_attr(attrs)



Lime Base attribution:   0%|          | 0/32000 [00:00<?, ?it/s]

Prediction probability: 0.897 tensor([0.0930, 0.8970, 0.0064, 0.0037], grad_fn=<SqueezeBackward1>)


Lime Base attribution: 100%|██████████| 32000/32000 [01:31<00:00, 348.89it/s]
  "Must have sklearn version 0.23.0 or higher to use "
  "Sample weight is not supported for the provided linear model!"


tensor([-0.0972,  0.0000,  0.0000,  0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
         0.0000,  0.0876, -0.0000, -0.0000, -0.0000, -0.0759,  0.0000, -0.2955,
         0.0908, -0.0000, -0.1226, -0.0000,  0.0724,  0.0000, -0.0000,  0.6067,
         0.0000,  0.0000, -0.0000,  0.0000,  0.0000, -0.0000,  0.0000, -0.0000,
        -0.0845, -0.0000, -0.0000, -0.0000,  0.0000,  0.0000, -0.0000,  0.0000,
        -0.0000,  0.1421, -0.0000,  0.6023,  0.0000, -0.2976, -0.0000,  0.0000,
        -0.0873,  0.0065, -0.0000,  0.0000,  0.0236, -0.0000,  0.0000, -0.0000,
        -0.0000, -0.0788])


In [7]:
test_label = 2  # {1: World, 2: Sports, 3: Business, 4: Sci/Tec}
test_line = ('US Men Have Right Touch in Relay Duel Against Australia THENS, Aug. 17 '
                 '- So Michael Phelps is not going to match the seven gold medals won by Mark Spitz. '
                 'And it is too early to tell if he will match Aleksandr Dityatin, '
                 'the Soviet gymnast who won eight total medals in 1980.')

test_labels, test_text, test_offsets = collate_batch([(test_label, test_line)])


interpretable_embedding = configure_interpretable_embedding_layer(model, 'embedding')

  "In order to make embedding layers more interpretable they will "


In [8]:
def construct_whole_bert_embeddings(input_ids, ref_input_ids, \
                                        token_type_ids=None, ref_token_type_ids=None, \
                                        position_ids=None, ref_position_ids=None):
        input_embeddings = interpretable_embedding.indices_to_embeddings(input_ids, test_offsets)
        ref_input_embeddings = interpretable_embedding.indices_to_embeddings(ref_input_ids, test_offsets)
        return input_embeddings, ref_input_embeddings

In [9]:



layer_grad_shap = KernelShap(model) 
baselines = torch.zeros(test_text.shape[0]).to(torch.int64)
input_embeddings, ref_input_embeddings = construct_whole_bert_embeddings(test_text, baselines)

   
attribution = layer_grad_shap.attribute(input_embeddings, ref_input_embeddings,
                                            target=test_labels)

  "Must have sklearn version 0.23.0 or higher to use "


In [12]:
attribution

tensor([[-1.0123,  1.2972,  0.4730, -0.1684,  0.3876,  0.5670, -0.4691, -0.4729,
          0.3510, -0.2553,  2.2656, -0.2304, -0.2541, -0.1504,  0.0196,  0.0541,
         -1.6132,  0.3716, -0.0845, -0.1136,  0.4707, -0.4951,  0.7380, -0.2249,
         -1.2641, -0.2667, -0.2680, -0.2820,  0.4002, -0.3239,  0.0404,  0.1696,
          0.2374,  0.1086, -0.4270,  0.7694,  0.4425, -0.3525,  0.7189,  0.7148,
         -1.0513, -0.5707,  0.6970, -0.2568,  0.0979, -0.1607, -0.0972,  0.3368,
          0.8625, -0.2194,  2.1601,  0.1669,  0.1086, -0.1429, -0.3077,  0.2106,
          0.0404,  0.0227, -0.6609, -0.4729,  2.1292,  0.3120, -0.1139,  1.3582]])

In [11]:
#attribution2 = F.normalize(attribution, p=2.0, dim=0)

def show_text_attr(attrs):
        rgb = lambda x: '255,0,0' if x < 0 else '0,255,0'
        alpha = lambda x: abs(x) ** 0.5
        token_marks = [
            f'<mark style="background-color:rgba({rgb(attr)},{alpha(attr)})">{token}</mark>'
            for token, attr in zip(tokenizer(test_line), attrs.tolist())
        ]
        print(token_marks)
    
        display(HTML('<p>' + ' '.join(token_marks) + '</p>'))
    
show_text_attr(attribution.squeeze(0))

['<mark style="background-color:rgba(255,0,0,1.006107159066216)">us</mark>', '<mark style="background-color:rgba(0,255,0,1.1389625662462224)">men</mark>', '<mark style="background-color:rgba(0,255,0,0.6877239469435061)">have</mark>', '<mark style="background-color:rgba(255,0,0,0.4103759177386793)">right</mark>', '<mark style="background-color:rgba(0,255,0,0.6225579235596883)">touch</mark>', '<mark style="background-color:rgba(0,255,0,0.7529947571894896)">in</mark>', '<mark style="background-color:rgba(255,0,0,0.6848868211121851)">relay</mark>', '<mark style="background-color:rgba(255,0,0,0.6876618021520073)">duel</mark>', '<mark style="background-color:rgba(0,255,0,0.5924409439738563)">against</mark>', '<mark style="background-color:rgba(255,0,0,0.5052244111743958)">australia</mark>', '<mark style="background-color:rgba(0,255,0,1.5051755626591437)">thens</mark>', '<mark style="background-color:rgba(255,0,0,0.48004185700206536)">,</mark>', '<mark style="background-color:rgba(255,0,0,0.5

In [45]:
for i in zip(tokenizer(test_line), attribution.squeeze(0).tolist()):
    print(i)

('we', -1.9076507091522217)
('talk', 0.6008592844009399)
('about', 0.8334792852401733)
('sports', -0.34259605407714844)


In [46]:
attribution

tensor([[-1.9077,  0.6009,  0.8335, -0.3426,  0.1981,  0.7604, -0.6456,  0.6146,
          0.2343,  0.0832,  0.7104,  0.1022, -0.7301,  0.4617,  0.8813, -0.1982,
          0.4625,  0.2343,  5.7208,  0.3786,  0.3361, -0.1519, -0.0339,  0.2571,
          0.3089,  1.0197, -0.3654,  0.1447,  0.8239,  0.4731, -0.2086,  0.5020,
         -0.0367, -0.1019,  0.2885,  0.0327,  0.7253,  0.3786,  2.2384,  0.0968,
          0.3095,  0.7376,  0.6863, -0.9640,  0.5611,  0.1504, -0.2511,  0.2455,
          0.5627,  0.7795,  0.7376,  0.4433, -0.0506, -0.0910,  1.3215,  0.3361,
          0.7376, -0.3436, -0.2265,  0.3786,  0.3843,  0.9029,  0.8754, -0.1075]])

In [1]:
attribution

In [6]:
from captum.attr import IntegratedGradients

In [7]:
def construct_whole_bert_embeddings(input_ids, ref_input_ids, \
                                        token_type_ids=None, ref_token_type_ids=None, \
                                        position_ids=None, ref_position_ids=None):
        input_embeddings = interpretable_embedding.indices_to_embeddings(input_ids, test_offsets)
        ref_input_embeddings = interpretable_embedding.indices_to_embeddings(ref_input_ids, test_offsets)
        return input_embeddings, ref_input_embeddings

In [8]:
from captum.attr import IntegratedGradients, TokenReferenceBase, visualization

token_reference = TokenReferenceBase(reference_token_idx=0)
vis_data_records_ig = []



def interpret_sentence(model, test_text, test_offsets, test_labels,  pred, pred_ind, min_len=7, interpretable_embedding = None, label=0):
        model.zero_grad()
        # input_indices dim: [sequence_length]
        seq_length = test_text.shape[0]
        # predic

        # generate reference indices for each sample
        
        reference_indices = token_reference.generate_reference(seq_length, device=device)
        input_embeddings, ref_input_embeddings = construct_whole_bert_embeddings(test_text, reference_indices, test_offsets, interpretable_embedding)
        # compute attributions and approximation delta using layer integrated gradients
        ig = IntegratedGradients(model)
        attributions_ig = ig.attribute(input_embeddings, ref_input_embeddings, #additional_forward_args=(test_offsets,),\
                                               n_steps=500, return_convergence_delta=True, target=test_labels)
        return attributions_ig
    #     add_attributions_to_visualizer(attributions_ig, test_text, pred, pred_ind, label, vis_data_records_ig)
    #
    #
    # def add_attributions_to_visualizer(attributions, text, pred, pred_ind, label, vis_data_records):
    #     attributions = attributions.sum(dim=2).squeeze(0)
    #     attributions = attributions / torch.norm(attributions)
    #     attributions = attributions.cpu().detach().numpy()
    #
    #     # storing couple samples in an array for visualization purposes
    #     vis_data_records.append(visualization.VisualizationDataRecord(
    #         attributions,
    #         pred,
    #         attributions.sum(),
    #         text))

model = torch.load("text_classification.model")
test_label = 2  # {1: World, 2: Sports, 3: Business, 4: Sci/Tec}
test_line = ('We talk about sports')
test_labels, test_text, test_offsets = collate_batch([(test_label, test_line)])
pred = F.softmax(model(test_text, test_offsets), dim=1)
pred_ind = torch.round(pred)
interpretable_embedding = configure_interpretable_embedding_layer(model, 'embedding')
attrs = interpret_sentence(model, test_text, test_offsets, test_labels,  pred,pred_ind,  label=2, interpretable_embedding=interpretable_embedding)


RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)

In [22]:
attrs = attrs[0]
show_text_attr(attrs.squeeze(0))

In [92]:
attrs

(tensor([[ 3.5102e-01,  4.0210e-02,  1.1089e+00, -4.2984e-01, -2.9637e-01,
           5.4363e-01,  2.4149e-03,  2.4999e-01,  2.0240e+00,  3.4554e-01,
           7.5385e-01, -4.6866e-01,  3.8687e-01,  5.0212e-01,  9.8996e-01,
          -2.7457e-01,  3.1099e-01, -1.0682e-02,  2.8944e+00,  2.2775e-04,
          -9.0748e-01,  4.7193e-01,  1.2137e+00, -1.7071e+00, -2.1126e-01,
           1.0765e+00, -8.3939e-02, -5.1048e-01, -3.8243e-01, -1.3251e+00,
          -3.7028e-01,  2.6928e+00, -1.4850e-01, -8.1190e-01, -1.8261e-01,
          -4.5970e-02, -1.7341e-02, -2.2869e-02,  6.8485e-01, -1.9277e-02,
          -3.7351e-01,  3.2794e-01, -5.3542e-01, -1.2330e+00,  6.0880e-01,
           4.3693e-01, -8.8090e-01,  2.1899e-02,  1.3545e+00, -1.1141e-01,
           1.2586e+00, -6.1608e-01, -1.5722e-01, -9.5145e-03, -1.7878e-01,
          -5.2722e-01, -1.2834e-01,  2.3180e-01, -5.3157e-01,  5.7616e-01,
          -1.7502e+00, -1.7291e+00,  7.6840e-02, -1.4986e-01]],
        dtype=torch.float64, grad_fn

In [18]:
test_text.shape[0]

58

In [10]:
reference_indices = token_reference.generate_reference(30, device = 'cpu').unsqueeze(0)

In [11]:
reference_indices

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0]])

In [84]:
attribution

tensor([[-0.2379,  0.8794, -0.2978, -0.4298,  0.1009, -0.0480,  0.4133,  0.4213,
          2.0240, -0.1612,  0.3858, -0.4868, -0.2254,  0.8605,  0.9900, -0.2248,
          0.7363,  0.0640,  0.9061,  0.4915, -0.1679, -0.9358,  0.6868, -0.7764,
         -0.2968,  0.0337, -0.0742,  0.2649, -0.6827,  0.0084,  0.1198,  0.6426,
         -0.4876, -0.8119, -0.8248,  0.3090, -0.2078, -0.1020,  0.6513,  0.0124,
          0.3199, -0.2214, -0.7981, -0.4747,  0.4360,  0.2398,  0.0172,  0.2905,
          0.7160, -0.1704,  0.8771,  0.1892,  0.0333,  0.2531,  0.1302,  0.2241,
          0.3600,  0.3199,  0.0842, -0.3762, -0.5210, -0.3472, -0.3215, -0.3837]])

In [33]:
attribution == attrs

tensor([[False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False]])

In [34]:
attrs

tensor([[ 8.3473e-01, -2.6957e-01,  1.8882e+00, -3.4260e-01, -1.6242e-01,
         -1.5000e-01, -4.2731e-03,  2.7737e-01,  3.0920e+00,  6.0682e-01,
          5.3031e-01, -4.5921e-01,  3.4753e-01,  4.6166e-01,  1.6990e+00,
         -1.9817e-01,  2.9082e-01,  8.8086e-03,  5.7208e+00,  5.0812e-05,
         -4.1451e-01,  6.2801e-01,  1.0184e-01, -9.8312e-01, -2.7794e-01,
          1.9893e+00, -1.3318e-01, -4.0041e-01,  5.7171e-01,  1.7516e-01,
         -2.8365e-01,  4.0791e+00, -3.6710e-02,  4.1080e-04, -1.5838e-01,
         -3.2150e-01, -7.4263e-03, -1.0595e-02,  6.6097e-01, -4.4583e-02,
         -2.8630e-01, -3.1165e-02,  3.7180e-02, -5.2194e-01,  8.1334e-01,
          7.4355e-01,  1.4040e-01,  2.5691e-01,  1.5124e+00, -5.7763e-02,
          4.5929e+00, -9.9076e-01, -5.0612e-02, -1.9402e-03,  1.5653e-01,
         -7.2541e-01,  9.4780e-02,  4.1349e-01, -6.0663e-01,  6.4594e-01,
         -1.0336e+00, -1.0643e+00,  7.1016e-02, -1.1892e-01]],
       dtype=torch.float64, grad_fn=<MulBackward0