Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NoiseTunnel on LayerIntegratedGradients Perturbs Input Instead of Attribution Target (RunetimeError) #1281

Open
EldadTalShir opened this issue May 15, 2024 · 0 comments

Comments

@EldadTalShir
Copy link

EldadTalShir commented May 15, 2024

NoiseTunnel on LayerIntegratedGradients Perturbs Input Instead of Attribution Target (RunetimeError)

When applying NoiseTunnel to a LayerIntegratedGradients instance initialized on a specific model layer, the noise application happens on the input regardless of the chosen layer and attribution target.

For example, using a BERT-based model and setting the layer as bert.embeddings with attribute_to_layer_input=False, the noise application happens at the token level instead of at the embedding level. This is the same result when setting the bert.encoder as the layer with attribute_to_layer_input=True.

Results in the following error when the perturbed tokens are processed by the BERT module:

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)

To Reproduce

Steps to reproduce the behavior:

  1. Set-up a BERT-based classifier/regressor
  2. Load the classifier and init the attribution methods (lig and smoothgrad)
  3. Tokenize input and baseline
  4. Run attribution (smoothgrad on lig).
import torch
from captum.attr import NoiseTunnel, LayerIntegratedGradients, TokenReferenceBase
from transformers import AutoTokenizer, AutoModel
from pytorch_lightning import LightningModule

# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class Transformer(LightningModule):
  def __init__(self,model_name='mixedbread-ai/mxbai-embed-large-v1'):
    super().__init__()
    self.bert = AutoModel.from_pretrained(model_name)
    self.fc = torch.nn.Linear(1024,1)

  def forward(self, x):
    '''
    x: inputs['input_ids'] from AutoTokenizer.from_pretrained(model_name).tokenizer(...)
    '''
    outputs = self.bert(x).last_hidden_state[:,0,:]
    prob = self.fc(outputs)
    return prob

# Load SentenceTransformer model
model = Transformer()
model.to(device)
model.eval()
model.zero_grad()

# Init attribution methods
lig = LayerIntegratedGradients(model, model.bert.embeddings)
sg = NoiseTunnel(lig)

# Get token indices for text
sentence = "Captum is great, but I found a bug today."
tokenizer = AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-embed-large-v1')
tokenized_text = tokenizer(sentence, return_tensors="pt", padding=False, truncation=False)
inputs = tokenized_text['input_ids'].to(device)

# Reference token for IG
ref_token_id = tokenizer.pad_token_id  # Use padding token as reference token
token_reference = TokenReferenceBase(reference_token_idx=ref_token_id)
ref = token_reference.generate_reference(len(inputs[0]),device=device)
# Set [CLS] and [SEP] token ids in baseline
ref[0] = 101
ref[-1] = 102

# Run attribution
attributions, delta = sg.attribute(inputs, nt_type='smoothgrad', nt_samples=100, nt_samples_batch_size=50, stdevs=1.0, draw_baseline_from_distrib=False, baselines=ref.unsqueeze(0), internal_batch_size=50, n_steps=50, method='riemann_trapezoid', attribute_to_layer_input=False, return_convergence_delta=True)

Expected behavior

NoiseTunnel applying noise to the attribution target (in the example, the embeddings instead of the tokens).

Environment

 - Captum / PyTorch Version: (0.7.0 / 2.2.1+cu121)
 - OS: Linux
 - How you installed Captum / PyTorch: pip
 - Build command you used (if compiling from source): N/A
 - Python version: 3.10.12
 - CUDA/cuDNN version: 12.2.r12.2
 - GPU models and configuration: T4 (Google Colab)
 - Any other relevant information: N/A
@EldadTalShir EldadTalShir changed the title NoiseTunnel on LayerIntegratedGradients for Token Inputs (RunetimeError) NoiseTunnel on LayerIntegratedGradients Perturbs Input Instead of Attribution Target (RunetimeError) May 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant