# Task 3: Saliency Mapping & Model Interpretability
## Interpreting DistilBERT + LoRA AI Text Classifier

**Objective**: Perform post-hoc interpretability analysis on the fine-tuned DistilBERT model to understand which tokens/features drive AI vs Human classification decisions.

**Methods**:
- **Captum Integrated Gradients** - Token-level attribution analysis
- **HTML Heatmap Visualization** - Color-coded importance scores
- **Error Analysis** - Examining misclassified or borderline cases

**Key Questions**:
1. Does the model rely on specific lexical markers ("AI-isms")?
2. Does it capture broader structural patterns (rhythm, coherence)?
3. What causes false positives (Human → AI)?

---

## 1. Setup & Imports

In [None]:
import torch
import numpy as np
import pandas as pd
import json
from pathlib import Path
from IPython.display import HTML, display
import warnings
warnings.filterwarnings('ignore')

# Transformers & PEFT
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import PeftModel

# Captum for interpretability
from captum.attr import IntegratedGradients, LayerIntegratedGradients
from captum.attr import visualization as viz

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("✓ All imports successful")
print(f"Device: {device}")
print(f"PyTorch version: {torch.__version__}")

---
## 2. Load Fine-Tuned DistilBERT + LoRA Model

In [None]:
# Load the fine-tuned model and tokenizer
MODEL_PATH = Path('../output/tier_c_models/tier_c_lora_model')
MAX_LENGTH = 256

print(f"Loading model from: {MODEL_PATH}")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

# Load model (handles LoRA automatically if saved with PEFT)
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_PATH,
    num_labels=2
)
model.to(device)
model.eval()  # Set to evaluation mode

print(f"✓ Model loaded successfully")
print(f"✓ Tokenizer vocab size: {len(tokenizer)}")
print(f"✓ Model has {model.config.num_labels} classes (0=Human, 1=AI)")
print(f"✓ Model parameters: {sum(p.numel() for p in model.parameters()):,}")

---
## 3. Load Test Samples

Load some AI and Human samples from the test set for analysis.

In [None]:
# Load sample AI-generated texts
AI_SAMPLES_PATH = Path('../output/class2')

sample_ai_texts = []
for jsonl_file in AI_SAMPLES_PATH.glob('*.jsonl'):
    with open(jsonl_file, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                entry = json.loads(line.strip())
                text = entry.get('text') or entry.get('paragraph') or entry.get('content', '')
                if text and len(text.split()) >= 50:
                    sample_ai_texts.append(text)
                    if len(sample_ai_texts) >= 10:
                        break
            except:
                continue
    if len(sample_ai_texts) >= 10:
        break

# Load sample Human-written texts
HUMAN_SAMPLES_PATH = Path('../output/class1')

sample_human_texts = []
for txt_file in HUMAN_SAMPLES_PATH.glob('*_cleaned.txt'):
    with open(txt_file, 'r', encoding='utf-8') as f:
        text = f.read()
        # Chunk into paragraphs
        words = text.split()
        for i in range(0, min(len(words), 2000), 200):  # Sample first 10 chunks
            chunk = ' '.join(words[i:i+200])
            if len(chunk.split()) >= 50:
                sample_human_texts.append(chunk)
                if len(sample_human_texts) >= 10:
                    break
    if len(sample_human_texts) >= 10:
        break

print(f"✓ Loaded {len(sample_ai_texts)} AI-generated samples")
print(f"✓ Loaded {len(sample_human_texts)} Human-written samples")

# Preview samples
print(f"\n{'='*70}")
print("SAMPLE AI TEXT (first 300 chars):")
print(sample_ai_texts[0][:300] + "...")
print(f"\n{'='*70}")
print("SAMPLE HUMAN TEXT (first 300 chars):")
print(sample_human_texts[0][:300] + "...")

---
## 4. Integrated Gradients Attribution Implementation

In [None]:
def predict_with_confidence(text, model, tokenizer, device):
    """
    Get model prediction and confidence for a text sample.
    
    Returns:
        tuple: (prediction_class, confidence, probabilities)
    """
    inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=MAX_LENGTH, padding='max_length')
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)
    
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        probs = torch.softmax(logits, dim=1)
        prediction = torch.argmax(probs, dim=1).item()
        confidence = probs[0, prediction].item()
    
    return prediction, confidence, probs[0].cpu().numpy()


