# Interpreting Basic LSTM, BiLSTM, Hybrid CNN LSTM Models using XAI

This notebook was mainly inspired by this [Captum Tutorial](https://captum.ai/tutorials/IMDB_TorchText_Interpret).

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys

print("Python interpreter:", sys.executable)

Python interpreter: /Users/ikea/miniforge3/bin/python


In [4]:
import os
import sys
module_path = os.path.dirname(os.path.dirname(os.path.abspath(os.path.join('.'))))
if module_path not in sys.path:
    print('Add root path to system path: ', module_path)
    sys.path.append(module_path)
module_path += '/'

Add root path to system path:  /Users/ikea/Documents/Brown/course/CS2470/Hate Speech Detection


### Imports

In [19]:
import tqdm
import argparse
import numpy as np
import datetime
import time

import spacy
import pandas as pd
from sklearn.metrics import f1_score

from torch import optim
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from utils.preprocess_utils import *
from train_utils import train_model, test_model
from src.evaluation.test_save_stats import *

from src.utils.utils import *
from xai_utils import *

import captum
from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization

from typing import Any, Iterable, List, Tuple, Union
from IPython.core.display import HTML, display


spacy_en = spacy.load("en_core_web_sm")

  from IPython.core.display import HTML, display


## ⚠️ Before running the cells below, make sure to run :

- test_save_stats.py --model=MODEL_NAME--saved_model_path=PATH_TO_MODEL (see source code for more details) + any model parameters needed

The code saves the samples for which the model is sure of its prediction (ie. when it the probability is either really close to 1 (Hate) or close to 0 (Neutral)). <br>
We are now going to visualize the explainability of the model (ie. the importance of words in the model's decision) respectively for True Positives (TP), False Positives (FP), True Negatives (TN) and False Negatives(FN).

### Hyperparameters

In [6]:
module_path

'/Users/ikea/Documents/Brown/course/CS2470/Hate Speech Detection/'

In [53]:
## Put your model hyperparameters here
model_type = 'BasicLSTM'
saved_model_path = module_path + 'saved-models/BasicLSTM_2024-04-27_21-46-28_trained_testAcc=0.7182.pth'
stats_path = module_path + "stats-results/stats_BasicLSTM_2024-04-27_21-46-28_test_bcelosswithlogits.csv"

In [54]:
# Specific model parameters
fix_length = None
context_size = 0
pyramid = []
fcs = []
batch_norm = 0
alpha = 0

### Data Import

In [55]:
training_data = "data/training_data/offenseval-training-v1.tsv"
testset_data = "data/test_data/testset-levela.tsv"
test_labels_data = "data/test_data/labels-levela.csv"

#training_data = "data/french_train.csv"
#testset_data = "data/french_test.csv"

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("Device:", device)

field,_, _, _, _ = get_datasets(training_data, testset_data, test_labels_data, model_type,'en', fix_length, 
                                module_path=module_path)

print("Loading vocabulary...")
vocab_stoi, vocab_itos = get_vocab_stoi_itos(field)
print("Vocabulary Loaded")

Device: cpu
file loaded and formatted..
data split into train/val/test
field objects created
fields and dataset object created
vocabulary built..
Loading vocabulary...
Vocabulary Loaded


In [56]:
from utils.utils import load_model, load_trained_model
print("Loading Model...")
model = load_model(model_type, field, device, fix_length=fix_length)
model = load_trained_model(model, saved_model_path, device)
print("Model Loaded.")

Loading Model...
/Users/ikea/Documents/Brown/course/CS2470/Hate Speech Detection/saved-models/BasicLSTM_2024-04-27_21-46-28_trained_testAcc=0.7182.pth loaded.
Model Loaded.


In [33]:
print("Loading Stats Data..")
df = pd.read_csv(stats_path)
df = df.drop(columns=["Unnamed: 0"])
df.head()

Loading Stats Data..


Unnamed: 0,original_index,text,true_label,pred_label,prob,loss
0,0,<unk> <unk> <unk> <unk> democrats support anti...,1,0,0.001329,6.623433
1,1,"constitutionday <unk> conservatives , hated pr...",0,0,0.000837,0.000837
2,2,foxnews nra maga potus trump <unk> rnc <unk> v...,0,0,2.2e-05,2.3e-05
3,3,watching <unk> getting news still <unk> always...,0,0,0.000703,0.000704
4,4,<unk> : unity demo oppose far - right london –...,1,0,0.192771,1.64625


In [65]:
## Selecting TP, FP, TN, FN

df_tp =   df[(df['true_label']==1) & (df['pred_label']==1) ]
df_fp =   df[(df['true_label']==0) & (df['pred_label']==1) ]
df_tn =   df[(df['true_label']==0) & (df['pred_label']==0) ]
df_fn =   df[(df['true_label']==1) & (df['pred_label']==0) ]

print("TP, FP, TN, FN selected from loaded data.")

TP, FP, TN, FN selected from loaded data.


### Definition of methods to Visualize Importance of Words

We modified and adapted code from Captum (in particular, visualization.visualize_text) to fit our context.

In [58]:
def interpret_sentence(model, field, pad_ind, input_data, sentence, vocab_stoi, vocab_itos, 
                       device, original_idx, vis_data_records_ig,
                       token_reference, lig, min_len = 7, label = 0, class_names=["Neutral","Hate"]):
    
    indexed = [int(input_data[i,0]) for i in range(input_data.shape[0])]
    if len(indexed) < min_len :
        indexed +=[pad_ind] * (min_len - len(indexed))

    text = [vocab_itos[tok] for tok in indexed]

    if len(text) < min_len:
        text += [pad_ind] * (min_len - len(text))

    indexed = [vocab_stoi[t] for t in text]
    input_indices = torch.tensor(indexed, device=device).unsqueeze(0).permute(1,0)

    model.zero_grad()

    # input_indices dim: [sequence_length]
    seq_length = input_indices.shape[0]
    #seq_length = input_data.shape[0]

   # input_indices = input_data

    # predict
    out = model.forward(input_data.to(device))
    out = torch.sigmoid(out)
    pred = out.item()
    pred_ind = round(pred)
    

    # generate reference indices for each sample
    reference_indices = token_reference.generate_reference(seq_length, device=device).unsqueeze(0).permute(1, 0)

    # compute attributions and approximation delta using layer integrated gradients
    attributions_ig, delta = lig.attribute(input_indices, reference_indices,\
                                           n_steps=200, return_convergence_delta=True)

    #print('pred: ', class_names[pred_ind], '(', '%.2f'%pred, ')', ', delta: ', abs(delta))

    add_attributions_to_visualizer(attributions_ig, vocab_itos, text, pred, pred_ind, label, delta, 
                                   original_idx, vis_data_records_ig,
                                   class_names)

def add_attributions_to_visualizer(attributions, vocab_itos, text, pred, pred_ind, label, delta, 
                                   original_idx, vis_data_records,
                                   class_names):
    attributions = attributions.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.cpu().detach().numpy()

    # storing couple samples in an array for visualization purposes
    vis_data_records.append(VisualizationDataRecordCustom(
                            attributions,
                            pred,
                            class_names[pred_ind],
                            class_names[label],
                            class_names[label],
                            attributions.sum(),
                            text,
                            delta, 
                            original_idx))


# Data Visualization

We are now going to visualize words' importances in the decision process. <br>
For each category (TP, FP, TN, FN), we visualize importances for both the highest scores and lowest scores.

## True Positives

First we retrieve the highest and lowest scores.

In [63]:
lig = LayerIntegratedGradients(model, model.emb)

In [64]:
lowest_stats_df_tp, highest_stats_df_tp = get_highest_lowest_metric_indexes(df_tp, stats_metric='prob', stats_topk=10)

In [69]:
df_hate = df.iloc[[701,791,371,730,406,853,767,488,259,433]]

In [70]:
dataset_visualization(interpret_sentence, lig, visualize_text, model, vocab_stoi, vocab_itos, df_hate,\
                      field, device, max_samples=10,partial_vis=True,class_names=["Neutral","Hate"])



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



Original Index,True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
701.0,Hate,Hate (0.98),Hate,-0.76,@user nigga stupid trash nt play play bitch : face_with_tears_of_joy : #unk
,,,,,
791.0,Hate,Hate (0.98),Hate,-1.21,$ #unk phone . fucking dumb . #unk
,,,,,
371.0,Hate,Hate (0.98),Hate,-0.15,bitch thinking niggas money tf n’t . #unk
,,,,,
730.0,Hate,Hate (0.99),Hate,-0.92,pet ? ? ? fucking disgusting url #unk
,,,,,
406.0,Hate,Hate (0.99),Hate,0.88,alright let get right god bc mother nature like fuck humans url #unk
,,,,,


#### Highest Scores 

In [66]:
dataset_visualization(interpret_sentence, lig, visualize_text, model, vocab_stoi, vocab_itos, highest_stats_df_tp,\
                      field, device, max_samples=10,partial_vis=True,class_names=["Neutral","Hate"])



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



Original Index,True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
491.0,Hate,Hate (0.76),Hate,1.87,". grown ass woman , probably 10 years older currently spreading #unk rather talking , nice work got satan : red_heart : #unk"
,,,,,
637.0,Hate,Neutral (0.35),Hate,1.5,#unk bitches : party_popper : : blue_heart : #unk
,,,,,
736.0,Hate,Hate (0.97),Hate,1.07,#unk #unk n’t hate .... hate . worst enemy cause let anger get way legit feel like one kids work job #unk
,,,,,
264.0,Hate,Hate (0.62),Hate,0.86,"democrat controlled city strict gun control laws @user wants blame gop um , n’t rich democrats trying keep blacks poverty ? walkaway url #unk"
,,,,,
202.0,Hate,Neutral (0.39),Hate,1.39,@user @user #unk #unk cute ! old ? #unk
,,,,,


## False Positives

In [67]:
lowest_stats_df_fp, highest_stats_df_fp = get_highest_lowest_metric_indexes(df_fp, stats_metric='prob', stats_topk=10)

#### Highest Scores

In [68]:
dataset_visualization(interpret_sentence, lig, visualize_text, model, vocab_stoi, vocab_itos, highest_stats_df_fp,\
                      field, device, max_samples=10,partial_vis=True,class_names=["Neutral","Hate"])



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



Original Index,True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
245.0,Neutral,Neutral (0.09),Neutral,2.06,@user @user ... show us hot ? #unk
,,,,,
267.0,Neutral,Neutral (0.29),Neutral,2.28,@user @user @user @user penny believes intelligent knows ..... think may #unk canadian ..... ever . canadians fun watch . get twitter us find anyone opinion . like rush ... hockey #unk
,,,,,
60.0,Neutral,Hate (0.97),Neutral,-0.45,@user @user @user got pretty deep debate friend told #unk trump blacks trump paid supporters : face_with_tears_of_joy : said mean antifa paid domestic terrorist said anti - fascist said fascist kidding ? ! #unk
,,,,,
808.0,Neutral,Neutral (0.00),Neutral,3.09,"9 ) #unk , thoughtful people , taking #unk category #unk communities _ _ _ _ _ #unk"
,,,,,
837.0,Neutral,Hate (0.89),Neutral,2.41,@user laws law abiding citizens . second amendment . n’t need gun control . criminals n’t care use guns end time ! figure stop using guns ! #unk
,,,,,


## True Negatives

In [None]:
lowest_stats_df_tn, highest_stats_df_tn = get_highest_lowest_metric_indexes(df_tn, stats_metric='prob', stats_topk=10)

#### Lowest Scores

In [None]:
dataset_visualization(interpret_sentence, lig, visualize_text, model, vocab_stoi, vocab_itos, lowest_stats_df_tn,\
                      field, device, max_samples=10,partial_vis=True,class_names=["Neutral","Hate"])



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



Original Index,True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
671.0,Neutral,Neutral (0.00),Neutral,1.39,"celebrities #unk #unk pregnant ! bachelor star expecting first child boyfriend #unk : #unk #unk whole new reality — going mom ! bachelor #unk , 23 , .. via url url url #unk"
,,,,,
316.0,Neutral,Neutral (0.00),Neutral,0.11,#unk security incidents rarely emerge fully formed #unk lights alert . see ’re prepared testing skills following scenario . url url #unk
,,,,,
545.0,Neutral,Neutral (0.00),Neutral,1.67,"art culture #unk #unk #unk : apology tour : #unk #unk #unk taken social issues . last season 's gun control episode stands shining example , show always ability .. via url url #unk"
,,,,,
388.0,Neutral,Neutral (0.00),Neutral,-0.44,#unk #unk #unk x #unk 2018 f / w @user : link : url : link : url : link : url : link : url url #unk
,,,,,
321.0,Neutral,Neutral (0.00),Neutral,0.06,"#unk definitely tad high puerto rico . democrats provide proof ? providing names , #unk , next #unk n’t hard . hope n’t mind , holding breath . url via @user url #unk"
,,,,,


## False Negatives

In [None]:
lowest_stats_df_fn, highest_stats_df_fn = get_highest_lowest_metric_indexes(df_fn, stats_metric='prob', stats_topk=10)

#### Lowest Scores

In [None]:
dataset_visualization(interpret_sentence, lig, visualize_text, model, vocab_stoi, vocab_itos, lowest_stats_df_fn,\
                      field, device, max_samples=10,partial_vis=True,class_names=["Neutral","Hate"])



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



Original Index,True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
275.0,Hate,Neutral (0.01),Hate,1.01,"california patriots ... especially #unk obama / soros attempting take house seat get control congress . say #unk #unk . shill candidate #unk #unk #unk . please get vote david #unk house , #unk 21 ! follow @user url #unk"
,,,,,
312.0,Hate,Neutral (0.02),Hate,2.81,"christian america – go trump ’s example , liberals support open borders , guess conservatives support school shootings . please explain makes america great . #unk"
,,,,,
281.0,Hate,Neutral (0.01),Hate,0.5,conservatives govt run debt #unk austerity cuts rich #unk wealth . #unk url via @user #unk
,,,,,
778.0,Hate,Neutral (0.01),Hate,-0.33,"brexit deal reached - #unk special #unk november , @user sold uk eu ? ? ? better @user finished ! ! @user url #unk"
,,,,,
468.0,Hate,Neutral (0.01),Hate,0.1,nigeria #unk #unk ' incompetent leader nigeria ’s history ' – #unk #unk #unk url #unk via url #unk
,,,,,


# Visualize a sentence by its index

### True Positive for DistillBert

In [None]:
list_indexes = [433, 730, 259, 406]
df_by_indexes = df.iloc[list_indexes]

In [None]:
%%time
dataset_visualization(interpret_sentence, lig, visualize_text, model, vocab_stoi, vocab_itos, df_by_indexes,\
                      field, device, max_samples=10,partial_vis=True,class_names=["Neutral","Hate"])



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



Original Index,True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
433.0,Hate,Hate (0.98),Hate,0.1,@user damn felt shit . loud lol #unk
,,,,,
730.0,Hate,Hate (0.98),Hate,-0.58,pet ? ? ? fucking disgusting url #unk
,,,,,
259.0,Hate,Hate (0.98),Hate,-0.43,! ! ! ! bitch ’m fucking coming back url #unk
,,,,,
406.0,Hate,Hate (0.99),Hate,0.97,alright let get right god bc mother nature like fuck humans url #unk
,,,,,


Wall time: 848 ms


### False Positive for DistillBert

In [None]:
list_indexes = [674, 599, 278, 700]
df_by_indexes = df.iloc[list_indexes]

In [None]:
%%time
dataset_visualization(interpret_sentence, lig, visualize_text, model, vocab_stoi, vocab_itos, df_by_indexes,\
                      field, device, max_samples=10,partial_vis=True,class_names=["Neutral","Hate"])



**MODEL EXPLAINABILITY**

Computing words importance for each sample... 
Computations completed.


**LOADING VISUALIZATION**



Original Index,True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
674.0,Neutral,Neutral (0.38),Neutral,2.44,alex jones #unk #unk mans really got supporters : face_with_tears_of_joy : #unk
,,,,,
599.0,Neutral,Hate (0.69),Neutral,1.99,"#unk actually incredible , #unk shit , always , ’m #unk like 5 days #unk . life good . nice day . #unk"
,,,,,
278.0,Neutral,Hate (0.94),Neutral,1.84,@user exactly ’s bc slick woods #unk look ’s yea ai n’t attractive us shit n’t matter @ lol #unk
,,,,,
700.0,Neutral,Hate (0.92),Neutral,1.61,american #unk really one underrated #unk ever ever ever . fuck cried scene #unk
,,,,,


Wall time: 1.4 s
