# LLM Classifier Interpretability Analysis

This notebook provides interactive exploration of the trained LLM classifier through:

1. **Logistic Regression Coefficient Analysis** - Understanding which hidden dimensions drive predictions
2. **Logit Lens Risk Trajectories** - Visualizing how risk evolves across patient sequences
3. **Spurious Correlation Detection** - Identifying non-clinical patterns that may be driving predictions

## Model Architecture Recap

```
Input Tokens → Qwen3-8B + LoRA → Hidden States (B, T, 4096) → Last Token → Linear(4096, 2) → Cancer/Control
```

The classifier weights `W ∈ R^(2×4096)` directly encode how each hidden dimension contributes to the prediction.


## Setup


In [None]:
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, HTML
import warnings
warnings.filterwarnings('ignore')

# Add project root to path
project_root = os.path.abspath(os.path.join(os.getcwd(), '../../..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

print(f"Project root: {project_root}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")


In [None]:
# Import interpretability module
from src.pipelines.text_based.interpretability import (
    load_classifier_for_analysis,
    extract_lr_weights,
    compute_risk_trajectory,
    get_top_contributing_tokens,
    visualize_weight_distribution,
    visualize_risk_trajectory,
    visualize_top_risk_tokens,
    detect_spurious_patterns,
)
from src.data.unified_dataset import UnifiedEHRDataset
from src.training.utils import seed_all

seed_all(42)


## Configuration

Set the paths to your config and checkpoint:


In [None]:
# ============================================================
# CONFIGURATION - Update these paths for your setup
# ============================================================

CONFIG_PATH = "src/pipelines/text_based/configs/llm_classify_pretrained_cls_lora.yaml"
CHECKPOINT_PATH = "/data/scratch/qc25022/pancreas/experiments/lora-6-month-logistic-raw/checkpoint-7856"
OUTPUT_DIR = "./interpretability_results_notebook"

# Device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Config: {CONFIG_PATH}")
print(f"Checkpoint: {CHECKPOINT_PATH}")
print(f"Output: {OUTPUT_DIR}")
print(f"Device: {DEVICE}")


## 1. Load Model and Tokenizer


In [None]:
# Load the trained classifier
model, tokenizer, config = load_classifier_for_analysis(
    config_path=CONFIG_PATH,
    checkpoint_path=CHECKPOINT_PATH,
    device=DEVICE
)

print(f"\nModel loaded successfully!")
print(f"Classifier shape: {model.classifier.weight.shape}")


## 2. Logistic Regression Weight Analysis

The classifier is a simple linear layer: `logits = W @ hidden_state + b`

For binary classification:
- `W[0]` = weights for control class
- `W[1]` = weights for cancer class
- `W[1] - W[0]` = log-odds direction (positive = more cancer-like)


In [None]:
# Extract and analyze weights
weight_analysis = extract_lr_weights(model, top_k=50)

print("\n" + "="*70)
print("CLASSIFIER WEIGHT STATISTICS")
print("="*70)

stats = weight_analysis.statistics
for key, value in stats.items():
    print(f"  {key:20s}: {value:.6f}")


In [None]:
# Visualize weight distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Histogram
ax1 = axes[0]
ax1.hist(weight_analysis.diff_weights, bins=100, edgecolor='black', alpha=0.7, color='steelblue')
ax1.axvline(x=0, color='red', linestyle='--', linewidth=2, label='Zero')
ax1.axvline(x=stats['mean'], color='green', linestyle='--', linewidth=2, label=f"Mean={stats['mean']:.4f}")
ax1.set_xlabel('Weight (Cancer - Control)', fontsize=12)
ax1.set_ylabel('Frequency', fontsize=12)
ax1.set_title('Distribution of Log-Odds Weights', fontsize=14)
ax1.legend()
ax1.grid(True, alpha=0.3)

# Sorted weights
ax2 = axes[1]
sorted_weights = np.sort(weight_analysis.diff_weights)
ax2.plot(sorted_weights, linewidth=0.5)
ax2.axhline(y=0, color='red', linestyle='--', linewidth=1)
ax2.fill_between(range(len(sorted_weights)), 0, sorted_weights,
                 where=(sorted_weights > 0), color='crimson', alpha=0.3)
ax2.fill_between(range(len(sorted_weights)), 0, sorted_weights,
                 where=(sorted_weights < 0), color='steelblue', alpha=0.3)
ax2.set_xlabel('Dimension (sorted)', fontsize=12)
ax2.set_ylabel('Weight Value', fontsize=12)
ax2.set_title('Sorted Weights: Cancer vs Control', fontsize=14)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'weight_distribution_interactive.png'), dpi=150)
plt.show()


