### Notebook for Explainable AI for Transformer Architectures

This notebook contains code to get you started on applying explainable AI techniques on your Transformer Model. The code is based on the work described in the paper "XAI for Transformers: Better Explanations through Conservative Propagation" by Ali et al. (2022), to be found here: https://proceedings.mlr.press/v162/ali22a/ali22a.pdf and this is the link to the official repository on Github: https://github.com/AmeenAli/XAI_Transformers

### Imports & loading data

In [None]:
import pandas as pd
import torch
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import os

# Create the 'figures' directory if it does not exist
os.makedirs('figures', exist_ok=True)

from transformers import BertTokenizer, BertForSequenceClassification

# Load your test dataset (adjust the path accordingly)
df = pd.read_csv("../Task5/Datasets/test.csv")

# Confirm expected columns exist
assert 'main_emotion' in df.columns and 'Corrected Sentence' in df.columns, "Dataset must include 'main_emotion' and 'Corrected Sentence' columns."

# Get 3 representative samples per emotion (results in 6x3 sentences if 6 emotions are available)
emotion_samples = {}
for emotion in df['main_emotion'].unique():
    samples = df[df['main_emotion'] == emotion].sample(n=min(3, len(df[df['main_emotion'] == emotion])), random_state=42)['Corrected Sentence'].tolist()
    emotion_samples[emotion] = samples

# Flatten into a list of (sentence, emotion) tuples.
selected_sentences = [(text, label) for label, texts in emotion_samples.items() for text in texts]
print(f"Collected {len(selected_sentences)} samples across {len(emotion_samples)} emotions.")

# Map for numerical model predictions
# Based on your test results + educated guesses for missing labels
label_map = {
    0: "anger",
    1: "happiness",    # Assuming LABEL_1 is happiness (most common remaining)
    2: "fear",
    3: "surprise",
    4: "neutral",     
    5: "sadness",
    6: "disgust"
}


# Convert numeric labels to strings
df['main_emotion'] = df['main_emotion'].map(label_map)



❗NB: This code works for models based on the BERT architecture. If you used a different transformer architecture (DistilBERT, RoBERTa, etc.), you might need to adapt some of the code to fit the architecutre requirements. 

In [None]:
from transformers import BertTokenizer, BertForSequenceClassification

# Load the tokenizer and model correctly
model_name = "wietsedv/bert-base-dutch-cased"
tokenizer = BertTokenizer.from_pretrained(model_name)

# Load the base model
model = BertForSequenceClassification.from_pretrained("../Task5/bertje_emotion_classifier") # Adjust num_labels if needed
model.eval()





## Gradient x Input

The gradient_x_input function calculates token relevance using the Gradient × Input method. It converts input IDs to embeddings and enables gradient tracking. The embeddings are passed through the model, generating logits for prediction. The function computes gradients with respect to the predicted class, indicating how sensitive the prediction is to each token. If no gradients are computed, it raises an error to signal that the embeddings are not properly linked to the output. The gradients are multiplied by the input embeddings to calculate relevance scores, showing how much each token influenced the model’s prediction.

In [None]:
def gradient_x_input(model, tokenizer, text):
    """
    Compute the relevance of each token via Gradient x Input.
    """
    model.eval()
    # Tokenize and retrieve input IDs and attention mask
    inputs = tokenizer(text, return_tensors='pt')
    
    # Get the input embeddings from the model's embedding layer.
    inputs_embeds = model.get_input_embeddings()(inputs.input_ids)
    # Ensure the embeddings require gradients
    inputs_embeds = inputs_embeds.clone().detach().requires_grad_(True)

    # Pass embeddings to the model.
    outputs = model(inputs_embeds=inputs_embeds, attention_mask=inputs['attention_mask'])
    # Predicted class based on logits
    pred_class = outputs.logits.argmax(dim=-1)
    # Select the logit corresponding to the predicted class.
    output = outputs.logits[0, pred_class]

    # Backward pass: get gradients with respect to the input embeddings.
    output.backward()
    grads = inputs_embeds.grad

    # Compute relevance: element-wise multiplication and sum across embedding dimensions.
    relevance = (inputs_embeds * grads).sum(dim=-1).squeeze()
    tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
    
    return tokens, relevance.detach().numpy(), pred_class.item()