def compute_attributions(text, model, tokenizer, device, target_class=1, n_steps=50):
    """
    Compute token-level attributions using Integrated Gradients.
    
    Args:
        text: Input text to analyze
        model: Fine-tuned model
        tokenizer: Tokenizer
        device: torch device
        target_class: Class to explain (0=Human, 1=AI)
        n_steps: Number of integration steps
    
    Returns:
        tuple: (tokens, attributions, prediction, confidence)
    """
    # Tokenize
    inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=MAX_LENGTH, padding='max_length')
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)
    
    # Get prediction
    prediction, confidence, probs = predict_with_confidence(text, model, tokenizer, device)
    
    # Define forward function that returns target class logit
    def forward_func(input_ids, attention_mask):
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        # Return logit for target class
        return outputs.logits[:, target_class]
    
    # Initialize Integrated Gradients
    ig = IntegratedGradients(forward_func)
    
    # Baseline: all PAD tokens (ID 0 in DistilBERT)
    baseline_ids = torch.zeros_like(input_ids)
    
    # Compute attributions
    attributions, delta = ig.attribute(
        inputs=input_ids,
        baselines=baseline_ids,
        additional_forward_args=(attention_mask,),
        return_convergence_delta=True,
        n_steps=n_steps
    )
    
    # Sum attributions (they are per-embedding dimension)
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions.cpu().detach().numpy()
    
    # Get tokens
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu().numpy())
    
    return tokens, attributions, prediction, confidence, delta


def aggregate_subword_attributions(tokens, attributions):
    """
    Aggregate WordPiece subword tokens to word-level attributions.
    
    DistilBERT uses ## prefix for subword continuations.
    
    Returns:
        list of tuples: [(word, attribution_score), ...]
    """
    words = []
    word_scores = []
    current_word = ""
    current_score = 0
    
    for token, score in zip(tokens, attributions):
        # Skip special tokens
        if token in ['[CLS]', '[SEP]', '[PAD]']:
            continue
        
        # Check if subword continuation
        if token.startswith('##'):
            current_word += token[2:]
            current_score += score
        else:
            # Save previous word
            if current_word:
                words.append(current_word)
                word_scores.append(current_score)
            # Start new word
            current_word = token
            current_score = score
    
    # Add final word
    if current_word:
        words.append(current_word)
        word_scores.append(current_score)
    
    return list(zip(words, word_scores))


print("✓ Attribution functions defined")

---
## 5. HTML Heatmap Visualization

In [None]:
def create_html_heatmap(words, scores, prediction, confidence, title="Saliency Map"):
    """
    Create an HTML heatmap visualization.
    
    Positive scores (red) = push toward AI
    Negative scores (blue) = push toward Human
    """
    # Normalize scores for coloring
    max_abs_score = max(abs(min(scores)), abs(max(scores)))
    if max_abs_score == 0:
        max_abs_score = 1
    
    html = f"""
    <div style="font-family: Arial, sans-serif; padding: 20px; background-color: #f5f5f5; border-radius: 10px;">
        <h3 style="color: #333;">{title}</h3>
        <p style="margin: 10px 0;">
            <strong>Prediction:</strong> <span style="color: {'#d32f2f' if prediction == 1 else '#1976d2'};">
                {'AI-generated' if prediction == 1 else 'Human-written'}
            </span> 
            ({confidence:.1%} confidence)
        </p>
        <div style="margin-top: 15px; line-height: 2.2; font-size: 16px;">
    """
    
    for word, score in zip(words, scores):
        # Normalize score to [-1, 1]
        norm_score = score / max_abs_score
        
        # Color mapping
        if norm_score > 0:
            # Red for AI
            intensity = min(int(norm_score * 200 + 55), 255)
            bg_color = f"rgb({intensity}, {255-intensity}, {255-intensity})"
        else:
            # Blue for Human
            intensity = min(int(abs(norm_score) * 200 + 55), 255)
            bg_color = f"rgb({255-intensity}, {255-intensity}, {intensity})"
        
        # Determine text color for readability
        text_color = "#000" if abs(norm_score) < 0.5 else "#fff"
        
        html += f'''
            <span style="
                background-color: {bg_color};
                color: {text_color};
                padding: 3px 6px;
                margin: 2px;
                border-radius: 4px;
                display: inline-block;
                font-weight: {'bold' if abs(norm_score) > 0.3 else 'normal'};
            " title="Score: {score:.4f}">
                {word}
            </span>
        '''
    
    html += """
        </div>
        <div style="margin-top: 20px; font-size: 12px; color: #666;">
            <strong>Legend:</strong>
            <span style="background-color: #ffcccc; padding: 2px 8px; margin: 0 5px; border-radius: 3px;">Red = AI signal</span>
            <span style="background-color: #ccccff; padding: 2px 8px; margin: 0 5px; border-radius: 3px;">Blue = Human signal</span>
            <span style="color: #999; margin-left: 10px;">(Hover over words for exact scores)</span>
        </div>
    </div>
    """
    
    return html


