### Dependencies

In [None]:
%pip install -U "transformers[torch]" accelerate datasets captum shap bertviz matplotlib pandas fsspec "huggingface_hub>=0.24.0" -q

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import spearmanr
from datasets import load_dataset
from bertviz import head_view, model_view
from captum.attr import LayerIntegratedGradients
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Model

In [None]:
model = AutoModelForSequenceClassification.from_pretrained('distilbert/distilbert-base-uncased', num_labels=2)
tokenizer = AutoTokenizer.from_pretrained('distilbert/distilbert-base-uncased')

## Data

In [None]:
def tokenize(batch):
    return tokenizer(batch['text'], truncation=True)

In [None]:
# ds = load_dataset("imdb", ignore_verifications=True)
ds = load_dataset("imdb")  
ds = ds.map(tokenize, batched=True)

In [None]:
collator = DataCollatorWithPadding(tokenizer=tokenizer)

## Trainer

In [None]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    probs = torch.from_numpy(logits).softmax(dim=-1).numpy()
    preds = probs.argmax(axis=-1)
    acc = accuracy_score(labels, preds)
    f1  = f1_score(labels, preds)
    prec = precision_score(labels, preds)
    rec = recall_score(labels, preds)
    try:
        auc = roc_auc_score(labels, probs[:, 1])
    except Exception:
        auc = np.nan
    return {"accuracy": acc, "f1": f1, "precision": prec, "recall": rec, "roc_auc": auc}

In [None]:
args = TrainingArguments(
    output_dir="./data/signal_explanation/ckpt",
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    gradient_accumulation_steps=4,
    learning_rate=2e-5,
    weight_decay=0.01,
    lr_scheduler_type="linear",
    warmup_ratio=0.1,
    fp16=True,
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    metric_for_best_model="f1",
    load_best_model_at_end=True,
    logging_steps=30,
    dataloader_num_workers=0,
    report_to="none",
)

trainer = Trainer(
    model=model.to(device),
    args=args,
    train_dataset=ds["train"].shuffle(seed=42).select(range(7000)),
    eval_dataset=ds["test"].shuffle(seed=42).select(range(700)),
    data_collator=collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()
metrics = trainer.evaluate()
print(metrics)

In [None]:
model.to(device).eval()

### BertViz

In [None]:
text = "The movie was surprisingly good!"
inputs = tokenizer.encode(text, return_tensors='pt').to(device)

outputs = model(inputs, output_attentions=True)
tokens = tokenizer.convert_ids_to_tokens(inputs[0])

head_view(outputs.attentions, tokens)

In [None]:
model_view(outputs.attentions, tokens)

### Attention Visualization

In [None]:
def plot_attention_heatmap(attn_4d, tokens, layer_idx=0, savepath="data/signal_explanation/attention_heatmap.png"):
    A = attn_4d[layer_idx].mean(axis=0)  # (H,L,L) -> (L,L), среднее по H
    fig, ax = plt.subplots(figsize=(12, 10))
    im = ax.imshow(A, aspect="auto")
    ax.set_xticks(range(len(tokens))); ax.set_yticks(range(len(tokens)))
    ax.set_xticklabels(tokens, rotation=90)
    ax.set_yticklabels(tokens)
    ax.set_title(f"Attention heatmap (layer {layer_idx}, mean heads)")
    fig.colorbar(im, ax=ax, fraction=0.02)
    plt.tight_layout()
    plt.savefig(savepath, dpi=200)
    plt.close()
    return savepath

### Integrated Gradients (token-importance)

In [None]:
def forward_for_class(input_ids, attention_mask, target_class):
    embeds = model.get_input_embeddings()(input_ids)
    logits = model(inputs_embeds=embeds, attention_mask=attention_mask, return_dict=True).logits
    return logits[:, target_class]

In [None]:
def token_importance_ig(text, target_class=None, max_len=128, n_steps=32):
    model.eval()
    enc = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_len).to(device)
    input_ids = enc["input_ids"]
    attention_mask = enc["attention_mask"]
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist())

    with torch.no_grad():
        probs = model(**enc).logits.softmax(-1)
        if target_class is None:
            target_class = int(probs.argmax(-1).item())

    lig = LayerIntegratedGradients(
        lambda inp: forward_for_class(inp, attention_mask, target_class), 
        layer=model.get_input_embeddings()
    )
    
    attributions, _ = lig.attribute(
        inputs=input_ids,
        baselines=torch.zeros_like(input_ids),
        additional_forward_args=None,
        n_steps=n_steps, 
        return_convergence_delta=True
    )

    token_scores = attributions.sum(dim=-1).detach().cpu().numpy()[0]  # (L,)

    s = token_scores
    s_norm = (s - s.min()) / (s.max() - s.min() + 1e-9)
    return tokens, s_norm, target_class, probs[0].detach().cpu().numpy()

