# Explain predictions using Lime

In [3]:
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import json
from lime.lime_text import LimeTextExplainer
import torch
from scipy.special import softmax
from IPython.display import display, HTML
import sqlite3

conn  = sqlite3.connect('../../giicg.db')
data_set= pd.read_sql("Select * from validation_set", conn)
conn.close()


## Load Model and Tokenizer

In [4]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

fine_tuned_model_path = "Mayaryin/gender-prompt-roberta-base"

tokenizer = AutoTokenizer.from_pretrained(fine_tuned_model_path)
model = AutoModelForSequenceClassification.from_pretrained(fine_tuned_model_path).to(device)

with open("finetune/label2id.json", "r") as f:
    label2id = json.load(f)

## Explain

In [5]:
class_names = list(label2id.keys())
id2label = {v: k for k, v in label2id.items()}


def predict(texts):
    # Tokenize and move each tensor to the correct device
    encodings = tokenizer(
        texts, return_tensors="pt", truncation=True, padding=True
    ).to(device)

    with torch.no_grad():
        outputs = model(**encodings)
        logits = outputs.logits.cpu().numpy() # for batch processing as expected by limes explainer, since it perturbs the text internally
        probs = softmax(logits, axis=1)
    return probs


## Set up custom explainer with roberta tokenizer

In [6]:
class SubwordLimeTextExplainer(LimeTextExplainer):
    def __init__(self, hf_tokenizer, **kwargs):
        super().__init__(**kwargs)
        self.hf_tokenizer = hf_tokenizer

    def tokenize(self, text):
        # Tokenize the text into subwords (by default returns list of strings/tokens)
        # Note: This usually includes special tokens, so we skip those
        tokens = self.hf_tokenizer.tokenize(text)
        return tokens

    def untokenize(self, tokens):
        # Convert the list of subword tokens back to a text string
        return self.hf_tokenizer.convert_tokens_to_string(tokens)

In [None]:
from collections import defaultdict
from tqdm import tqdm


explainer = SubwordLimeTextExplainer(hf_tokenizer=tokenizer, class_names=class_names)

# Sample column name is 'text'
importance_agg = defaultdict(float)  # token -> sum of scores
token_counts = defaultdict(int)      # token -> number of appearances in explanations

for sample_text in tqdm(data_set['conversational'], desc="Explaining samples"):
    explanation = explainer.explain_instance(
        sample_text,
        predict,
        num_features=100,  # adjust as needed
        labels=[1]         # or the class index you're interested in
    )

    # Get the explanation as a list of (token, weight) tuples
    token_weights = explanation.as_list(label=1)  # use correct label

    for token, weight in token_weights:
        importance_agg[token] += weight
        token_counts[token] += 1

# Now aggregate: for example, calculate average importance for each token
average_importance = {token: importance_agg[token] / token_counts[token]
                      for token in importance_agg}

# Optionally, sort tokens by their average importance (desc)
sorted_tokens = sorted(average_importance.items(), key=lambda x: abs(x[1]), reverse=True)

# Print top 20 tokens
print("Top tokens by average importance:")
for token, score in sorted_tokens[:20]:
    print(f"{token}: {score:.4f}")


Explaining samples:   3%|â–Ž         | 3/114 [00:50<33:37, 18.17s/it]