import os

def plot_relevance(tokens, relevance, title, emotion_label, filename):
    """
    Plot a bar graph of token relevance.
    Highlights tokens with a relevance above the 75th percentile.
    """
    # Ensure the 'figuresi1' directory exists
    os.makedirs('figuresi1', exist_ok=True)
    
    plt.figure(figsize=(12, 3))
    # Mark tokens with high absolute relevance with red
    threshold = np.percentile(np.abs(relevance), 75)
    colors = ['red' if abs(r) >= threshold else 'grey' for r in relevance]
    plt.bar(range(len(tokens)), relevance, color=colors)
    plt.xticks(range(len(tokens)), tokens, rotation=45, fontsize=10)
    plt.title(f"{title} (Emotion: {emotion_label})", fontsize=12)
    plt.ylabel("Relevance Score", fontsize=10)
    plt.tight_layout()
    plt.savefig(f"figuresi1/{filename}.png", bbox_inches='tight')
    plt.show()



# Run Gradient x Input for each selected sentence
# First, generate all plots
for idx, (text, true_label) in enumerate(selected_sentences):
    tokens, relevance, pred_class = gradient_x_input(model, tokenizer, text)
    id2label = {v: k for k, v in label_map.items()}
    predicted_emotion = label_map.get(pred_class, f'class_{pred_class}')
    plot_relevance(tokens, relevance, f"Sentence {idx+1}", predicted_emotion, f"gradient_x_input_{predicted_emotion}_{idx}")

# Then, print all the details
for idx, (text, true_label) in enumerate(selected_sentences):
    tokens, relevance, pred_class = gradient_x_input(model, tokenizer, text)
    key_tokens = [tok for tok, rel in zip(tokens, relevance) if abs(rel) > 0.2]
    predicted_emotion = label_map.get(pred_class, f'class_{pred_class}')
    print(f"Emotion: {predicted_emotion}\nSentence: {text}\nKey Tokens: {key_tokens}\n")


## Conservative Propagation

The ``modified_attention_forward`` function adjusts attention propagation by calculating attention scores using scaled dot product attention. It computes attention probabilities through softmax, then multiplies them with hidden states to get the context layer. Relevance scores are propagated back through the attention weights, to ensure a conservative flow of information. The ``modified_layernorm_forward`` function normalizes hidden states using the mean and variance, then redistributes relevance based on the normalization values. This follows the paper’s conservative propagation strategy, which stabilizes relevance flow through transformers by correcting how relevance is handled in attention and layer normalization layers.

In [None]:
def modified_attention_forward(attention_layer, hidden_states, relevance_scores, tokens, save_filename=None):
    """
    Modified forward pass for the attention layer that applies Conservative Propagation.
    Visualizes the attention probabilities using a heatmap.
    
    Parameters:
        attention_layer: the specific attention layer (e.g., model.bert.encoder.layer[0].attention.self)
        hidden_states: the input hidden states for the layer
        relevance_scores: the current relevance scores (starting at the output)
        tokens: list of tokens corresponding to the hidden states
        save_filename: if provided, saves the heatmap figure under 'figures/{save_filename}.png'
    
    Returns:
        context_layer: the result after attention
        propagated_relevance: the updated relevance scores after propagation
    """
    # Compute raw attention scores and apply softmax to get probabilities.
    attention_scores = (hidden_states @ hidden_states.transpose(-1, -2)) / hidden_states.size(-1)**0.5
    attention_probs = attention_scores.softmax(dim=-1).detach()
    context_layer = attention_probs @ hidden_states

    # Propagate relevance back using the attention probabilities.
    propagated_relevance = (attention_probs.transpose(-1, -2) @ relevance_scores)
    
    # Visualize the attention probabilities as a heatmap.
    plt.figure(figsize=(10, 8))
    sns.heatmap(attention_probs[0].detach().numpy(), cmap='coolwarm', xticklabels=tokens, yticklabels=tokens)
    plt.title('Attention Scores Heatmap')
    plt.tight_layout()
    
    # Save the figure if a filename is provided.
    if save_filename:
        plt.savefig(f"figuresi1/{save_filename}.png", bbox_inches='tight')
    plt.show()

    return context_layer, propagated_relevance

