In [None]:
import mlflow
from captum.attr import IntegratedGradients, visualization
import torch
from transformers import AutoTokenizer
import matplotlib.pyplot as plt
import numpy as np

# Load your model
model = mlflow.pytorch.load_model("runs:/your_run_id/your_model_path")
model.eval()  # Ensure the model is in evaluation mode

# Determine the device
device = next(model.parameters()).device
print(f"Model is on device: {device}")

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")

# Model wrapper to debug tensor types
class ModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input_ids, attention_mask=None):
        # Ensure that input_ids are always LongTensor
        if input_ids.dtype != torch.long:
            input_ids = input_ids.long()
        print(f"Inside model - input_ids dtype: {input_ids.dtype}, attention_mask dtype: {attention_mask.dtype}")
        outputs = self.model(input_ids, attention_mask=attention_mask)
        return outputs.logits if hasattr(outputs, 'logits') else outputs

wrapped_model = ModelWrapper(model)

# Custom forward function for Integrated Gradients
def forward_func(input_ids, attention_mask=None):
    return wrapped_model(input_ids, attention_mask=attention_mask)

# Initialize Integrated Gradients with custom forward function
ig = IntegratedGradients(forward_func)

# Function to explain predictions
def explain_prediction(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    
    # Ensure input_ids and attention_mask are of type torch.LongTensor
    input_ids = inputs["input_ids"].to(device).long()
    attention_mask = inputs["attention_mask"].to(device).long()
    
    # Check input tensor types before passing to the model
    print("Before model input_ids dtype:", input_ids.dtype)
    print("Before model attention_mask dtype:", attention_mask.dtype)
    
    # Run a forward pass to check if the model works
    with torch.no_grad():
        outputs = wrapped_model(input_ids, attention_mask)
    print("Model output shape:", outputs.shape)
    
    # Now try to get attributions
    try:
        attributions = ig.attribute(input_ids, additional_forward_args=(attention_mask,), target=0, n_steps=50)
        print("Attribution shape:", attributions.shape)
    except Exception as e:
        print("Error during attribution:", str(e))
        attributions = None
    
    return attributions, input_ids

# Function to visualize attributions
def visualize_attributions(text, attributions, input_ids):
    if attributions is None:
        print("No attributions to visualize.")
        return
    
    # Convert attributions to word-level
    word_attributions = attributions.sum(dim=-1).squeeze(0)
    word_attributions = word_attributions / torch.norm(word_attributions)
    word_attributions = word_attributions.cpu().detach().numpy()

    # Decode tokens
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])

    # Remove padding tokens
    tokens = [token for token in tokens if token != '[PAD]']
    word_attributions = word_attributions[:len(tokens)]

    # Visualization
    fig, ax = plt.subplots(figsize=(20, 2))
    visualization.visualize_text_attr(word_attributions, tokens, ax=ax)
    plt.title("Integrated Gradients Attribution")
    plt.show()

    # Print attributions
    for token, attribution in zip(tokens, word_attributions):
        print(f"{token}: {attribution:.4f}")

# Example usage
text = "The CRISPR-Cas9 system has revolutionized gene editing techniques in molecular biology."
attributions, input_ids = explain_prediction(text)
visualize_attributions(text, attributions, input_ids)
