#### Explainability of BERT using LayerIntegratedGradients

In [None]:
import os
import sys

if "google.colab" in sys.modules:
    workspace_dir = '/content/spam-detection'
    branch = 'feature/extended-explainability'
    current_dir = os.getcwd()
    if not os.path.exists(workspace_dir) and current_dir != workspace_dir:
        !git clone https://github.com/RationalEar/spam-detection.git
        os.chdir(workspace_dir)
        !git checkout $branch
        !ls -al
        !pip install -q transformers==4.48.0 scikit-learn pandas numpy
        !pip install -q torch --index-url https://download.pytorch.org/whl/cu126
        !pip install captum --no-deps --ignore-installed
    else:
        os.chdir(workspace_dir)
        !git pull origin $branch

    from google.colab import drive

    drive.mount('/content/drive')

In [1]:
import os
import torch

import pandas as pd
from utils.constants import DATA_PATH, MODEL_SAVE_PATH

DATA_PATH

'/home/michael/PycharmProjects/spam-detection-data'

In [2]:
# Load the data
train_df = pd.read_pickle(DATA_PATH + '/data/processed/train.pkl')
test_df = pd.read_pickle(DATA_PATH + '/data/processed/test.pkl')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'

In [3]:
from utils.functions import set_seed, build_vocab

# Build vocabulary and load embeddings
set_seed(42)
word2idx, idx2word = build_vocab(train_df['text'])
embedding_dim = 300
max_len = 200
pretrained_embeddings = None

In [4]:
from models.bert import SpamBERT
from transformers import BertTokenizer

# Initialize BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = SpamBERT(dropout=0.2)

# Load the trained model weights
model_path = os.path.join(MODEL_SAVE_PATH, 'spam_bert_final.pt')
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
model.eval()

