<a href="https://colab.research.google.com/github/Kristina-26/LLM-interpretability/blob/main/feature_attribution_experiment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip -q install transformers==4.44.2 captum==0.7.0 torch==2.3.1 scikit-learn==1.5.1 matplotlib==3.9.0 seaborn==0.13.2

In [None]:
import io
import json
import math
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import seaborn as sns
from dataclasses import dataclass
from typing import Dict, Any, List, Tuple
import torch.nn as nn

from google.colab import files

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, f1_score, accuracy_score
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC

from torch.utils.data import Dataset, DataLoader

from transformers import BertTokenizer, BertForSequenceClassification, get_linear_schedule_with_warmup

from captum.attr import (
    IntegratedGradients,
    LayerIntegratedGradients,
    TokenReferenceBase,
    Saliency,
    visualization
)

## 1. Controlled experiment with multinomial logistic regression + coefficient heatmap

### 1.1 Heatmaps method

Mi,c = wi,c * xi

In [None]:
np.random.seed(0)
n_classes = 3
n_per = 6
n = n_classes * n_per  # 18 samples
m = 9                 # 9 features
X = np.random.randn(n, m)
y = np.repeat(np.arange(n_classes), n_per)

# add signal to features [0:3] -> class0, [3:6] -> class1, [6:9] -> class2
X[y==0, :3]  += 3
X[y==1, 3:6] += 3
X[y==2, 6:]  += 3

# add two noisy labels
y_noisy = y.copy()
y_noisy[[2, 13]] = (y_noisy[[2, 13]] + 1) % n_classes