In [None]:
# Top dimensions
print("\n" + "="*70)
print("TOP 25 DIMENSIONS PUSHING TOWARD CANCER (Positive Weights)")
print("="*70)
for i, (dim, weight) in enumerate(weight_analysis.top_positive_dims[:25], 1):
    bar = '█' * int(abs(weight) * 20)
    print(f"{i:3d}. Dim {dim:4d}: {weight:+.5f} {bar}")

print("\n" + "="*70)
print("TOP 25 DIMENSIONS PUSHING TOWARD CONTROL (Negative Weights)")
print("="*70)
for i, (dim, weight) in enumerate(weight_analysis.top_negative_dims[:25], 1):
    bar = '█' * int(abs(weight) * 20)
    print(f"{i:3d}. Dim {dim:4d}: {weight:+.5f} {bar}")


## 3. Load Dataset for Trajectory Analysis


In [None]:
# Load validation dataset
data_config = config['data']

dataset_args = {
    "data_dir": data_config["data_dir"],
    "vocab_file": data_config["vocab_filepath"],
    "labels_file": data_config["labels_filepath"],
    "medical_lookup_file": data_config["medical_lookup_filepath"],
    "lab_lookup_file": data_config["lab_lookup_filepath"],
    "region_lookup_file": data_config["region_lookup_filepath"],
    "time_lookup_file": data_config["time_lookup_filepath"],
    "format": 'text',
    "cutoff_months": data_config.get("cutoff_months", 1),
    "max_sequence_length": None,
    "tokenizer": None,
    "data_type": data_config.get('data_type', 'raw')
}

val_dataset = UnifiedEHRDataset(split="tuning", **dataset_args)
print(f"Loaded {len(val_dataset)} validation samples")


In [None]:
# Find cancer and control patients
cancer_indices = []
control_indices = []

for i in range(min(len(val_dataset), 500)):
    sample = val_dataset[i]
    if sample is not None:
        label = sample['label'].item() if torch.is_tensor(sample['label']) else sample['label']
        if label > 0:
            cancer_indices.append(i)
        else:
            control_indices.append(i)

print(f"Found {len(cancer_indices)} cancer patients")
print(f"Found {len(control_indices)} control patients")


## 4. Risk Trajectory Analysis (Logit Lens)

The Logit Lens applies the classifier weights to hidden states at **every position**, not just the last token. This shows how the model's "opinion" evolves as it reads through the patient's history.


In [None]:
def analyze_patient(patient_idx, dataset, model, tokenizer, weight_analysis, output_dir):
    """Analyze and visualize a single patient."""
    sample = dataset[patient_idx]
    text = sample['text']
    label = sample['label'].item() if torch.is_tensor(sample['label']) else sample['label']
    label_str = "CANCER" if label > 0 else "CONTROL"
    
    print(f"\n{'='*70}")
    print(f"PATIENT {patient_idx} - {label_str}")
    print(f"{'='*70}")
    
    # Compute trajectory
    trajectory = compute_risk_trajectory(
        model=model,
        tokenizer=tokenizer,
        text=text,
        weight_analysis=weight_analysis,
        true_label=label,
        max_length=12000,
        device=DEVICE
    )
    
    print(f"\nSequence Length: {len(trajectory.tokens)} tokens")
    print(f"Final P(Cancer): {trajectory.final_prediction:.4f}")
    print(f"Final Log-Odds:  {trajectory.risk_scores[-1]:.4f}")
    print(f"Max Risk Score:  {np.max(trajectory.risk_scores):.4f}")
    print(f"Min Risk Score:  {np.min(trajectory.risk_scores):.4f}")
    
    # Plot trajectory
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 8), sharex=True)
    
    subsample = max(1, len(trajectory.tokens) // 500)
    positions = trajectory.positions[::subsample]
    risk_scores = trajectory.risk_scores[::subsample]
    probabilities = trajectory.probabilities[::subsample]
    
    # Log-odds plot
    ax1.plot(positions, risk_scores, color='crimson', linewidth=1, alpha=0.8)
    ax1.fill_between(positions, 0, risk_scores,
                     where=(risk_scores > 0), color='crimson', alpha=0.3, label='Cancer risk')
    ax1.fill_between(positions, 0, risk_scores,
                     where=(risk_scores <= 0), color='steelblue', alpha=0.3, label='Control risk')
    ax1.axhline(y=0, color='black', linestyle='--', linewidth=1)
    ax1.set_ylabel('Log-Odds (Cancer vs Control)', fontsize=12)
    ax1.set_title(f'Risk Trajectory - Patient {patient_idx} ({label_str})', fontsize=14)
    ax1.legend(loc='upper left')
    ax1.grid(True, alpha=0.3)
    
    # Probability plot
    ax2.plot(positions, probabilities, color='purple', linewidth=1, alpha=0.8)
    ax2.fill_between(positions, 0.5, probabilities,
                     where=(probabilities > 0.5), color='crimson', alpha=0.3)
    ax2.fill_between(positions, 0.5, probabilities,
                     where=(probabilities <= 0.5), color='steelblue', alpha=0.3)
    ax2.axhline(y=0.5, color='black', linestyle='--', linewidth=1)
    ax2.set_xlabel('Token Position', fontsize=12)
    ax2.set_ylabel('P(Cancer)', fontsize=12)
    ax2.set_ylim(0, 1)
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f'trajectory_patient_{patient_idx}.png'), dpi=150)
    plt.show()
    
    return trajectory


