# Interpretability: Attention and Attributions

This notebook provides privacy-preserving interpretability: BiLSTM attention visualization and optional BERT attributions.

In [None]:
import matplotlib.pyplot as plt
import torch

from suicide_detection.models.bilstm_attention import BiLSTMAttention, BiLSTMAttentionConfig

# Demo: construct a tiny vocabulary and model, then visualize attention on a sample
vocab = {'<pad>':0, '<unk>':1, 'i':2, 'feel':3, 'okay':4, 'today':5, 'in':6, 'pain':7}
tokens = ['i','am','in','pain']
ids = torch.tensor([[vocab.get(t,1) for t in tokens]], dtype=torch.long)
attn_mask = torch.ones_like(ids)
cfg = BiLSTMAttentionConfig(vocab_size=max(vocab.values())+1, embedding_dim=16, hidden_dim=8)
model = BiLSTMAttention(cfg)
with torch.no_grad():
    _ = model(ids, attn_mask)
aw = model.last_attn.squeeze(0).numpy()
plt.figure(figsize=(6,1.5)); plt.bar(range(len(tokens)), aw[:len(tokens)]); plt.xticks(range(len(tokens)), tokens); plt.title('BiLSTM Attention (demo)'); plt.show()

## Optional: BERT attributions (Integrated Gradients)
Requires transformers-interpret and captum. Will skip gracefully if unavailable.

In [None]:
try:
    from transformers import AutoModelForSequenceClassification, AutoTokenizer
    from transformers_interpret import SequenceClassificationExplainer
    model_name = 'bert-base-uncased'
    tok = AutoTokenizer.from_pretrained(model_name)
    m = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
    m.eval()
    explainer = SequenceClassificationExplainer(m, tok)
    txt = 'I feel okay today but sometimes I am in pain'
    word_attributions = explainer(txt)
    # Display top attributions (values only, no raw text snippet beyond tokens)
    sorted_attrs = sorted([(w, float(a)) for w,a in zip(explainer.word_attributions, explainer.attributions_sum)], key=lambda x: -abs(x[1]))[:10]
    sorted_attrs
except Exception as e:
    print('Attribution libraries not available or failed:', e)