def modified_layernorm_forward(layernorm_layer, hidden_states, relevance_scores):
    """
    Modified forward pass for layer normalization with Conservative Propagation.
    
    Returns the normalized hidden states and updated relevance scores.
    """
    mean = hidden_states.mean(-1, keepdim=True)
    variance = hidden_states.var(-1, keepdim=True, unbiased=False)
    normed = (hidden_states - mean) / (variance + layernorm_layer.eps).sqrt()
    propagated_relevance = relevance_scores / (variance + layernorm_layer.eps).sqrt()
    return normed.detach(), propagated_relevance

# Process each selected sentence for Step 2 (Modified Attention Propagation)
for idx, (text, true_label) in enumerate(selected_sentences):
    # Tokenize the text and get input embeddings.
    inputs = tokenizer(text, return_tensors='pt')
    inputs_embeds = model.get_input_embeddings()(inputs.input_ids)
    hidden_states = inputs_embeds.clone().detach().requires_grad_(True)
    
    # Run a forward pass to compute model outputs.
    outputs = model(inputs_embeds=hidden_states, attention_mask=inputs['attention_mask'])
    predicted_class = outputs.logits.argmax(dim=-1)
    
    # Initialize relevance scores.
    relevance_scores = torch.zeros_like(hidden_states)
    # A simple initialization: assign the model's logit for the predicted class to the last token.
    relevance_scores[:, -1, :] = outputs.logits[:, predicted_class].unsqueeze(-1)
    
    # Convert token IDs to tokens.
    tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
    
    # Run modified attention propagation on the first encoder layer.
    heatmap_filename = f"modified_attention_sample_{idx+1}"
    context_layer, propagated_relevance = modified_attention_forward(
        model.bert.encoder.layer[0].attention.self,
        hidden_states,
        relevance_scores,
        tokens,
        save_filename=heatmap_filename
    )
    
    # Optional: You can also apply modified_layernorm_forward to further process the propagated relevance.
    # For example:
    normed, relevance_normed = modified_layernorm_forward(model.bert.encoder.layer[0].output.LayerNorm, context_layer, propagated_relevance)


### Pertubating

Perturbating means deliberately modifying or disturbing the input data to observe how the model's prediction changes. It helps identify which parts of the input are most influential by measuring the model's sensitivity to these changes. The ``perturb_input_and_evaluate`` function measures how sensitive the model’s prediction is to specific tokens by progressively modifying the input. It first tokenizes the input text and sorts tokens based on their relevance scores. If perturb_type is 'remove', the **least** relevant tokens are removed first; otherwise, the **most relevant** tokens are removed first. Tokens are replaced with [PAD] to maintain input length consistency. The modified texts are converted back into input tensors and passed through the model. The model’s confidence in its prediction is recorded after each perturbation using softmax probabilities. This allows evaluation of how much each token contributes to the final prediction by observing how the model’s confidence changes as tokens are progressively masked.