In [None]:
# Analyze a cancer patient
if cancer_indices:
    cancer_patient_idx = cancer_indices[0]
    cancer_trajectory = analyze_patient(
        cancer_patient_idx, val_dataset, model, tokenizer, weight_analysis, OUTPUT_DIR
    )
else:
    print("No cancer patients found in the scanned range.")


In [None]:
# Analyze a control patient
if control_indices:
    control_patient_idx = control_indices[0]
    control_trajectory = analyze_patient(
        control_patient_idx, val_dataset, model, tokenizer, weight_analysis, OUTPUT_DIR
    )
else:
    print("No control patients found in the scanned range.")


## 5. High-Risk Token Analysis

Which tokens appear at positions where the risk score is highest? This can reveal:
- **Clinical signals**: Symptoms, diagnoses, lab values that indicate cancer risk
- **Spurious patterns**: Administrative tokens, dates, doctor names that shouldn't be predictive


In [None]:
def show_high_risk_tokens(trajectory, top_k=20):
    """Display tokens at high-risk positions with context."""
    top_tokens = get_top_contributing_tokens(trajectory, top_k=top_k, context_window=10)
    
    print(f"\n{'='*70}")
    print(f"TOP {top_k} HIGH-RISK POSITIONS")
    print(f"{'='*70}")
    
    for i, info in enumerate(top_tokens, 1):
        print(f"\n{i}. Position {info['position']}: Score={info['risk_score']:.4f}, P={info['probability']:.4f}")
        print(f"   Token: '{info['token']}'")
        print(f"   Context: ...{info['context'][:100]}...")


In [None]:
# Show high-risk tokens for the cancer patient
if 'cancer_trajectory' in dir():
    show_high_risk_tokens(cancer_trajectory, top_k=15)


In [None]:
# Visualize token risk distribution
if 'cancer_trajectory' in dir():
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    # Get top and bottom tokens
    sorted_idx = np.argsort(cancer_trajectory.risk_scores)
    top_cancer = sorted_idx[-20:][::-1]
    top_control = sorted_idx[:20]
    
    # Top cancer tokens
    tokens_cancer = [f"[{i}] {cancer_trajectory.tokens[i][:12]}" for i in top_cancer]
    scores_cancer = [cancer_trajectory.risk_scores[i] for i in top_cancer]
    ax1.barh(tokens_cancer[::-1], scores_cancer[::-1], color='crimson', edgecolor='black', alpha=0.7)
    ax1.set_xlabel('Log-Odds Score', fontsize=12)
    ax1.set_title('Top 20 Tokens -> Cancer', fontsize=14)
    ax1.grid(True, alpha=0.3, axis='x')
    
    # Top control tokens
    tokens_control = [f"[{i}] {cancer_trajectory.tokens[i][:12]}" for i in top_control]
    scores_control = [cancer_trajectory.risk_scores[i] for i in top_control]
    ax2.barh(tokens_control[::-1], scores_control[::-1], color='steelblue', edgecolor='black', alpha=0.7)
    ax2.set_xlabel('Log-Odds Score', fontsize=12)
    ax2.set_title('Top 20 Tokens -> Control', fontsize=14)
    ax2.grid(True, alpha=0.3, axis='x')
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, 'high_risk_tokens.png'), dpi=150)
    plt.show()


## 6. Interactive Patient Explorer

Use this section to explore specific patients interactively.