In [None]:
def plot_token_importance(tokens, scores, title="Token attribution (IG)", savepath="data/signal_explanation/ig_importance.png"):
    idx = np.arange(len(tokens))
    plt.figure(figsize=(12, 4))
    plt.bar(idx, scores)
    plt.xticks(idx, tokens, rotation=90)
    plt.title(title)
    plt.tight_layout()
    plt.savefig(savepath, dpi=200)
    plt.close()
    return savepath

#### Attention vs Attribution

In [None]:
def is_content_token(tok):
    return tok not in ["[CLS]", "[SEP]", "[PAD]"]

In [None]:
@torch.no_grad()
def get_attentions(model, tokenizer, text, max_len=128):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_len).to(device)
    outputs = model(**inputs, output_attentions=True, return_dict=True)

    attentions = [a[0].detach().cpu().numpy() for a in outputs.attentions]
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].tolist())
    return attentions, tokens, outputs.logits.softmax(-1)[0].cpu().numpy()

In [None]:
def attention_cls_importance(attn_4d, layer_idx=-1):
    """Берём среднее по головам внимание из позиции [CLS] к остальным токенам на выбранном слое."""
    A = attn_4d[layer_idx].mean(axis=0)  # (L,L)
    cls_idx = 0
    imp = A[cls_idx]  # (L,)
    imp = (imp - imp.min()) / (imp.max() - imp.min() + 1e-9)
    return imp

In [None]:
def compare_attention_vs_ig(text, layer_idx=-1, max_len=128):
    attn, toks, probs = get_attentions(model, tokenizer, text, max_len=max_len)
    toks_ig, ig_scores, target_class, probs_full = token_importance_ig(text, max_len=max_len)

    attn_imp = attention_cls_importance(attn, layer_idx=layer_idx)

    # Фильтруем служебные токены
    mask = np.array([is_content_token(t) for t in toks])
    ig_c = ig_scores[mask]
    at_c = attn_imp[mask]
    
    rho, pval = spearmanr(ig_c, at_c)

    def top_k(tokens, scores, k=10):
        idx = np.argsort(scores)[::-1][:k]
        return [(tokens[i], float(scores[i])) for i in idx]

    top_ig = top_k(toks, ig_scores)
    top_attn = top_k(toks, attn_imp)
    
    heatmap_path = plot_attention_heatmap(attn, toks, layer_idx if layer_idx>=0 else len(attn)-1)
    ig_plot_path = plot_token_importance(toks, ig_scores)
    
    return {
        "text": text,
        "tokens": toks,
        "probs": probs.tolist(),
        "pred_class": int(np.argmax(probs)),
        "attention_importance": list(zip(toks, attn_imp)),
        "ig_importance": list(zip(toks, ig_scores)),
        "spearman_rho": float(rho),
        "spearman_pval": float(pval),
        "top_tokens_ig": top_ig,
        "top_tokens_attention": top_attn,
        "plots": {
            "attention_heatmap": heatmap_path, 
            "ig_importance": ig_plot_path
        },
    }

In [None]:
res = compare_attention_vs_ig("The movie was surprisingly good, not boring at all!")
for k, v in res.items():
    print(k)
    print(v)
    print()