In [None]:
def perturb_input_and_evaluate(model, tokenizer, text, relevance_scores, perturb_type='remove'):
    """
    Perturb the input text by removing tokens based on relevance scores and evaluate the model's confidence.
    
    Parameters:
        model: the Transformer model
        tokenizer: the tokenizer corresponding to the model
        text: the input text to perturb
        relevance_scores: the relevance scores used for determining token importance
        perturb_type: 'remove' to remove least relevant tokens first,
                      or any other value to remove most relevant tokens first.
    
    Returns:
        confidences: a list of confidence scores after successive perturbations.
    """
    # Tokenize the text to get a list of tokens.
    tokens = tokenizer.tokenize(text)
    
    # Compute the norm of the relevance score for each token.
    token_relevance_norm = torch.norm(relevance_scores[0], dim=-1)
    
    if perturb_type == 'remove':
        # Remove tokens with lower relevance first.
        sorted_indices = torch.argsort(token_relevance_norm, descending=False)
    else:
        # Remove tokens with higher relevance first.
        sorted_indices = torch.argsort(token_relevance_norm, descending=True)
        
    sorted_indices = sorted_indices.tolist()

    perturbed_texts = []
    # Iteratively remove tokens based on sorted relevance.
    for i in range(1, len(tokens) + 1):
        perturbed_tokens = tokens.copy()
        for idx in sorted_indices[:i]:
            if idx < len(perturbed_tokens):
                perturbed_tokens[idx] = '[PAD]'
        perturbed_text = tokenizer.convert_tokens_to_string(perturbed_tokens)
        perturbed_texts.append(perturbed_text)

    confidences = []
    # Evaluate model confidence on each perturbed text.
    for pert_text in perturbed_texts:
        pert_inputs = tokenizer(pert_text, return_tensors='pt')
        pert_output = model(**pert_inputs)
        confidence = pert_output.logits.softmax(dim=-1).max().item()
        confidences.append(confidence)

    return confidences

# Process each selected sentence for Step 3 (Input Perturbation)
for idx, (text, true_label) in enumerate(selected_sentences):
    # Tokenize and get embeddings for the sentence.
    inputs = tokenizer(text, return_tensors='pt')
    inputs_embeds = model.get_input_embeddings()(inputs.input_ids)
    hidden_states = inputs_embeds.clone().detach().requires_grad_(True)
    
    # Forward pass to get outputs and predicted class.
    outputs = model(inputs_embeds=hidden_states, attention_mask=inputs['attention_mask'])
    predicted_class = outputs.logits.argmax(dim=-1)
    
    # Initialize relevance scores as in the previous step.
    relevance_scores = torch.zeros_like(hidden_states)
    relevance_scores[:, -1, :] = outputs.logits[:, predicted_class].unsqueeze(-1)
    
    # Calculate token-level perturbation confidences.
    confidences_remove = perturb_input_and_evaluate(model, tokenizer, text, relevance_scores, perturb_type='remove')
    
    # Tokenize the sentence to get tokens for the x-axis.
    tokenized_text = tokenizer.tokenize(text)
    min_length = min(len(tokenized_text), len(confidences_remove))
    
    # Plot the model confidence as tokens are removed.
    plt.figure(figsize=(12, 5))
    plt.plot(range(1, min_length + 1), confidences_remove[:min_length], marker='o', color='blue')
    plt.axhline(y=0.5, color='red', linestyle='--', label="50% Confidence Threshold")
    plt.fill_between(
        range(1, min_length + 1),
        confidences_remove[:min_length],
        0.5,
        where=(np.array(confidences_remove[:min_length]) < 0.5),
        color='red',
        alpha=0.2
    )
    plt.title(f"Model Confidence During Token Removal for Sentence {idx+1}")
    plt.xlabel("Number of Tokens Removed")
    plt.ylabel("Confidence")
    plt.legend()
    
    confidence_plot_filename = f"confidence_removal_sentence_{idx+1}"
    plt.savefig(f"figuresi1/{confidence_plot_filename}.png", bbox_inches='tight')
    plt.show()


FileNotFoundError: [Errno 2] No such file or directory: '/mnt/data/XAI_Analysis_Dutch_BERT.md'