In [None]:
# Choose a patient index to analyze
PATIENT_INDEX = cancer_indices[1] if len(cancer_indices) > 1 else (control_indices[0] if control_indices else 0)

trajectory = analyze_patient(
    PATIENT_INDEX, val_dataset, model, tokenizer, weight_analysis, OUTPUT_DIR
)
show_high_risk_tokens(trajectory, top_k=10)


## 7. Spurious Correlation Detection

Analyze multiple patients to detect patterns that might indicate spurious correlations.


In [None]:
# Collect trajectories from multiple patients
import random

all_trajectories = []
sample_cancer = random.sample(cancer_indices, min(5, len(cancer_indices)))
sample_control = random.sample(control_indices, min(5, len(control_indices)))

print("Analyzing cancer patients...")
for idx in sample_cancer:
    sample = val_dataset[idx]
    text = sample['text']
    label = sample['label'].item() if torch.is_tensor(sample['label']) else sample['label']
    
    traj = compute_risk_trajectory(
        model=model,
        tokenizer=tokenizer,
        text=text,
        weight_analysis=weight_analysis,
        true_label=label,
        max_length=12000,
        device=DEVICE
    )
    all_trajectories.append(traj)
    print(f"  Patient {idx}: P(Cancer)={traj.final_prediction:.3f}")

print("\nAnalyzing control patients...")
for idx in sample_control:
    sample = val_dataset[idx]
    text = sample['text']
    label = sample['label'].item() if torch.is_tensor(sample['label']) else sample['label']
    
    traj = compute_risk_trajectory(
        model=model,
        tokenizer=tokenizer,
        text=text,
        weight_analysis=weight_analysis,
        true_label=label,
        max_length=12000,
        device=DEVICE
    )
    all_trajectories.append(traj)
    print(f"  Patient {idx}: P(Cancer)={traj.final_prediction:.3f}")


In [None]:
# Run spurious detection
spurious_results = detect_spurious_patterns(all_trajectories, tokenizer, top_k=50)

print(f"\n{'='*70}")
print("SPURIOUS CORRELATION ANALYSIS")
print(f"{'='*70}")

print(f"\nAnalyzed {len(all_trajectories)} patient trajectories")

print("\nMost Common High-Risk Tokens:")
print("-" * 50)
for i, (token, count) in enumerate(spurious_results['high_risk_tokens'][:20], 1):
    token_display = token.replace('\n', '\\n')[:30]
    print(f"  {i:2d}. '{token_display}': {count} occurrences")

if spurious_results['warnings']:
    print(f"\nWARNINGS ({spurious_results['num_warnings']} potential spurious patterns):")
    print("-" * 50)
    for warning in spurious_results['warnings']:
        print(f"  - {warning['warning']}")
else:
    print("\nNo obvious spurious patterns detected in high-risk tokens")


## 8. Save Analysis Results


In [None]:
from src.pipelines.text_based.interpretability import save_analysis_results

# Save weight analysis
save_analysis_results(weight_analysis, OUTPUT_DIR)

# Save spurious analysis
import json
spurious_path = os.path.join(OUTPUT_DIR, "spurious_analysis.json")
with open(spurious_path, 'w') as f:
    serializable_results = {
        "high_risk_tokens": [(str(t), c) for t, c in spurious_results['high_risk_tokens']],
        "low_risk_tokens": [(str(t), c) for t, c in spurious_results['low_risk_tokens']],
        "warnings": spurious_results['warnings'],
        "num_warnings": spurious_results['num_warnings']
    }
    json.dump(serializable_results, f, indent=2)

print(f"\nResults saved to: {OUTPUT_DIR}")
print("\nGenerated files:")
for f in os.listdir(OUTPUT_DIR):
    print(f"  - {f}")


## Summary

This notebook provided:

1. **Weight Analysis**: Extracted and visualized the logistic regression coefficients, showing which hidden dimensions push toward cancer vs control predictions.

2. **Risk Trajectories**: Applied the Logit Lens to visualize how the model's "opinion" evolves as it reads patient histories.

3. **Token Attribution**: Identified tokens at high-risk positions to understand what clinical (or spurious) signals drive predictions.

4. **Spurious Detection**: Analyzed patterns across multiple patients to flag potential non-clinical biases.

### Next Steps

- **Integrated Gradients**: For more precise token-level attribution, consider implementing Integrated Gradients using the Captum library.
- **Attention Analysis**: Visualize attention patterns to understand token interactions.
- **Dimension Probing**: Investigate what specific hidden dimensions encode (e.g., by correlating with known clinical features).