print("✓ Visualization function defined")

---
## 6. Analysis: AI-Generated Text Sample

In [None]:
# Analyze AI-generated sample
ai_text = sample_ai_texts[0]

print("="*70)
print("ANALYZING AI-GENERATED TEXT")
print("="*70)
print(f"\nText:\n{ai_text[:400]}...\n")

# Compute attributions
tokens, attrs, pred, conf, delta = compute_attributions(
    ai_text, model, tokenizer, device, target_class=1
)

# Aggregate to words
word_attrs = aggregate_subword_attributions(tokens, attrs)
words, scores = zip(*word_attrs)

print(f"Prediction: {'AI' if pred == 1 else 'Human'} ({conf:.2%} confidence)")
print(f"Convergence delta: {delta.item():.6f}")
print(f"\nTop 10 words pushing toward AI:")
sorted_words = sorted(word_attrs, key=lambda x: x[1], reverse=True)
for i, (w, s) in enumerate(sorted_words[:10], 1):
    print(f"  {i:2d}. {w:20s} {s:+.4f}")

# Visualize
html_viz = create_html_heatmap(
    words, scores, pred, conf,
    title="Saliency Map: AI-Generated Text"
)
display(HTML(html_viz))

---
## 7. Analysis: Human-Written Text Sample

In [None]:
# Analyze Human-written sample
human_text = sample_human_texts[0]

print("="*70)
print("ANALYZING HUMAN-WRITTEN TEXT")
print("="*70)
print(f"\nText:\n{human_text[:400]}...\n")

# Compute attributions (still target AI class to see what prevents AI classification)
tokens, attrs, pred, conf, delta = compute_attributions(
    human_text, model, tokenizer, device, target_class=1
)

# Aggregate to words
word_attrs = aggregate_subword_attributions(tokens, attrs)
words, scores = zip(*word_attrs)

print(f"Prediction: {'AI' if pred == 1 else 'Human'} ({conf:.2%} confidence)")
print(f"Convergence delta: {delta.item():.6f}")
print(f"\nTop 10 words pushing AWAY from AI (toward Human):")
sorted_words = sorted(word_attrs, key=lambda x: x[1])
for i, (w, s) in enumerate(sorted_words[:10], 1):
    print(f"  {i:2d}. {w:20s} {s:+.4f}")

# Visualize
html_viz = create_html_heatmap(
    words, scores, pred, conf,
    title="Saliency Map: Human-Written Text"
)
display(HTML(html_viz))

---
## 8. Find Borderline/Misclassified Cases

Let's find human texts that are classified as AI or have low confidence, for error analysis.

In [None]:
# Find borderline or misclassified human samples
print("Searching for borderline Human samples (predicted as AI or low confidence)...\n")

borderline_cases = []

for i, text in enumerate(sample_human_texts):
    pred, conf, probs = predict_with_confidence(text, model, tokenizer, device)
    
    # Look for:
    # 1. Misclassified (Human predicted as AI)
    # 2. Low confidence correct predictions
    # 3. High AI probability even if correctly classified
    
    ai_prob = probs[1]  # Probability of AI class
    
    if pred == 1 or conf < 0.7 or ai_prob > 0.3:
        borderline_cases.append({
            'text': text,
            'prediction': pred,
            'confidence': conf,
            'ai_probability': ai_prob,
            'type': 'MISCLASSIFIED' if pred == 1 else 'BORDERLINE'
        })
    
    if len(borderline_cases) >= 5:
        break

print(f"Found {len(borderline_cases)} borderline/misclassified cases\n")
print("="*70)

for i, case in enumerate(borderline_cases[:3], 1):  # Show top 3
    print(f"\nCase {i} - {case['type']}")
    print(f"Prediction: {'AI' if case['prediction'] == 1 else 'Human'}")
    print(f"Confidence: {case['confidence']:.2%}")
    print(f"AI Probability: {case['ai_probability']:.2%}")
    print(f"Text preview: {case['text'][:200]}...")
    print("-"*70)

---
## 9. Error Analysis: Detailed Investigation of 3 Cases

Now we'll perform detailed saliency analysis on 3 borderline/misclassified cases, looking for:
1. **Lexical repetition** - Repeated words or phrases
2. **Stylistic regularity** - Overly formal or structured language
3. **AI-isms** - Words like "tapestry", "delve", "testament", "furthermore", "consequently"