#### Explainability of BiLSTM using Attention

In [None]:
import os
import sys

if "google.colab" in sys.modules:
    workspace_dir = '/content/spam-detection'
    branch = 'feature/extended-explainability'
    current_dir = os.getcwd()
    if not os.path.exists(workspace_dir) and current_dir != workspace_dir:
        !git clone https://github.com/RationalEar/spam-detection.git
        os.chdir(workspace_dir)
        !git checkout $branch
        !ls -al
        !pip install -q transformers==4.48.0 scikit-learn pandas numpy
        !pip install -q torch --index-url https://download.pytorch.org/whl/cu126
        !pip install captum --no-deps --ignore-installed
    else:
        os.chdir(workspace_dir)
        !git pull origin $branch

    from google.colab import drive

    drive.mount('/content/drive')

In [None]:
import torch

import pandas as pd
from utils.constants import DATA_PATH, GLOVE_PATH

DATA_PATH

In [None]:
# Load the data
train_df = pd.read_pickle(DATA_PATH + '/data/processed/train.pkl')
test_df = pd.read_pickle(DATA_PATH + '/data/processed/test.pkl')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [3]:
from utils.functions import set_seed, build_vocab

# Build vocabulary and load embeddings
set_seed(42)
word2idx, idx2word = build_vocab(train_df['text'])
embedding_dim = 300
max_len = 200

In [4]:
from preprocess.data_loader import load_glove_embeddings

pretrained_embeddings = load_glove_embeddings(GLOVE_PATH, word2idx, embedding_dim)

In [5]:
# Load the trained BiLSTM model
from models.bilstm import BiLSTMSpam

model_path = DATA_PATH + '/trained-models/spam_bilstm_final.pt'
model = BiLSTMSpam(vocab_size=len(word2idx), embedding_dim=embedding_dim,
                   pretrained_embeddings=pretrained_embeddings)
model.load(model_path, map_location=torch.device('cpu'))
model = model.to(device)
model.eval()

BiLSTMSpam(
  (embedding): Embedding(25245, 300)
  (lstm): LSTM(300, 128, num_layers=2, batch_first=True, dropout=0.5, bidirectional=True)
  (attention): Attention(
    (attn): Linear(in_features=256, out_features=1, bias=True)
  )
  (fc1): Linear(in_features=256, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=1, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)

In [6]:
from utils.functions import encode

# Prepare test data
X_test_tensor = torch.tensor([encode(t, word2idx, max_len) for t in test_df['text']])
y_test_tensor = torch.tensor(test_df['label'].values, dtype=torch.float32)

# Move data to device
X_test_tensor = X_test_tensor.to(device)
y_test_tensor = y_test_tensor.to(device)

In [7]:
# Get model predictions
with torch.no_grad():
    model_output = model(X_test_tensor)
    # If model returns a tuple, use the first element (typically the predictions)
    if isinstance(model_output, tuple):
        y_pred_probs = model_output[0]
    else:
        y_pred_probs = model_output

    y_pred = (y_pred_probs > 0.5).float()

#### Attention for BiLSTM

In [8]:
# Generate Attention explanations and compute quality metrics
from explainability.BiLSTMAttentionMetrics import BiLSTMAttentionMetrics

# Initialize the metrics calculator
attention_metrics = BiLSTMAttentionMetrics(
    model=model,
    word2idx=word2idx,
    idx2word=idx2word,
    max_len=max_len,
    device=device
)

print("BiLSTM Attention Metrics Calculator initialized successfully!")
print(f"Model device: {next(model.parameters()).device}")
print(f"Metrics calculator device: {attention_metrics.device}")


BiLSTM Attention Metrics Calculator initialized successfully!
Model device: cpu
Metrics calculator device: cpu


In [None]:
from explainability.BiLSTMAttentionMetrics import analyze_test_dataset_influential_words_bilstm

# Analyze your test dataset using BiLSTM attention
results = analyze_test_dataset_influential_words_bilstm(
    model=model,
    word2idx=word2idx,
    idx2word=idx2word,
    test_texts=test_df['text'].tolist(),
    test_labels=test_df['label'].tolist(),
    device=device,
    max_len=max_len,
    top_k=20
)

In [None]:
# Compute overall explanation quality metrics
metrics_df, overall_metrics, spam_metrics, ham_metrics = attention_metrics.calculate_overall_metrics(test_df)
metrics_df.describe()

In [16]:
overall_metrics.describe()

{'mean': auc_deletion         0.748638
 auc_insertion        0.611234
 comprehensiveness    0.247781
 jaccard_stability    0.661087
 true_label           0.346939
 predicted_prob       0.346949
 dtype: float64,
 'median': auc_deletion         7.857322e-01
 auc_insertion        7.166749e-01
 comprehensiveness    1.132333e-06
 jaccard_stability    6.396825e-01
 true_label           0.000000e+00
 predicted_prob       1.465470e-08
 dtype: float64,
 'std': auc_deletion         0.220761
 auc_insertion        0.351215
 comprehensiveness    0.421584
 jaccard_stability    0.232576
 true_label           0.478443
 predicted_prob       0.478433
 dtype: float64}

In [17]:
spam_metrics.describe()

auc_deletion         0.878995
auc_insertion        0.974886
comprehensiveness    0.076916
jaccard_stability    0.604318
true_label           1.000000
predicted_prob       0.970585
dtype: float64

In [18]:
ham_metrics.describe()

auc_deletion         0.679386
auc_insertion        0.418044
comprehensiveness    0.338553
jaccard_stability    0.691245
true_label           0.000000
predicted_prob       0.015643
dtype: float64