<div text-align="center">
  <img src="https://raw.githubusercontent.com/FarnoushRJ/MambaLRP/main/assets/MambaLRP_logo.jpeg" width="1000"/>
</div>


<div text-align="center"><h1>🐍 MambaLRP is here! 🎉</h1>

Clone the repository and install MambaLRP.

In [10]:
!git clone https://github.com/FarnoushRJ/MambaLRP.git
!pip install git+file:///content/MambaLRP --quiet

fatal: destination path 'MambaLRP' already exists and is not an empty directory.
  Preparing metadata (setup.py) ... [?25l[?25hdone


Import necessary packages.

In [11]:
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import sys

from mamba_lrp.model.mamba_huggingface import ModifiedMambaForCausalLM
from mamba_lrp.model.utils import *
from mamba_lrp.lrp.utils import relevance_propagation
from mamba_lrp.dataset.general_dataset import get_sst_dataset
import torch
import numpy as np

## Load model

Load model and tokenizer.

In [12]:
!pip install gdown

# Import gdown
import gdown

# Define the file ID and the destination file name
file_id = '1RnIygUDodGeKPqbcEQTOYR5dztpF6X1b'  # Replace with your actual file ID
destination = 'mamba_sst2_weights.pt'  # Desired output file name

# Construct the URL
url = f'https://drive.google.com/uc?id={file_id}'

# Download the file
gdown.download(url, destination, quiet=False)



Downloading...
From (original): https://drive.google.com/uc?id=1RnIygUDodGeKPqbcEQTOYR5dztpF6X1b
From (redirected): https://drive.google.com/uc?id=1RnIygUDodGeKPqbcEQTOYR5dztpF6X1b&confirm=t&uuid=a7a6744f-6be8-4f33-9330-fd6d45f0395b
To: /content/mamba_sst2_weights.pt
100%|██████████| 517M/517M [00:07<00:00, 72.4MB/s]


'mamba_sst2_weights.pt'

In [13]:
# Load tokenizer.
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
tokenizer.eos_token = "<|endoftext|>"
tokenizer.bos_token = "<|startoftext|>"
tokenizer.pad_token = "<|pad|>"
tokenizer.unk_token = "<|unkown|>"
tokenizer.add_tokens(['<|unkown|>', '<|pad|>', "<|startoftext|>"], special_tokens=True)

# Load model.
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", use_cache=True)
resize_token_embeddings(model, len(tokenizer))
model.lm_head = torch.nn.Linear(768, 2, bias=True)

# Load the model's weights
model.load_state_dict(
    torch.load('mamba_sst2_weights.pt', map_location=torch.device('cpu')),
    strict=True
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()

# Make model explainable.
modified_model = ModifiedMambaForCausalLM(model, is_fast_forward_available=False)
modified_model.eval()
model.backbone.embeddings.requires_grad = False
pretrained_embeddings = model.backbone.embeddings

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


## Load dataset

Load SST-2 dataset.

In [14]:
validation_dataset = get_sst_dataset(
    tokenizer=tokenizer,
    truncation=False,
    max_length=None
    )

## Generate explanation

Generate explanation for one sample.

In [15]:
i = 413
input_ids = validation_dataset.__getitem__(i)['input_ids'].unsqueeze(0).to(device)
label = torch.tensor(validation_dataset.__getitem__(i)['label']).long().to(device)
idx = torch.where(input_ids == 0)[1] + 1
input_ids = input_ids[:, :idx]
embeddings = pretrained_embeddings(input_ids)

R, prediction = relevance_propagation(
    model=modified_model,
    embeddings=embeddings,
    targets=label,
    n_classes=2
    )

## Visualization

For simplicity, we use the visualization utilities in Captum to display the results.

In [16]:
from captum.attr import visualization as viz

In [17]:
tokens = []
for id in input_ids[0][1: -2]:
    tokens.append(tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens([id.item()])))
attributions = R[0][1: -2]
attributions = attributions / attributions.max()

In [18]:
# Visualize the attributions
viz.visualize_text([viz.VisualizationDataRecord(
    attributions,
    torch.max(model(input_ids).logits[:, -1, :], dim=1).values.item(),
    torch.argmax(model(input_ids).logits[:, -1, :], dim=1).item(),
    true_class=label.item(),
    attr_class=label.item(),
    attr_score=attributions.sum(),
    raw_input_ids=tokens,
    convergence_score=None
)])

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (2.86),0.0,1.67,at least one scene is so disgusting that viewers may be hard pressed to retain their lunch .
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (2.86),0.0,1.67,at least one scene is so disgusting that viewers may be hard pressed to retain their lunch .
,,,,
