# Deep Dive: Advanced Interpretability

This notebook explores the fine-tuned Cancer Classifier model in depth.
We go beyond single-example explanations to understand global model behavior and latent representations.

## Goals
1. **Global Feature Importance**: Identify which words are *globally* most predictive for each cancer type.
2. **Interactive Latent Space**: Visualize the document embeddings in an interactive 2D plot.
3. **Automated Error Analysis**: Identify and explain the most confusing examples.

In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
from tqdm.auto import tqdm
import torch
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import PeftModel, PeftConfig

# Install plotly if not present
try:
    import plotly.express as px
except ImportError:
    !pip install plotly
    import plotly.express as px

# Local imports
sys.path.append("../src")
from dataset import CancerDataset as TCGADataset
from interpretability.feature_importance.global_importance import GlobalExplainer
from interpretability.error_analysis import ErrorAnalyzer
from interpretability.viz.utils import plot_token_importance, plot_latent_space_interactive, plot_text_heatmap

## 1. Load Model and Data

In [2]:
# Configuration
MODEL_PATH = "../checkpoints/classifier_run/final_model"
DATA_PATH = "../tcga_reports_valid.csv"

print(f"Loading model from {MODEL_PATH}...")

# Load Config & Tokenizer
peft_config = PeftConfig.from_pretrained(MODEL_PATH)
base_model_name = peft_config.base_model_name_or_path

tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load full model
base_model = AutoModelForSequenceClassification.from_pretrained(
    base_model_name,
    num_labels=20, # We'll fix this properly with label encoder later if needed
    device_map="auto",
    dtype=torch.bfloat16,
    trust_remote_code=True
)
model = PeftModel.from_pretrained(base_model, MODEL_PATH)
model.eval()
print("Model loaded.")

Loading model from ../checkpoints/classifier_run/final_model...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of Gemma3ForSequenceClassification were not initialized from the model checkpoint at google/medgemma-1.5-4b-it and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded.


In [3]:
# Load Dataset
df = pd.read_csv(DATA_PATH)
from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()
df['label_idx'] = le.fit_transform(df['cancer_type'])
class_names = le.classes_
print(f"Loaded {len(df)} examples with {len(class_names)} classes.")

Loaded 7391 examples with 20 classes.


## 2. Interactive Latent Space
We extract embeddings using the base model (before the classification head) and project them to 2D.

In [4]:
embeddings = []
labels = []
texts = []

# Let's take a subset for speed, e.g., 200 examples or full dataset if small
subset_df = df.sample(min(200, len(df)), random_state=42)

print("Extracting embeddings...")
with torch.no_grad():
    for _, row in tqdm(subset_df.iterrows(), total=len(subset_df)):
        text = row['text']
        label = row['cancer_type']
        
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(model.device)
        # Hook or manually get hidden states? 
        # GemmaForSequenceClassification outputs.hidden_states if output_hidden_states=True
        outputs = model.base_model(inputs.input_ids, output_hidden_states=True)
        # Last hidden state of the last token (cls token equivalent)
        # Convert to float32 before numpy conversion because numpy doesn't support bfloat16
        cls_embedding = outputs.hidden_states[-1][:, -1, :].float().cpu().numpy()
        
        embeddings.append(cls_embedding[0])
        labels.append(label)
        texts.append(text)

embeddings = np.array(embeddings)

Extracting embeddings...


  0%|          | 0/200 [00:00<?, ?it/s]

In [5]:
from sklearn.manifold import TSNE
print("Running t-SNE...")
tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embeddings)-1))
emb_2d = tsne.fit_transform(embeddings)

# Interactive Plot
plot_latent_space_interactive(
    emb_2d, 
    labels, 
    texts, 
    title="Latent Space (t-SNE) - Hover to see reports"
)

Running t-SNE...


## 3. Global Feature Importance
We aggregate attribution scores across the subset to find global keywords.

In [None]:
explainer = GlobalExplainer(model, tokenizer)

# Calculate importance for the subset
target_indices = le.transform(subset_df['cancer_type'])
global_scores = explainer.accumulate_token_importance(
    subset_df['text'].tolist(), 
    target_indices,
    n_steps=10 # Faster for demo
)

Calculating global importance for 200 examples...


  0%|          | 0/200 [00:00<?, ?it/s]Gemma3ForSequenceClassification will not detect padding tokens in `inputs_embeds`. Results may be unexpected if using padding tokens in conjunction with `inputs_embeds.`
  8%|▊         | 15/200 [01:00<12:54,  4.19s/it]

In [1]:
# Visualize Top Global Tokens for a Class
target_class = "Breast Invasive Carcinoma"
if target_class in class_names:
    idx = le.transform([target_class])[0]
    top_tokens = explainer.get_top_global_tokens(global_scores, idx, k=15)
    
    if not top_tokens.empty:
        print(f"Global Importance for {target_class}:")
        display(top_tokens)
        plot_token_importance(
            top_tokens['Token'].tolist(), 
            top_tokens['MeanScore'].tolist(), 
            title=f"Global Importance: {target_class}"
        )
    else:
        print("No sufficient data for this class in subset.")
else:
    print("Class not found.")

NameError: name 'class_names' is not defined

In [None]:
# Let's check another class
target_class = "Lung Adenocarcinoma"
if target_class in class_names:
    idx = le.transform([target_class])[0]
    top_tokens = explainer.get_top_global_tokens(global_scores, idx, k=15)
    
    if not top_tokens.empty:
        plot_token_importance(
            top_tokens['Token'].tolist(), 
            top_tokens['MeanScore'].tolist(), 
            title=f"Global Importance: {target_class}"
        )
    else:
        print("No data.")

## 4. Automated Error Analysis
Identify the examples where the model is most "confused" (highest loss) and analyze them.

In [None]:
analyzer = ErrorAnalyzer(model, tokenizer)

# Find top 5 most confusing examples in the subset
# Using subset for speed, but ideally use full validation set
confusing_df = analyzer.find_most_confusing_examples(
    subset_df['text_report'].tolist(), 
    le.transform(subset_df['cancer_type']).tolist(), 
    class_names,
    top_k=5
)

display(confusing_df[['True_Label', 'Predicted_Label', 'Confidence', 'Loss']])

In [None]:
# Explain the most confusing example
worst_case = confusing_df.iloc[0]
print("Analyzing the most confusing example:")
print(f"True: {worst_case['True_Label']}, Pred: {worst_case['Predicted_Label']}")

# Run Interpretable Explanation
target_idx = le.transform([worst_case['True_Label']])[0]
res = explainer.cig.interpret(worst_case['Text'], target_class_idx=target_idx, n_steps=20)
plot_text_heatmap(
    res['tokens'][:512],
    res['scores'][:512],
    title=f"Why was this NOT {worst_case['True_Label']}? (Attribution to True Class)"
)