SpamBERT(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affi

In [5]:
from models.bert import tokenize_texts

# Tokenize test data
X_test_input_ids, X_test_attention_mask = tokenize_texts(test_df['text'].tolist(), tokenizer)
y_test_tensor = torch.tensor(test_df['label'].values, dtype=torch.float32)

# Move data to device
X_test_input_ids = X_test_input_ids.to(device)
X_test_attention_mask = X_test_attention_mask.to(device)
y_test_tensor = y_test_tensor.to(device)

print(f"Test data prepared: {X_test_input_ids.shape[0]} samples")

Test data prepared: 606 samples


In [6]:
# Get model predictions using BERT tokenized inputs
with torch.no_grad():
    model_output = model(
        input_ids=X_test_input_ids,
        attention_mask=X_test_attention_mask
    )
    # If model returns a tuple, use the first element (typically the predictions)
    if isinstance(model_output, tuple):
        y_pred_probs = model_output[0]
    else:
        y_pred_probs = model_output

    y_pred = (y_pred_probs > 0.5).float()

print(f"Model predictions computed for {len(y_pred)} samples")
print(f"Predicted spam samples: {(y_pred == 1).sum().item()}")
print(f"Predicted ham samples: {(y_pred == 0).sum().item()}")

Model predictions computed for 606 samples
Predicted spam samples: 194
Predicted ham samples: 412


In [None]:
from explainability.BertExplanationMetrics import analyze_test_dataset_influential_words

# Analyze your test dataset
results = analyze_test_dataset_influential_words(
    model=model,
    tokenizer=tokenizer,
    test_texts=test_df['text'].tolist(),
    test_labels=test_df['label'].tolist(),
    device=device,
    top_k=20,
    method='integrated_gradients'
)

Initializing BERT explanation analyzer...
Analyzing top 20 influential words using integrated_gradients...
Analyzing 606 texts for influential words using integrated_gradients...
Processed 10/606 texts...
Processed 20/606 texts...


In [None]:
overall_df = pd.DataFrame(results['top_overall_words'])
spam_df = pd.DataFrame(results['top_spam_words'])
ham_df = pd.DataFrame(results['top_ham_words'])
discriminative_df = pd.DataFrame(results['top_discriminative_words'])

In [None]:
print(f"Unique words in spam: {results['statistics']['spam_unique_words']}")

#### LayerIntegratedGradients for BERT

In [None]:
from explainability.BertExplanationMetrics import BertExplanationMetrics

# Initialize the BERT explanation quality metrics calculator
quality_evaluator = BertExplanationMetrics(model, tokenizer, device)

print("BERT Explanation Quality Metrics Calculator initialized successfully!")

In [None]:
# Compute explanation quality metrics using Integrated Gradients
print("Computing explanation quality metrics using LayerIntegratedGradients...")
print("=" * 60)

ig_results = []

for (i, row) in test_df.iterrows():
    text = row['text']
    subject = row['subject']
    try:
        # Compute metrics using Integrated Gradients
        metrics = quality_evaluator.evaluate_explanation_quality(
            text,
            subject=subject,
            method='integrated_gradients', 
            verbose=False
        )
        
        metrics['sample_id'] = i + 1
        metrics['text'] = subject
        metrics['label'] = 'Spam' if row['label'] == 1 else 'Ham'
        
        ig_results.append(metrics)
        
    except Exception as e:
        print(f"Error processing sample {i+1}: {e}")
        continue

print(f"\nCompleted processing {len(ig_results)} samples with Integrated Gradients.")

In [None]:
metrics_df = pd.DataFrame(ig_results)
metrics_df

In [None]:
# Compute explanation quality metrics using Attention Heads
print("Computing explanation quality metrics using Attention Heads...")
print("=" * 60)

attention_results = []
for (i, row) in test_df.iterrows():
    text = row['text']
    try:
        # Compute metrics using Attention Weights
        metrics = quality_evaluator.evaluate_explanation_quality(
            text,
            subject=row['subject'],
            method='attention', 
            verbose=True
        )
        
        metrics['sample_id'] = i + 1
        metrics['text'] = row['subject']
        metrics['label'] = 'Spam' if row['label'] == 1 else 'Ham'
        
        attention_results.append(metrics)
        
    except Exception as e:
        print(f"Error processing sample {i+1}: {e}")
        continue

print(f"\nCompleted processing {len(attention_results)} samples with Attention Heads.")

In [None]:
attention_metrics_df = pd.DataFrame(attention_results)
attention_metrics_df

In [None]:
# Create comprehensive results summary
import pandas as pd

# Convert results to DataFrames for better visualization
if ig_results:
    ig_df = pd.DataFrame(ig_results)
    ig_df['method'] = 'Integrated_Gradients'

if attention_results:
    attention_df = pd.DataFrame(attention_results)
    attention_df['method'] = 'Attention_Heads'

# Combine results if both methods were successful
if ig_results and attention_results:
    combined_df = pd.concat([ig_df, attention_df], ignore_index=True)
    
    print("COMPREHENSIVE EXPLANATION QUALITY RESULTS")
    print("=" * 60)
    
    # Summary statistics by method
    summary_stats = combined_df.groupby('method')[['auc_deletion', 'auc_insertion', 'comprehensiveness', 'jaccard_stability']].agg(['mean', 'std'])
    print("\nSummary Statistics by Method:")
    print(summary_stats)
    
    # Detailed results by sample
    print("\nDetailed Results by Sample:")
    display_cols = ['sample_id', 'method', 'auc_deletion', 'auc_insertion', 'comprehensiveness', 'jaccard_stability', 'label']
    print(combined_df[display_cols].to_string(index=False))
    
elif ig_results:
    print("Results using Integrated Gradients only:")
    display_cols = ['sample_id', 'auc_deletion', 'auc_insertion', 'comprehensiveness', 'jaccard_stability', 'label']
    print(ig_df[display_cols].to_string(index=False))
    
elif attention_results:
    print("Results using Attention Heads only:")
    display_cols = ['sample_id', 'auc_deletion', 'auc_insertion', 'comprehensiveness', 'jaccard_stability', 'label']
    print(attention_df[display_cols].to_string(index=False))
    
else:
    print("No results available for analysis.")

In [None]:
# Save results for further analysis
import os

# Create results directory if it doesn't exist
results_dir = os.path.join(DATA_PATH, 'results', 'explanation_quality')
os.makedirs(results_dir, exist_ok=True)

# Save detailed results
if 'combined_df' in locals():
    results_file = os.path.join(results_dir, 'bert_explanation_quality_metrics.csv')
    combined_df.to_csv(results_file, index=False)
    print(f"Detailed results saved to: {results_file}")
elif 'ig_df' in locals():
    results_file = os.path.join(results_dir, 'bert_ig_metrics.csv')
    ig_df.to_csv(results_file, index=False)
    print(f"Integrated Gradients results saved to: {results_file}")
elif 'attention_df' in locals():
    results_file = os.path.join(results_dir, 'bert_attention_metrics.csv')
    attention_df.to_csv(results_file, index=False)
    print(f"Attention results saved to: {results_file}")

print("\n" + "="*60)
print("BERT EXPLANATION QUALITY ANALYSIS COMPLETE")
print("="*60)
print("\nKey Findings:")
print("- Successfully computed AUC-Del, AUC-Ins, Comprehensiveness, and Jaccard Stability")
print("- Compared LayerIntegratedGradients vs Attention Head explanations")
print("- Generated deletion and insertion curve visualizations")
print("\nNext Steps:")
print("- Analyze metric patterns across spam vs ham samples")
print("- Compare with other model explanation methods")  
print("- Use insights to improve model interpretability")

print(f"\nResults saved in: {results_dir}")