class MultinomialLogReg(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.linear = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.linear(x)

# convert to PyTorch tensors
X_tensor = torch.FloatTensor(X)
y_tensor = torch.LongTensor(y_noisy)

model = MultinomialLogReg(m, n_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.LBFGS(model.parameters(), lr=0.1)

# training loop
def closure():
    optimizer.zero_grad()
    outputs = model(X_tensor)
    loss = criterion(outputs, y_tensor)
    loss.backward()
    return loss

for epoch in range(10):
    optimizer.step(closure)

model.eval()
with torch.no_grad():
    outputs = model(X_tensor)
    _, y_pred_torch = torch.max(outputs, 1)
    acc_torch = accuracy_score(y_noisy, y_pred_torch.numpy())

print(f"PyTorch model accuracy: {acc_torch:.3f}")

# coefficients
coefs = model.linear.weight.data.numpy()

# feature importance: Mi,c = wi,c * xi, average for class level importance
feature_importance = np.zeros((n_classes, m))
for class_idx in range(n_classes):
    class_mask = (y_noisy == class_idx)
    if class_mask.sum() > 0:
        # for each sample in this class find Mi,c = wi,c * xi, then average
        class_samples = X[class_mask]
        for i, sample in enumerate(class_samples):
            Mi_c = coefs[class_idx] * sample
            feature_importance[class_idx] += Mi_c
        feature_importance[class_idx] /= len(class_samples)  # average

for i, row in enumerate(feature_importance):
    print(f"Class {i}: {' '.join([f'{x:6.3f}' for x in row])}")

### 1.2 Integrated gradients

In [None]:
ig = IntegratedGradients(model)

# attributions for each sample
all_attributions = []
baseline = torch.zeros_like(X_tensor[0:1])  # zero baseline

for i in range(n):
    input_sample = X_tensor[i:i+1]
    target_class = int(y_noisy[i])

    # attribution
    attribution = ig.attribute(input_sample, baseline, target=target_class, n_steps=50)
    all_attributions.append(attribution.squeeze().detach().numpy())

attributions = np.array(all_attributions)

# average attribution across samples of same class
class_attributions = np.zeros((n_classes, m))
for class_idx in range(n_classes):
    mask = (y_noisy == class_idx)
    if mask.sum() > 0:
        class_attributions[class_idx] = attributions[mask].mean(axis=0)

In [None]:
plt.figure(figsize=(15, 4))

plt.subplot(1, 3, 1)
sns.heatmap(coefs, annot=True, fmt=".2f", cmap="coolwarm", cbar=True,
            yticklabels=[f"class_{i}" for i in range(n_classes)],
            xticklabels=[f"x{j}" for j in range(m)],
            center=0)
plt.title("Coefficients (wi,c)")

plt.subplot(1, 3, 2)
sns.heatmap(feature_importance, annot=True, fmt=".3f", cmap="coolwarm", cbar=True,
            yticklabels=[f"class_{i}" for i in range(n_classes)],
            xticklabels=[f"x{j}" for j in range(m)],
            center=0)
plt.title("Method 1: Feature importance\n(Mi,c = wi,c * xi)")

plt.subplot(1, 3, 3)
sns.heatmap(class_attributions, annot=True, fmt=".3f", cmap="coolwarm", cbar=True,
            yticklabels=[f"class_{i}" for i in range(n_classes)],
            xticklabels=[f"x{j}" for j in range(m)],
            center=0)
plt.title("Method 2: Integrated gradients\n(avg attribution per class)")

plt.tight_layout()
plt.show()

In [None]:
# 1. Extract Coefficients
coefs = model.linear.weight.data.numpy()

# 2. Calculate Method 1: Analytical Feature Importance (Weight * Input)
feature_importance = np.zeros((n_classes, m))
for class_idx in range(n_classes):
    class_mask = (y_noisy == class_idx)
    if class_mask.sum() > 0:
        class_samples = X[class_mask]
        # Broadcasting: (n_samples, n_features) * (n_features,)
        Mi_c = class_samples * coefs[class_idx]
        feature_importance[class_idx] = Mi_c.mean(axis=0)

# 3. Calculate Method 2: Integrated Gradients
ig = IntegratedGradients(model)
baseline = torch.zeros_like(X_tensor[0:1]) # zero baseline
attributions_list = []

for i in range(n):
    input_sample = X_tensor[i:i+1]
    target_class = int(y_noisy[i])
    attr = ig.attribute(input_sample, baseline, target=target_class, n_steps=50)
    attributions_list.append(attr.detach().numpy().squeeze())

attributions = np.array(attributions_list)

# Average IG attributions per class
class_attributions = np.zeros((n_classes, m))
for class_idx in range(n_classes):
    mask = (y_noisy == class_idx)
    if mask.sum() > 0:
        class_attributions[class_idx] = attributions[mask].mean(axis=0)

# --- VERIFICATION (Addresses Redundancy) ---
# Prove that Method 1 and Method 2 are nearly identical
diff = np.abs(feature_importance - class_attributions).max()
print(f"Max difference between Analytical and IG importance: {diff:.6f}")
print("(This confirms that for Linear Models, IG == Weight * Input. We only need to plot one.)")

# --- PLOTTING ---
plt.figure(figsize=(12, 5))

# Plot 1: Model Coefficients (The "Global" Logic)
plt.subplot(1, 2, 1)
sns.heatmap(coefs, annot=True, fmt=".1f", cmap="coolwarm", cbar=True,
            yticklabels=[f"Class {i}" for i in range(n_classes)],
            xticklabels=[f"x{j}" for j in range(m)],
            center=0, annot_kws={"size": 10}) # Adjusted font size
plt.title("Model coefficients (weights)", fontsize=12, pad=10)

# Plot 2: Feature Attribution (The "Local" Logic)
# We plot the IG results, but title clarifies it matches analytical importance
plt.subplot(1, 2, 2)
sns.heatmap(class_attributions, annot=True, fmt=".1f", cmap="coolwarm", cbar=True,
            yticklabels=[f"Class {i}" for i in range(n_classes)],
            xticklabels=[f"x{j}" for j in range(m)],
            center=0, annot_kws={"size": 10})
plt.title("Feature importance (Method 1)\n and Integrated Gradients (Method 2)", fontsize=12, pad=10)

plt.tight_layout()
plt.show()

### Comparison of feature importance and integrated gradients

both methods give almost identical results when averaged across classes. This makes sense because for linear models (like logistic regression), integrated gradients should converge to the same result as feature importance

In [None]:
print("Coefficients (wi,c):")
for i, row in enumerate(coefs):
    print(f"Class {i}: {' '.join([f'{x:6.2f}' for x in row])}")

print("\n1. Feature importance (Mi,c = wi,c × xi):")
for i, row in enumerate(feature_importance):
    print(f"Class {i}: {' '.join([f'{x:6.9f}' for x in row])}")

print("\n2. Integrated gradients:")
for i, row in enumerate(class_attributions):
    print(f"Class {i}: {' '.join([f'{x:6.9f}' for x in row])}")

# Compute correlation between the two interpretability methods
correlation_matrix = np.corrcoef(feature_importance.flatten(), class_attributions.flatten())[0,1]
print(f"\nCorrelation between feature importance and IG: {correlation_matrix:.20f}")

### Concrete examples with feature importance and IG

In [None]:
sample_indices = [0, 6, 12]  # one from each class
for idx in sample_indices:
    true_class = y[idx]
    noisy_class = y_noisy[idx]
    predicted_class = y_pred_torch[idx].item()

    print(f"\nSample {idx}: True={true_class}, Noisy={noisy_class}, Predicted={predicted_class}")
    print(f"Feature values: {X[idx]}")

    # Method 1: feature importance Mi,c = wi,c × xi
    weights_for_predicted = coefs[predicted_class]
    sample_feature_importance = weights_for_predicted * X[idx]

    print(f"Method 1 (feature importance): {sample_feature_importance}")
    print(f"Method 2 (integrated gradients):  {attributions[idx]}")

    # which features each method identifies as most important
    fi_top3 = np.argsort(np.abs(sample_feature_importance))[-3:]
    ig_top3 = np.argsort(np.abs(attributions[idx]))[-3:]

    # important features based on ground truth
    expected_important = []
    if true_class == 0: expected_important = [0,1,2]
    elif true_class == 1: expected_important = [3,4,5]
    elif true_class == 2: expected_important = [6,7,8]

    print(f"Expected important features:     {expected_important}")
    print(f"Method 1 top 3 features:         {fi_top3}")
    print(f"Method 2 top 3 features:         {ig_top3}")

    # agreement between methods
    agreement = len(set(fi_top3) & set(ig_top3))
    print(f"Agreement (features in both top-3): {agreement}/3")

    # correlation for this sample
    sample_correlation = np.corrcoef(sample_feature_importance, attributions[idx])[0,1]
    print(f"Sample correlation: {sample_correlation:.6f}")

# Feature attribution of text classification

## Classical baseline models

### Data loading

In [None]:
def load_dataset_from_upload() -> pd.DataFrame:
    uploaded = files.upload()
    if not uploaded:
        raise RuntimeError("no file uploaded.")
    fname = list(uploaded.keys())[0]
    print(f"file uploaded: {fname}")
    df = pd.read_csv(io.BytesIO(uploaded[fname]))
    # normalize column names
    cols = {c.lower(): c for c in df.columns}
    if 'text' in cols and cols['text'] != 'text':
        df = df.rename(columns={cols['text']: 'text'})
    if 'label' in cols and cols['label'] != 'label':
        df = df.rename(columns={cols['label']: 'label'})
    assert 'text' in df.columns and 'label' in df.columns
    df = df[['text','label']].dropna().reset_index(drop=True)
    print(f"loaded {len(df)} samples")
    print("class distribution:")
    print(df['label'].value_counts())
    return df

df = load_dataset_from_upload()
df.head()

# encode labels + stratified split
label_encoder = LabelEncoder()
df['label_encoded'] = label_encoder.fit_transform(df['label'])

X_train, X_val, y_train, y_val = train_test_split(
    df['text'], df['label_encoded'], test_size=0.2, stratify=df['label_encoded'], random_state=42
)

num_classes = len(label_encoder.classes_)
print(f"num classes: {num_classes}, classes: {list(label_encoder.classes_)}")

In [None]:
def bootstrap_ci(y_true, y_pred, metric_fn, B: int = 1000, seed: int = 42):
    rng = np.random.default_rng(seed)
    idx = np.arange(len(y_true))
    scores = []
    for _ in range(B):
        s = rng.choice(idx, size=len(idx), replace=True)
        scores.append(metric_fn(np.array(y_true)[s], np.array(y_pred)[s]))
    lo, med, hi = np.percentile(scores, [2.5, 50, 97.5])
    return lo, med, hi

def print_ci(name: str, lo: float, med: float, hi: float):
    print(f"{name}: {med:.3f}  (95% ci: {lo:.3f} … {hi:.3f})")

### Classical models: tf-idf + logistic regression, tf-idf + linear svm

In [None]:
# classical baselines: tf-idf + logistic regression, linear svm (with bootstrap cis)
vectorizer = TfidfVectorizer(max_features=20000, ngram_range=(1,2))
Xtr_vec = vectorizer.fit_transform(X_train)
Xva_vec = vectorizer.transform(X_val)

# logistic regression (multinomial)
logreg = LogisticRegression(max_iter=2000, multi_class='multinomial')
logreg.fit(Xtr_vec, y_train)
yhat_log = logreg.predict(Xva_vec)

# linear svm
linsvm = LinearSVC()
linsvm.fit(Xtr_vec, y_train)
yhat_svm = linsvm.predict(Xva_vec)

# metrics + cis
for name, yhat in [("logreg", yhat_log), ("linear_svm", yhat_svm)]:
    print(f"\n{name} classification report:")
    print(classification_report(y_val, yhat, target_names=label_encoder.classes_, digits=3))
    acc_lo, acc_med, acc_hi = bootstrap_ci(y_val, yhat, accuracy_score, B=1000)
    f1_lo,  f1_med,  f1_hi  = bootstrap_ci(y_val, yhat, lambda a,b: f1_score(a,b,average='weighted'), B=1000)
    print_ci("accuracy", acc_lo, acc_med, acc_hi)
    print_ci("weighted f1", f1_lo, f1_med, f1_hi)

## BERT

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"using device: {device}")

model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
bert_model = BertForSequenceClassification.from_pretrained(
    model_name, num_labels=num_classes
).to(device)

# dataset for BERT using existing train/val split
class BertTextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = list(texts)  # use existing X_train/X_val
        self.labels = list(labels)  # use existing y_train/y_val
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self): return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = int(self.labels[idx])
        enc = self.tokenizer(text, truncation=True, padding='max_length',
                           max_length=self.max_length, return_tensors='pt')
        return {
            'input_ids': enc['input_ids'].squeeze(0),
            'attention_mask': enc['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        }

bert_train_ds = BertTextDataset(X_train, y_train, tokenizer)
bert_val_ds = BertTextDataset(X_val, y_val, tokenizer)
bert_train_loader = DataLoader(bert_train_ds, batch_size=8, shuffle=True)  # smaller batch for stability
bert_val_loader = DataLoader(bert_val_ds, batch_size=16, shuffle=False)

# training
optimizer = optim.AdamW(bert_model.parameters(), lr=2e-5)
bert_model.train()

epochs = 4

for epoch in range(epochs):
    total_loss = 0
    batch_count = 0
    for batch in bert_train_loader:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        batch_count += 1

        if batch_count % 5 == 0:
            print(f"  Epoch {epoch+1}, Batch {batch_count}, Loss: {loss.item():.4f}")

    print(f"Epoch {epoch+1} avg loss: {total_loss/batch_count:.4f}")

# evaluate BERT on same validation set as classical models
bert_model.eval()
bert_predictions = []
bert_true_labels = []

with torch.no_grad():
    for batch in bert_val_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
        predictions = torch.argmax(outputs.logits, dim=1)

        bert_predictions.extend(predictions.cpu().numpy())
        bert_true_labels.extend(labels.cpu().numpy())

bert_accuracy = accuracy_score(bert_true_labels, bert_predictions)
print(f"BERT accuracy on validation set: {bert_accuracy:.3f}")

# setup bert attributions for comparison with classical models
PAD_IND = tokenizer.pad_token_id
bert_token_reference = TokenReferenceBase(reference_token_idx=PAD_IND)

def bert_forward_func(input_ids, attention_mask):
    bert_model.eval()
    return bert_model(input_ids=input_ids, attention_mask=attention_mask).logits

bert_lig = LayerIntegratedGradients(bert_forward_func, bert_model.bert.embeddings.word_embeddings)

### "TF-IDF vocabulary setup for LogReg and SVM interpretability (BERT uses different tokens)

In [None]:
feature_names = vectorizer.get_feature_names_out()
print(f"Total TF-IDF features: {len(feature_names)}")

### 2.2 IG method for tf-idf + logistic regression, tf-idf + linear svm

In [None]:
# convert TF-IDF data to PyTorch tensors
X_val_dense = torch.FloatTensor(Xva_vec.toarray())
y_val_tensor = torch.LongTensor(y_val.values)

print(f"Validation data shape: {X_val_dense.shape}")

# create PyTorch wrapper using existing sklearn weights
class SklearnToPyTorch(nn.Module):
    def __init__(self, sklearn_model):
        super().__init__()
        # copy weights from trained sklearn model
        self.linear = nn.Linear(sklearn_model.coef_.shape[1], sklearn_model.coef_.shape[0])
        self.linear.weight.data = torch.FloatTensor(sklearn_model.coef_)
        self.linear.bias.data = torch.FloatTensor(sklearn_model.intercept_)

    def forward(self, x):
        return self.linear(x)

# pytorch_model = SklearnToPyTorch(logreg)
# pytorch_model.eval()

# ig_text = IntegratedGradients(pytorch_model)

pytorch_logreg = SklearnToPyTorch(logreg)
pytorch_logreg.eval()
pytorch_svm = SklearnToPyTorch(linsvm)
pytorch_svm.eval()

ig_logreg = IntegratedGradients(pytorch_logreg)
ig_svm = IntegratedGradients(pytorch_svm)


## Comparison of feature importance vs IG on logress, SVM, and BERT

In [None]:
from captum.attr import visualization

sample_indices = [1, 5, 10, 15]
sample_texts = [list(X_val)[i] for i in sample_indices if i < len(X_val)]

vis_records_logreg_m1 = []  # LogReg Method 1
vis_records_logreg_m2 = []  # LogReg Method 2
vis_records_svm_m1 = []     # SVM Method 1
vis_records_svm_m2 = []     # SVM Method 2
vis_records_bert_m1 = []    # BERT Method 1
vis_records_bert_m2 = []    # BERT Method 2

# convert TF-IDF attributions back to word-level attributions for visualization
def create_tfidf_word_attribution_record(sample_text, tfidf_attributions, feature_names, predicted_class, confidence, model_name):
    words = sample_text.split()
    word_attributions = []

    for word in words:
        word_attr = 0.0
        word_lower = word.lower().strip('.,!?";')

        for i, feature in enumerate(feature_names):
            if word_lower in feature.split():
                word_attr += tfidf_attributions[i]

        word_attributions.append(word_attr)

    # normalize
    max_attr = max(abs(attr) for attr in word_attributions) if word_attributions else 1
    if max_attr > 0:
        word_attributions = [attr/max_attr for attr in word_attributions]

    return visualization.VisualizationDataRecord(
        word_attributions, confidence, f"{model_name}: {predicted_class}",
        "", "", sum(word_attributions), words, 0.0
    )

def create_bert_word_attribution_record(sample_text, attributions, predicted_class, confidence, method_name):
    original_words = sample_text.split()

    # for Method 1: attributions are per-dimension of [CLS] representation
    # for Method 2: attributions are per-token, need to map back to words
    if method_name == "BERT M1":
        # Method 1: Use magnitude of representation attribution as word importance
        # Simple approach: distribute the overall importance equally across words
        if len(original_words) > 0:
            overall_importance = abs(attributions).mean() if hasattr(attributions, 'mean') else abs(sum(attributions))
            word_attributions = [overall_importance] * len(original_words)
        else:
            word_attributions = [0.0]
    else:
        # Method 2: map BERT tokens back to original words
        enc = tokenizer(sample_text, add_special_tokens=True, padding='max_length',
                       truncation=True, max_length=128, return_tensors='pt').to(device)
        bert_tokens = tokenizer.convert_ids_to_tokens(enc['input_ids'][0])

        word_attributions = []
        token_idx = 1  # skip [CLS]

        for word in original_words:
            word_attr = 0.0
            word_piece_count = 0

            # skip special tokens
            while token_idx < len(bert_tokens) and bert_tokens[token_idx] in ['[SEP]', '[PAD]']:
                break

            # accumulate attributions for word pieces
            while token_idx < len(bert_tokens) and bert_tokens[token_idx] not in ['[SEP]', '[PAD]']:
                if token_idx < len(attributions):
                    word_attr += attributions[token_idx]
                    word_piece_count += 1
                token_idx += 1

                # check if next token is start of new word (doesn't start with ##)
                if token_idx < len(bert_tokens) and not bert_tokens[token_idx].startswith('##'):
                    break

            # average attribution across word pieces
            if word_piece_count > 0:
                word_attr /= word_piece_count

            word_attributions.append(word_attr)

    # normalize
    max_attr = max(abs(attr) for attr in word_attributions) if word_attributions else 1
    if max_attr > 0:
        word_attributions = [attr/max_attr for attr in word_attributions]

    return visualization.VisualizationDataRecord(
        word_attributions, confidence, f"{method_name}: {predicted_class}",
        "", "", sum(word_attributions), original_words, 0.0
    )

for i, (sample_idx, sample_text) in enumerate(zip(sample_indices[:len(sample_texts)], sample_texts)):
    if sample_idx >= len(X_val_dense):
        continue

    print(f"\nSAMPLE {sample_idx}: {sample_text}")
    true_class = label_encoder.classes_[y_val.iloc[sample_idx]]
    print(f"True class: {true_class}")

    ##### LOGREG PREDICTIONS AND ATTRIBUTIONS
    with torch.no_grad():
        sample_tensor = X_val_dense[sample_idx:sample_idx+1]
        logreg_logits = pytorch_logreg(sample_tensor)
        logreg_pred_idx = torch.argmax(logreg_logits, dim=1).item()
        logreg_confidence = torch.softmax(logreg_logits, dim=1)[0][logreg_pred_idx].item()

    logreg_pred_class = label_encoder.classes_[logreg_pred_idx]

    # LogReg Method 1: feature importance (wi,c × xi)
    logreg_coefs = logreg.coef_[logreg_pred_idx]  # wi,c
    sample_tfidf = Xva_vec[sample_idx].toarray()[0]  # xi
    logreg_m1_attrs = logreg_coefs * sample_tfidf  # Mi,c = wi,c × xi

    # LogReg Method 2: integrated gradients
    baseline = torch.zeros_like(sample_tensor)
    logreg_m2_attrs = ig_logreg.attribute(sample_tensor, baseline, target=logreg_pred_idx, n_steps=50)
    logreg_m2_attrs = logreg_m2_attrs.squeeze().detach().numpy()

    ##### SVM PREDICTIONS AND ATTRIBUTIONS
    svm_pred_idx = linsvm.predict(Xva_vec[sample_idx])[0]
    svm_pred_class = label_encoder.classes_[svm_pred_idx]
    svm_confidence = 0.8  # SVM doesn't give probabilities easily

    # SVM Method 1: feature importance (wi,c × xi)
    svm_coefs = linsvm.coef_[svm_pred_idx]  # wi,c
    svm_m1_attrs = svm_coefs * sample_tfidf  # Mi,c = wi,c × xi

    # SVM Method 2: integrated gradients (same as LogReg since same TF-IDF space)
    svm_m2_attrs = logreg_m2_attrs

    ##### BERT PREDICTIONS AND ATTRIBUTIONS
    # tokenize for BERT
    enc = tokenizer(sample_text, add_special_tokens=True, padding='max_length',
                   truncation=True, max_length=128, return_tensors='pt').to(device)

    # get BERT prediction and [CLS] representation
    with torch.no_grad():
        bert_outputs = bert_model(**enc, output_hidden_states=True)
        bert_logits = bert_outputs.logits
        bert_pred_idx = torch.argmax(bert_logits, dim=1).item()
        bert_confidence = torch.softmax(bert_logits, dim=1)[0][bert_pred_idx].item()
        cls_representation = bert_outputs.hidden_states[-1][0, 0, :]  # [CLS] token from last layer

    bert_pred_class = label_encoder.classes_[bert_pred_idx]

    # BERT Method 1: feature importance (wi,c × xi)
    # wi,c: classification head weights for predicted class
    # xi: [CLS] token representation
    bert_classification_weights = bert_model.classifier.weight[bert_pred_idx]  # wi,c
    bert_m1_attrs = bert_classification_weights * cls_representation  # Mi,c = wi,c × xi
    bert_m1_attrs = bert_m1_attrs.detach().cpu().numpy()

    # BERT Method 2: integrated gradients
    seq_len = int(enc['input_ids'].size(1))
    ref = bert_token_reference.generate_reference(seq_len, device=device).unsqueeze(0)

    try:
        atts_ig, delta = bert_lig.attribute(
            inputs=enc['input_ids'], baselines=ref,
            additional_forward_args=(enc['attention_mask'],),
            target=bert_pred_idx, n_steps=50, return_convergence_delta=True
        )
        bert_m2_attrs = atts_ig.sum(dim=2).squeeze(0).detach().cpu().numpy()
    except:
        bert_m2_attrs = [0.0] * seq_len

    # visualization
    feature_names = vectorizer.get_feature_names_out()

    # TF-IDF models (LogReg and SVM)
    vis_records_logreg_m1.append(create_tfidf_word_attribution_record(
        sample_text, logreg_m1_attrs, feature_names, logreg_pred_class, logreg_confidence, "LogReg M1"))
    vis_records_logreg_m2.append(create_tfidf_word_attribution_record(
        sample_text, logreg_m2_attrs, feature_names, logreg_pred_class, logreg_confidence, "LogReg M2"))
    vis_records_svm_m1.append(create_tfidf_word_attribution_record(
        sample_text, svm_m1_attrs, feature_names, svm_pred_class, svm_confidence, "SVM M1"))
    vis_records_svm_m2.append(create_tfidf_word_attribution_record(
        sample_text, svm_m2_attrs, feature_names, svm_pred_class, svm_confidence, "SVM M2"))

    # BERT
    vis_records_bert_m1.append(create_bert_word_attribution_record(
        sample_text, bert_m1_attrs, bert_pred_class, bert_confidence, "BERT M1"))
    vis_records_bert_m2.append(create_bert_word_attribution_record(
        sample_text, bert_m2_attrs, bert_pred_class, bert_confidence, "BERT M2"))


# display
for i in range(len(sample_texts)):
    if i >= len(vis_records_logreg_m1):
        break

    print(f"\n" + "="*60)
    print(f"SENTENCE {i+1}: {sample_texts[i]}")
    print("="*60)

    print("\nLogistic regression - Method 1 (feature importance: wi,c * xi):")
    try:
        _ = visualization.visualize_text([vis_records_logreg_m1[i]])
    except Exception as e:
        print(f"Error: {e}")

    print("\nLogistic regression - Method 2 (integrated gradients):")
    try:
        _ = visualization.visualize_text([vis_records_logreg_m2[i]])
    except Exception as e:
        print(f"Error: {e}")

    print("\nSVM - Method 1 (feature importance: wi,c * xi):")
    try:
        _ = visualization.visualize_text([vis_records_svm_m1[i]])
    except Exception as e:
        print(f"Error: {e}")

    print("\nSVM - Method 2 (integrated gradients):")
    try:
        _ = visualization.visualize_text([vis_records_svm_m2[i]])
    except Exception as e:
        print(f"Error: {e}")

    print("\nBERT - Method 1 (feature importance: wi,c * xi):")
    try:
        _ = visualization.visualize_text([vis_records_bert_m1[i]])
    except Exception as e:
        print(f"Error: {e}")

    print("\nBERT - Method 2 (integrated gradients):")
    try:
        _ = visualization.visualize_text([vis_records_bert_m2[i]])
    except Exception as e:
        print(f"Error: {e}")

In [None]:
# top attributed features for a sample
def analyze_text_sample(sample_idx, model, ig_method, feature_names, n_top=10):
    sample = X_val_dense[sample_idx:sample_idx+1]
    true_label = y_val_tensor[sample_idx].item()

    # prediction
    with torch.no_grad():
        logits = model(sample)
        predicted_class = torch.argmax(logits, dim=1).item()
        probs = torch.softmax(logits, dim=1)[0]

    # attributions
    baseline = torch.zeros_like(sample)
    attributions = ig_method.attribute(sample, baseline, target=predicted_class, n_steps=50)
    attributions = attributions.squeeze().detach().numpy()

    print(f"\nSample {sample_idx}:")
    print(f"  True class: {label_encoder.classes_[true_label]}")
    print(f"  Predicted: {label_encoder.classes_[predicted_class]}")
    print(f"  Confidence: {probs[predicted_class]:.3f}")

    # top attributed features
    top_indices = np.argsort(np.abs(attributions))[-n_top:][::-1]
    print(f"  Top {n_top} attributed features:")
    for idx in top_indices:
        if attributions[idx] != 0:  # only show non-zero attributions
            print(f"    {feature_names[idx]}: {attributions[idx]:.4f}")

    return attributions, predicted_class

# sample texts from validation set
# logreg
print("\nAnalyzing individual text samples with IG:")
print("\nLogReg")
sample_attributions = []
for i in [0, 5, 10, 15]:
    if i < len(X_val_dense):
        attrs, pred = analyze_text_sample(i, pytorch_logreg, ig_logreg, feature_names, n_top=8)
        sample_attributions.append(attrs)

# svm
print("\nSVM")
sample_attributions = []
for i in [0, 5, 10, 15]:
    if i < len(X_val_dense):
        attrs, pred = analyze_text_sample(i, pytorch_svm, ig_svm, feature_names, n_top=8)
        sample_attributions.append(attrs)