#### Explainability of BiLSTM using Attention

In [1]:
import os
import sys

if "google.colab" in sys.modules:
    workspace_dir = '/content/spam-detection'
    branch = 'feature/extended-explainability'
    current_dir = os.getcwd()
    if not os.path.exists(workspace_dir) and current_dir != workspace_dir:
        !git clone https://github.com/RationalEar/spam-detection.git
        os.chdir(workspace_dir)
        !git checkout $branch
        !ls -al
        !pip install -q transformers==4.48.0 scikit-learn pandas numpy
        !pip install -q torch --index-url https://download.pytorch.org/whl/cu126
        !pip install captum --no-deps --ignore-installed
    else:
        os.chdir(workspace_dir)
        !git pull origin $branch

    from google.colab import drive

    drive.mount('/content/drive')

Cloning into 'spam-detection'...
remote: Enumerating objects: 520, done.[K
remote: Counting objects: 100% (189/189), done.[K
remote: Compressing objects: 100% (127/127), done.[K
remote: Total 520 (delta 100), reused 119 (delta 56), pack-reused 331 (from 1)[K
Receiving objects: 100% (520/520), 8.88 MiB | 18.37 MiB/s, done.
Resolving deltas: 100% (258/258), done.
Branch 'feature/extended-explainability' set up to track remote branch 'feature/extended-explainability' from 'origin'.
Switched to a new branch 'feature/extended-explainability'
total 76
drwxr-xr-x 11 root root 4096 Aug 16 15:37 .
drwxr-xr-x  1 root root 4096 Aug 16 15:37 ..
-rw-r--r--  1 root root  584 Aug 16 15:37 docker-compose.yml
-rw-r--r--  1 root root  879 Aug 16 15:37 Dockerfile
-rw-r--r--  1 root root   92 Aug 16 15:37 .dockerignore
drwxr-xr-x  2 root root 4096 Aug 16 15:37 docs
drwxr-xr-x  2 root root 4096 Aug 16 15:37 explainability
drwxr-xr-x  8 root root 4096 Aug 16 15:37 .git
-rw-r--r--  1 root root   38 Aug 1

In [2]:
import torch

import pandas as pd
from utils.constants import DATA_PATH, GLOVE_PATH

DATA_PATH

'/content/drive/MyDrive/Projects/spam-detection-data'

In [3]:
# Load the data
train_df = pd.read_pickle(DATA_PATH + '/data/processed/train.pkl')
test_df = pd.read_pickle(DATA_PATH + '/data/processed/test.pkl')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
from utils.functions import set_seed, build_vocab

# Build vocabulary and load embeddings
set_seed(42)
word2idx, idx2word = build_vocab(train_df['text'])
embedding_dim = 300
max_len = 200

In [5]:
from preprocess.data_loader import load_glove_embeddings

pretrained_embeddings = load_glove_embeddings(GLOVE_PATH, word2idx, embedding_dim)

In [6]:
# Load the trained BiLSTM model
from models.bilstm import BiLSTMSpam

model_path = DATA_PATH + '/trained-models/spam_bilstm_final.pt'
model = BiLSTMSpam(vocab_size=len(word2idx), embedding_dim=embedding_dim,
                   pretrained_embeddings=pretrained_embeddings)
model.load(model_path, map_location=torch.device('cpu'))
model = model.to(device)
model.eval()

BiLSTMSpam(
  (embedding): Embedding(25245, 300)
  (lstm): LSTM(300, 128, num_layers=2, batch_first=True, dropout=0.5, bidirectional=True)
  (attention): Attention(
    (attn): Linear(in_features=256, out_features=1, bias=True)
  )
  (fc1): Linear(in_features=256, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=1, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)

In [7]:
from utils.functions import encode

# Prepare test data
X_test_tensor = torch.tensor([encode(t, word2idx, max_len) for t in test_df['text']])
y_test_tensor = torch.tensor(test_df['label'].values, dtype=torch.float32)

# Move data to device
X_test_tensor = X_test_tensor.to(device)
y_test_tensor = y_test_tensor.to(device)

In [8]:
# Get model predictions
with torch.no_grad():
    model_output = model(X_test_tensor)
    # If model returns a tuple, use the first element (typically the predictions)
    if isinstance(model_output, tuple):
        y_pred_probs = model_output[0]
    else:
        y_pred_probs = model_output

    y_pred = (y_pred_probs > 0.5).float()

#### Attention for BiLSTM

In [9]:
# Generate Attention explanations and compute quality metrics
from explainability.BiLSTMAttentionMetrics import BiLSTMAttentionMetrics

# Initialize the metrics calculator
attention_metrics = BiLSTMAttentionMetrics(
    model=model,
    word2idx=word2idx,
    idx2word=idx2word,
    max_len=max_len,
    device=device
)

print("BiLSTM Attention Metrics Calculator initialized successfully!")
print(f"Model device: {next(model.parameters()).device}")
print(f"Metrics calculator device: {attention_metrics.device}")


BiLSTM Attention Metrics Calculator initialized successfully!
Model device: cuda:0
Metrics calculator device: cuda


In [10]:
from explainability.BiLSTMAttentionMetrics import analyze_test_dataset_influential_words_bilstm

# Analyze your test dataset using BiLSTM attention
results = analyze_test_dataset_influential_words_bilstm(
    model=model,
    word2idx=word2idx,
    idx2word=idx2word,
    test_texts=test_df['text'].tolist(),
    test_labels=test_df['label'].tolist(),
    device=device,
    max_len=max_len,
    top_k=20
)

Initializing BiLSTM attention analyzer...
Analyzing top 20 influential words using BiLSTM attention...
Analyzing 606 texts for influential words using BiLSTM attention...
Processed 10/606 texts...
Processed 20/606 texts...
Processed 30/606 texts...
Processed 40/606 texts...
Processed 50/606 texts...
Processed 60/606 texts...
Processed 70/606 texts...
Processed 80/606 texts...
Processed 90/606 texts...
Processed 100/606 texts...
Processed 110/606 texts...
Processed 120/606 texts...
Processed 130/606 texts...
Processed 140/606 texts...
Processed 150/606 texts...
Processed 160/606 texts...
Processed 170/606 texts...
Processed 180/606 texts...
Processed 190/606 texts...
Processed 200/606 texts...
Processed 210/606 texts...
Processed 220/606 texts...
Processed 230/606 texts...
Processed 240/606 texts...
Processed 250/606 texts...
Processed 260/606 texts...
Processed 270/606 texts...
Processed 280/606 texts...
Processed 290/606 texts...
Processed 300/606 texts...
Processed 310/606 texts...
P

In [11]:
top_spam_words = pd.DataFrame(results['top_spam_words'])
top_ham_words = pd.DataFrame(results['top_ham_words'])
top_overall_words = pd.DataFrame(results['top_overall_words'])
top_spam_words

Unnamed: 0,word,frequency,mean_importance,std_importance,max_importance,total_importance
0,48,5,0.227455,0.004948,0.237351,1.137277
1,72,6,0.221338,0.080654,0.304654,1.328026
2,re,3,0.144532,0.086675,0.240295,0.433597
3,allow,6,0.141002,0.048872,0.166862,0.846014
4,hours,6,0.116046,0.035748,0.158533,0.696275
5,subject,2,0.107624,0.011142,0.118765,0.215247
6,cut,5,0.106666,0.059002,0.177831,0.533331
7,again,3,0.096196,0.066388,0.175948,0.288588
8,sale,3,0.092219,0.104924,0.240599,0.276656
9,combined,2,0.09184,0.000197,0.092036,0.183679


In [12]:
# Compute overall explanation quality metrics
metrics_df, overall_metrics, spam_metrics, ham_metrics = attention_metrics.calculate_overall_metrics(test_df)
metrics_df.describe()

Processing example 1/100Evaluating metrics for text: when i receive message line starting s broken two ...
AUC-Del: 0.8112
AUC-Ins: 0.3567
Comprehensiveness: 0.0188
Jaccard Stability: 0.3056
Evaluating metrics for text: need find something free mortgage quote removed li...
AUC-Del: 1.0000
AUC-Ins: 1.0000
Comprehensiveness: 0.0000
Jaccard Stability: 0.9333
Evaluating metrics for text: join get 4 dvds 49 ea shipping processing details ...
AUC-Del: 0.9511
AUC-Ins: 0.1462
Comprehensiveness: 1.0000
Jaccard Stability: 0.6500
Evaluating metrics for text: <EMAIL> your use yahoo groups subject <URL>...
AUC-Del: 0.7857
AUC-Ins: 0.7962
Comprehensiveness: 1.0000
Jaccard Stability: 1.0000
Evaluating metrics for text: re ilug interesting article free software licences...
AUC-Del: 0.8552
AUC-Ins: 0.1449
Comprehensiveness: 0.0000
Jaccard Stability: 0.9333
Evaluating metrics for text: skinny acoustic bass url <URL> planes case little ...
AUC-Del: 0.5657
AUC-Ins: 0.2410
Comprehensiveness: 0.9996
Jaccar

Unnamed: 0,auc_deletion,auc_insertion,comprehensiveness,jaccard_stability,true_label,predicted_prob
count,98.0,98.0,98.0,98.0,98.0,98.0
mean,0.748655,0.611234,0.2477866,0.632977,0.346939,0.3469491
std,0.220755,0.351215,0.4215891,0.261693,0.478443,0.4784326
min,0.151844,0.043859,0.0,0.219048,0.0,1.166331e-10
25%,0.640217,0.248367,1.890562e-09,0.395238,0.0,2.099531e-09
50%,0.785732,0.716685,1.13232e-06,0.592857,0.0,1.465831e-08
75%,0.950762,0.999999,0.4407776,0.933333,1.0,1.0
max,1.0,1.0,1.0,1.0,1.0,1.0


In [15]:
spam_metrics

Unnamed: 0,0
auc_deletion,0.879023
auc_insertion,0.974886
comprehensiveness,0.076911
jaccard_stability,0.592717
true_label,1.0
predicted_prob,0.970585


In [16]:
ham_metrics

Unnamed: 0,0
auc_deletion,0.679398
auc_insertion,0.418043
comprehensiveness,0.338564
jaccard_stability,0.654365
true_label,0.0
predicted_prob,0.015643


In [19]:
overall_metrics

{'mean': auc_deletion         0.748655
 auc_insertion        0.611234
 comprehensiveness    0.247787
 jaccard_stability    0.632977
 true_label           0.346939
 predicted_prob       0.346949
 dtype: float64,
 'median': auc_deletion         7.857320e-01
 auc_insertion        7.166847e-01
 comprehensiveness    1.132320e-06
 jaccard_stability    5.928571e-01
 true_label           0.000000e+00
 predicted_prob       1.465831e-08
 dtype: float64,
 'std': auc_deletion         0.220755
 auc_insertion        0.351215
 comprehensiveness    0.421589
 jaccard_stability    0.261693
 true_label           0.478443
 predicted_prob       0.478433
 dtype: float64}