#### Explainability of BERT using LayerIntegratedGradients

In [1]:
import os
import sys

if "google.colab" in sys.modules:
    workspace_dir = '/content/spam-detection'
    branch = 'feature/extended-explainability'
    current_dir = os.getcwd()
    if not os.path.exists(workspace_dir) and current_dir != workspace_dir:
        !git clone https://github.com/RationalEar/spam-detection.git
        os.chdir(workspace_dir)
        !git checkout $branch
        !ls -al
        !pip install -q transformers==4.48.0 scikit-learn pandas numpy
        !pip install -q torch --index-url https://download.pytorch.org/whl/cu126
        !pip install captum --no-deps --ignore-installed
    else:
        os.chdir(workspace_dir)
        !git pull origin $branch

    from google.colab import drive

    drive.mount('/content/drive')

remote: Enumerating objects: 9, done.[K
remote: Counting objects:  11% (1/9)[Kremote: Counting objects:  22% (2/9)[Kremote: Counting objects:  33% (3/9)[Kremote: Counting objects:  44% (4/9)[Kremote: Counting objects:  55% (5/9)[Kremote: Counting objects:  66% (6/9)[Kremote: Counting objects:  77% (7/9)[Kremote: Counting objects:  88% (8/9)[Kremote: Counting objects: 100% (9/9)[Kremote: Counting objects: 100% (9/9), done.[K
remote: Compressing objects: 100% (1/1)[Kremote: Compressing objects: 100% (1/1), done.[K
remote: Total 5 (delta 4), reused 5 (delta 4), pack-reused 0 (from 0)[K
Unpacking objects:  20% (1/5)Unpacking objects:  40% (2/5)Unpacking objects:  60% (3/5)Unpacking objects:  80% (4/5)Unpacking objects: 100% (5/5)Unpacking objects: 100% (5/5), 1.26 KiB | 645.00 KiB/s, done.
From https://github.com/RationalEar/spam-detection
 * branch            feature/extended-explainability -> FETCH_HEAD
   338ef0c..5d7141a  feature/extended-explainability -> 

In [2]:
import os
import torch

import pandas as pd
from utils.constants import DATA_PATH, MODEL_SAVE_PATH

DATA_PATH

'/content/drive/MyDrive/Projects/spam-detection-data'

In [3]:
# Load the data
train_df = pd.read_pickle(DATA_PATH + '/data/processed/train.pkl')
test_df = pd.read_pickle(DATA_PATH + '/data/processed/test.pkl')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
from utils.functions import set_seed, build_vocab

# Build vocabulary and load embeddings
set_seed(42)
word2idx, idx2word = build_vocab(train_df['text'])
embedding_dim = 300
max_len = 200
pretrained_embeddings = None

In [5]:
from models.bert import SpamBERT
from transformers import BertTokenizer

# Initialize BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = SpamBERT(dropout=0.2)

# Load the trained model weights
model_path = os.path.join(MODEL_SAVE_PATH, 'spam_bert_final.pt')
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
model.eval()

SpamBERT(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affi

In [6]:
from models.bert import tokenize_texts

# Tokenize test data
X_test_input_ids, X_test_attention_mask = tokenize_texts(test_df['text'].tolist(), tokenizer)
y_test_tensor = torch.tensor(test_df['label'].values, dtype=torch.float32)

# Move data to device
X_test_input_ids = X_test_input_ids.to(device)
X_test_attention_mask = X_test_attention_mask.to(device)
y_test_tensor = y_test_tensor.to(device)

print(f"Test data prepared: {X_test_input_ids.shape[0]} samples")

Test data prepared: 606 samples


In [7]:
# Get model predictions using BERT tokenized inputs
with torch.no_grad():
    model_output = model(
        input_ids=X_test_input_ids,
        attention_mask=X_test_attention_mask
    )
    # If model returns a tuple, use the first element (typically the predictions)
    if isinstance(model_output, tuple):
        y_pred_probs = model_output[0]
    else:
        y_pred_probs = model_output

    y_pred = (y_pred_probs > 0.5).float()

print(f"Model predictions computed for {len(y_pred)} samples")
print(f"Predicted spam samples: {(y_pred == 1).sum().item()}")
print(f"Predicted ham samples: {(y_pred == 0).sum().item()}")

Model predictions computed for 606 samples
Predicted spam samples: 194
Predicted ham samples: 412


#### LayerIntegratedGradients for BERT

In [8]:
from explainability.BertExplanationMetrics import BertExplanationMetrics

# Initialize the BERT explanation quality metrics calculator
quality_evaluator = BertExplanationMetrics(model, tokenizer, device)

print("BERT Explanation Quality Metrics Calculator initialized successfully!")

BERT Explanation Quality Metrics Calculator initialized successfully!


In [9]:
# Compute explanation quality metrics using Integrated Gradients
print("Computing explanation quality metrics using LayerIntegratedGradients...")
print("=" * 60)

ig_results = []

for (i, row) in test_df.iterrows():
    text = row['text']
    subject = row['subject']
    try:
        # Compute metrics using Integrated Gradients
        metrics = quality_evaluator.evaluate_explanation_quality(
            text,
            subject=subject,
            method='integrated_gradients',
            verbose=False
        )

        metrics['sample_id'] = i + 1
        metrics['text'] = subject
        metrics['label'] = 'Spam' if row['label'] == 1 else 'Ham'

        ig_results.append(metrics)

    except Exception as e:
        print(f"Error processing sample {i+1}: {e}")
        continue

print(f"\nCompleted processing {len(ig_results)} samples with Integrated Gradients.")

Computing explanation quality metrics using LayerIntegratedGradients...

Completed processing 606 samples with Integrated Gradients.


In [10]:
metrics_df = pd.DataFrame(ig_results)
metrics_df

Unnamed: 0,auc_deletion,auc_insertion,comprehensiveness,jaccard_stability,computation_time,sample_id,text,label
0,0.451565,0.107920,0.005028,0.334127,0 days 00:00:12.116649,1,"RE: Our friends the Palestinians, Our servants...",Ham
1,0.615704,0.149823,0.500781,0.210714,0 days 00:00:12.005684,2,"Re: Our friends the Palestinians, Our servants...",Ham
2,0.429045,0.112967,0.004510,0.535714,0 days 00:00:11.815992,3,xine src package,Ham
3,0.183767,0.067226,0.005919,0.346825,0 days 00:00:11.916729,4,Re: xine src package,Ham
4,0.218783,0.104998,0.005007,0.966667,0 days 00:00:11.820961,5,"Re: Our friends the Palestinians, Our servants...",Ham
...,...,...,...,...,...,...,...,...
601,0.594187,0.444710,0.564256,0.371825,0 days 00:00:11.835178,602,hurry,Spam
602,0.846331,0.767668,0.049100,0.229762,0 days 00:00:12.085660,603,[ILUG] WILSON KAMELA,Spam
603,0.755331,0.826747,0.008041,0.316270,0 days 00:00:12.136452,604,"How to get 10,000 FREE hits per day to any web...",Spam
604,0.787173,0.894576,0.001132,0.264683,0 days 00:00:13.493403,605,Cannabis Difference,Spam


In [15]:
metrics_df.describe()

Unnamed: 0,auc_deletion,auc_insertion,comprehensiveness,jaccard_stability,computation_time,sample_id
count,606.0,606.0,606.0,606.0,606,606.0
mean,0.476898,0.395897,0.09209,0.532393,0 days 00:00:11.924757250,303.5
std,0.234679,0.327079,0.145554,0.186247,0 days 00:00:01.247627005,175.08141
min,0.08829,0.0,3e-05,0.022222,0 days 00:00:11.262604,1.0
25%,0.280918,0.099329,0.003196,0.396825,0 days 00:00:11.775231750,152.25
50%,0.392662,0.304443,0.018076,0.531349,0 days 00:00:11.849422500,303.5
75%,0.719355,0.830316,0.120473,0.677083,0 days 00:00:11.944917,454.75
max,0.914714,0.934142,0.849144,1.0,0 days 00:00:41.811837,606.0


In [11]:
# Compute explanation quality metrics using Attention Heads
print("Computing explanation quality metrics using Attention Heads...")
print("=" * 60)

attention_results = []
for (i, row) in test_df.iterrows():
    text = row['text']
    try:
        # Compute metrics using Attention Weights
        metrics = quality_evaluator.evaluate_explanation_quality(
            text,
            subject=row['subject'],
            method='attention',
            verbose=True
        )

        metrics['sample_id'] = i + 1
        metrics['text'] = row['subject']
        metrics['label'] = 'Spam' if row['label'] == 1 else 'Ham'

        attention_results.append(metrics)

    except Exception as e:
        print(f"Error processing sample {i+1}: {e}")
        continue

print(f"\nCompleted processing {len(attention_results)} samples with Attention Heads.")



[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Evaluating explanation quality for text: 'Wannabe fathers ramp up testosterone'
Using method: attention
Computing AUC-Del...
Computing AUC-Ins...
Computing Comprehensiveness...
Computing Jaccard Stability...

EXPLANATION QUALITY METRICS
Method:           attention
AUC-Deletion:     0.3300 (lower is better)
AUC-Insertion:    0.3721 (higher is better)
Comprehensiveness: 0.0196 (higher is better)
Jaccard Stability: 0.4921 (higher is better)
Computation Time: 0 days 00:00:01.031523
Evaluating explanation quality for text: 'Australia declares world's largest marine reserve'
Using method: attention
Computing AUC-Del...
Computing AUC-Ins...
Computing Comprehensiveness...
Computing Jaccard Stability...

EXPLANATION QUALITY METRICS
Method:           attention
AUC-Deletion:     0.5376 (lower is better)
AUC-Insertion:    0.1682 (higher is better)
Comprehensiveness: 0.1206 (higher is better)
Jaccard Stability: 0.6619 (higher is bette

In [12]:
attention_metrics_df = pd.DataFrame(attention_results)
attention_metrics_df

Unnamed: 0,auc_deletion,auc_insertion,comprehensiveness,jaccard_stability,computation_time,sample_id,text,label
0,0.553344,0.120436,0.068414,0.553571,0 days 00:00:01.272183,1,"RE: Our friends the Palestinians, Our servants...",Ham
1,0.539277,0.141726,0.027357,0.345238,0 days 00:00:01.239677,2,"Re: Our friends the Palestinians, Our servants...",Ham
2,0.379974,0.126468,0.003247,0.604762,0 days 00:00:01.037393,3,xine src package,Ham
3,0.280834,0.063366,0.007177,0.474206,0 days 00:00:01.148540,4,Re: xine src package,Ham
4,0.307057,0.072941,0.001085,0.586905,0 days 00:00:01.035079,5,"Re: Our friends the Palestinians, Our servants...",Ham
...,...,...,...,...,...,...,...,...
601,0.713385,0.343468,0.671166,0.547619,0 days 00:00:01.080186,602,hurry,Spam
602,0.814591,0.826935,0.284308,0.345238,0 days 00:00:01.289566,603,[ILUG] WILSON KAMELA,Spam
603,0.649499,0.891758,0.023864,0.257937,0 days 00:00:01.320918,604,"How to get 10,000 FREE hits per day to any web...",Spam
604,0.812971,0.909017,0.013985,0.652381,0 days 00:00:02.741550,605,Cannabis Difference,Spam


In [16]:
attention_metrics_df.describe()

Unnamed: 0,auc_deletion,auc_insertion,comprehensiveness,jaccard_stability,computation_time,sample_id
count,606.0,606.0,606.0,606.0,606,606.0
mean,0.454475,0.403134,0.113369,0.535107,0 days 00:00:01.138515914,303.5
std,0.219562,0.331169,0.161241,0.171167,0 days 00:00:01.245813205,175.08141
min,0.106232,0.0,3.5e-05,0.222222,0 days 00:00:00.495906,1.0
25%,0.27743,0.103772,0.005184,0.412698,0 days 00:00:00.983575500,152.25
50%,0.36478,0.286194,0.031868,0.49881,0 days 00:00:01.061230500,303.5
75%,0.679843,0.83689,0.160058,0.652381,0 days 00:00:01.157919250,454.75
max,0.893406,0.93056,0.860855,1.0,0 days 00:00:30.892934,606.0


In [13]:
# Create comprehensive results summary
import pandas as pd

# Convert results to DataFrames for better visualization
if ig_results:
    ig_df = pd.DataFrame(ig_results)
    ig_df['method'] = 'Integrated_Gradients'

if attention_results:
    attention_df = pd.DataFrame(attention_results)
    attention_df['method'] = 'Attention_Heads'

# Combine results if both methods were successful
if ig_results and attention_results:
    combined_df = pd.concat([ig_df, attention_df], ignore_index=True)

    print("COMPREHENSIVE EXPLANATION QUALITY RESULTS")
    print("=" * 60)

    # Summary statistics by method
    summary_stats = combined_df.groupby('method')[['auc_deletion', 'auc_insertion', 'comprehensiveness', 'jaccard_stability']].agg(['mean', 'std'])
    print("\nSummary Statistics by Method:")
    print(summary_stats)

    # Detailed results by sample
    print("\nDetailed Results by Sample:")
    display_cols = ['sample_id', 'method', 'auc_deletion', 'auc_insertion', 'comprehensiveness', 'jaccard_stability', 'label']
    print(combined_df[display_cols].to_string(index=False))

elif ig_results:
    print("Results using Integrated Gradients only:")
    display_cols = ['sample_id', 'auc_deletion', 'auc_insertion', 'comprehensiveness', 'jaccard_stability', 'label']
    print(ig_df[display_cols].to_string(index=False))

elif attention_results:
    print("Results using Attention Heads only:")
    display_cols = ['sample_id', 'auc_deletion', 'auc_insertion', 'comprehensiveness', 'jaccard_stability', 'label']
    print(attention_df[display_cols].to_string(index=False))

else:
    print("No results available for analysis.")

COMPREHENSIVE EXPLANATION QUALITY RESULTS

Summary Statistics by Method:
                     auc_deletion           auc_insertion            \
                             mean       std          mean       std   
method                                                                
Attention_Heads          0.454475  0.219562      0.403134  0.331169   
Integrated_Gradients     0.476898  0.234679      0.395897  0.327079   

                     comprehensiveness           jaccard_stability            
                                  mean       std              mean       std  
method                                                                        
Attention_Heads               0.113369  0.161241          0.535107  0.171167  
Integrated_Gradients          0.092090  0.145554          0.532393  0.186247  

Detailed Results by Sample:
 sample_id               method  auc_deletion  auc_insertion  comprehensiveness  jaccard_stability label
         1 Integrated_Gradients      0.451

In [14]:
# Save results for further analysis
import os

# Create results directory if it doesn't exist
results_dir = os.path.join(DATA_PATH, 'results', 'explanation_quality')
os.makedirs(results_dir, exist_ok=True)

# Save detailed results
if 'combined_df' in locals():
    results_file = os.path.join(results_dir, 'bert_explanation_quality_metrics.csv')
    combined_df.to_csv(results_file, index=False)
    print(f"Detailed results saved to: {results_file}")
elif 'ig_df' in locals():
    results_file = os.path.join(results_dir, 'bert_ig_metrics.csv')
    ig_df.to_csv(results_file, index=False)
    print(f"Integrated Gradients results saved to: {results_file}")
elif 'attention_df' in locals():
    results_file = os.path.join(results_dir, 'bert_attention_metrics.csv')
    attention_df.to_csv(results_file, index=False)
    print(f"Attention results saved to: {results_file}")

print("\n" + "="*60)
print("BERT EXPLANATION QUALITY ANALYSIS COMPLETE")
print("="*60)
print("\nKey Findings:")
print("- Successfully computed AUC-Del, AUC-Ins, Comprehensiveness, and Jaccard Stability")
print("- Compared LayerIntegratedGradients vs Attention Head explanations")
print("- Generated deletion and insertion curve visualizations")
print("\nNext Steps:")
print("- Analyze metric patterns across spam vs ham samples")
print("- Compare with other model explanation methods")
print("- Use insights to improve model interpretability")

print(f"\nResults saved in: {results_dir}")

Detailed results saved to: /content/drive/MyDrive/Projects/spam-detection-data/results/explanation_quality/bert_explanation_quality_metrics.csv

BERT EXPLANATION QUALITY ANALYSIS COMPLETE

Key Findings:
- Successfully computed AUC-Del, AUC-Ins, Comprehensiveness, and Jaccard Stability
- Compared LayerIntegratedGradients vs Attention Head explanations
- Generated deletion and insertion curve visualizations

Next Steps:
- Analyze metric patterns across spam vs ham samples
- Compare with other model explanation methods
- Use insights to improve model interpretability

Results saved in: /content/drive/MyDrive/Projects/spam-detection-data/results/explanation_quality


In [20]:
from explainability.BertExplanationMetrics import analyze_test_dataset_influential_words

# Analyze your test dataset
results = analyze_test_dataset_influential_words(
    model=model,
    tokenizer=tokenizer,
    test_texts=test_df['text'].tolist(),
    test_labels=test_df['label'].tolist(),
    device=device,
    top_k=20,
    method='attention'
)

Initializing BERT explanation analyzer...
Analyzing top 20 influential words using attention...
Analyzing 606 texts for influential words using attention...
Processed 10/606 texts...
Processed 20/606 texts...
Processed 30/606 texts...
Processed 40/606 texts...
Processed 50/606 texts...
Processed 60/606 texts...
Processed 70/606 texts...
Processed 80/606 texts...
Processed 90/606 texts...
Processed 100/606 texts...
Processed 110/606 texts...
Processed 120/606 texts...
Processed 130/606 texts...
Processed 140/606 texts...
Processed 150/606 texts...
Processed 160/606 texts...
Processed 170/606 texts...
Processed 180/606 texts...
Processed 190/606 texts...
Processed 200/606 texts...
Processed 210/606 texts...
Processed 220/606 texts...
Processed 230/606 texts...
Processed 240/606 texts...
Processed 250/606 texts...
Processed 260/606 texts...
Processed 270/606 texts...
Processed 280/606 texts...
Processed 290/606 texts...
Processed 300/606 texts...
Processed 310/606 texts...
Processed 320/6

In [22]:
from explainability.BertExplanationMetrics import analyze_test_dataset_influential_words

# Analyze your test dataset
results = analyze_test_dataset_influential_words(
    model=model,
    tokenizer=tokenizer,
    test_texts=test_df['text'].tolist(),
    test_labels=test_df['label'].tolist(),
    device=device,
    top_k=20,
    method='integrated_gradients'
)

Initializing BERT explanation analyzer...
Analyzing top 20 influential words using integrated_gradients...
Analyzing 606 texts for influential words using integrated_gradients...
Processed 10/606 texts...
Processed 20/606 texts...
Processed 30/606 texts...
Processed 40/606 texts...
Processed 50/606 texts...
Processed 60/606 texts...
Processed 70/606 texts...
Processed 80/606 texts...
Processed 90/606 texts...
Processed 100/606 texts...
Processed 110/606 texts...
Processed 120/606 texts...
Processed 130/606 texts...
Processed 140/606 texts...
Processed 150/606 texts...
Processed 160/606 texts...
Processed 170/606 texts...
Processed 180/606 texts...
Processed 190/606 texts...
Processed 200/606 texts...
Processed 210/606 texts...
Processed 220/606 texts...
Processed 230/606 texts...
Processed 240/606 texts...
Processed 250/606 texts...
Processed 260/606 texts...
Processed 270/606 texts...
Processed 280/606 texts...
Processed 290/606 texts...
Processed 300/606 texts...
Processed 310/606 te