In [None]:
import mlflow
from captum.attr import IntegratedGradients, visualization
import torch
from transformers import AutoTokenizer
import matplotlib.pyplot as plt
import numpy as np

# Load your model
model = mlflow.pytorch.load_model("runs:/your_run_id/your_model_path")
model.eval()  # Ensure the model is in evaluation mode

# Determine the device
device = next(model.parameters()).device
print(f"Model is on device: {device}")

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")

# Prepare the model for explanability
class ModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids, attention_mask=attention_mask)
        return outputs.logits if hasattr(outputs, 'logits') else outputs

wrapped_model = ModelWrapper(model)

# Custom forward function for IntegratedGradients
def forward_func(inputs):
    input_ids = inputs[0].long()
    attention_mask = inputs[1].long()
    return wrapped_model(input_ids, attention_mask)

# Initialize Integrated Gradients with custom forward function
ig = IntegratedGradients(forward_func)

# Function to explain predictions
def explain_prediction(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)
    
    print("Input IDs shape:", input_ids.shape)
    print("Input IDs dtype:", input_ids.dtype)
    print("Attention Mask shape:", attention_mask.shape)
    print("Attention Mask dtype:", attention_mask.dtype)
    
    # Run a forward pass to check if the model works
    with torch.no_grad():
        outputs = wrapped_model(input_ids, attention_mask)
    print("Model output shape:", outputs.shape)
    
    # Now try to get attributions
    try:
        attributions = ig.attribute((input_ids, attention_mask), target=1, n_steps=50)
        print("Attribution shape:", attributions[0].shape)
    except Exception as e:
        print("Error during attribution:", str(e))
        attributions = None
    
    return attributions, input_ids

# Function to visualize attributions
def visualize_attributions(text, attributions, input_ids):
    if attributions is None:
        print("No attributions to visualize.")
        return
    
    # Convert attributions to word-level
    word_attributions = attributions[0].sum(dim=-1).squeeze(0)
    word_attributions = word_attributions / torch.norm(word_attributions)
    word_attributions = word_attributions.cpu().detach().numpy()

    # Decode tokens
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])

    # Remove padding tokens
    tokens = [token for token in tokens if token != '[PAD]']
    word_attributions = word_attributions[:len(tokens)]

    # Visualization
    fig, ax = plt.subplots(figsize=(20, 2))
    visualization.visualize_text_attr(word_attributions, tokens, ax=ax)
    plt.title("Integrated Gradients Attribution")
    plt.show()

    # Print attributions
    for token, attribution in zip(tokens, word_attributions):
        print(f"{token}: {attribution:.4f}")

# Example usage
text = "The CRISPR-Cas9 system has revolutionized gene editing techniques in molecular biology."
attributions, input_ids = explain_prediction(text)
visualize_attributions(text, attributions, input_ids)

## Old Code
Dimensionality Issue

In [None]:
# Install Captum if not already installed
!pip install captum

import torch
from captum.attr import IntegratedGradients
import numpy as np

# Function to visualize token attributions
def visualize_token_attributions(input_text, attributions, tokenizer):
    tokens = tokenizer.tokenize(input_text)
    attributions = attributions.sum(dim=-1).squeeze(0).detach().cpu().numpy()

    # Normalize attributions for better visualization
    attributions = (attributions - np.min(attributions)) / (np.max(attributions) - np.min(attributions) + 1e-8)

    # Display tokens with their attribution scores
    for token, score in zip(tokens, attributions[:len(tokens)]):
        print(f"{token}: {score:.4f}")

# Function to compute integrated gradients
def compute_integrated_gradients(model, tokenizer, input_text, label, max_len=256, baseline_text="[PAD]", n_steps=50):
    model.eval()

    # Tokenize input and baseline text
    inputs = tokenizer.encode_plus(
        input_text,
        add_special_tokens=True,
        max_length=max_len,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )

    # Convert input IDs and attention mask to LongTensor to match model's embedding layer requirements
    input_ids = inputs['input_ids'].to(device).type(torch.long)  # Convert input_ids to LongTensor
    attention_mask = inputs['attention_mask'].to(device).type(torch.long)  # Convert attention_mask to LongTensor

    # Generate a baseline that matches input shape, usually padded zeros or "[PAD]"
    baseline_ids = tokenizer.encode(
        baseline_text,
        add_special_tokens=True,
        max_length=max_len,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    ).to(device).type(torch.long)  # Convert baseline_ids to LongTensor

    # Initialize Integrated Gradients object
    ig = IntegratedGradients(forward_func)

    # Define a forward function for the Integrated Gradients
    def forward_func(input_ids, attention_mask):
        # Forward pass through the model
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.squeeze(1)  # Adjust this as per your model's output shape

    # Compute attributions using integrated gradients
    attributions, delta = ig.attribute(
        inputs=input_ids,
        baselines=baseline_ids,
        target=label,
        additional_forward_args=(attention_mask,),
        n_steps=n_steps,
        return_convergence_delta=True
    )

    # Visualize attributions
    print(f"Integrated Gradients Delta: {delta.item():.4f}")
    visualize_token_attributions(input_text, attributions, tokenizer)

# Example usage: Apply IG to a specific test sample
sample_text = "Replace this with an example text from your dataset."
true_label = 1  # Replace with the correct label for this example

# Call the function with the sample text
compute_integrated_gradients(model, tokenizer, sample_text, true_label)
