# Interpretability Analysis

In [None]:
import os
import sys
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM

sys.path.append(os.path.abspath('..'))

from src.dataset import load_data
from src.interpretability.feature_importance.integrated_gradients import LLMIsDefault
from src.interpretability.viz.utils import plot_token_importance, plot_text_heatmap, plot_latent_space
from src.interpretability.latent_analysis.clustering import compute_tsne, compute_pca
from src.interpretability.latent_analysis.probing import ProbingClassifier
from src.interpretability.attention.visualization import get_attention_weights, plot_attention_heatmap
from src.interpretability.counterfactuals.generation import sensitivity_check

%matplotlib inline

## Load Model & Data

In [None]:
MODEL_ID = "google/medgemma-1.5-4b-it"
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Loading {MODEL_ID} on {device}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, 
    torch_dtype=torch.bfloat16, 
    device_map="auto"
)
model.eval()
print("Model loaded.")

In [None]:
_, _, test_df = load_data('../tcga_reports_valid.csv')
print(f"Loaded {len(test_df)} test samples.")
test_df.head(2)

## Feature Importance (Integrated Gradients)

In [None]:
sample = test_df.sample(1).iloc[0]
text = sample['text'][:500]
print(f"Patient: {sample['patient_id']} | Cancer: {sample['cancer_type']}")

prompt = f"### Instruction:\nAnalyze the report.\n\n### Input:\n{text}\n\n### Response:\n"

ig = LLMIsDefault(model, tokenizer)

res = ig.interpret(prompt, n_steps=20, internal_batch_size=4)

plot_token_importance(res['tokens'][-20:], res['scores'][-20:], title=f"Importance for '{res['target_token']}'")

In [None]:
from IPython.core.display import display, HTML
from src.interpretability.viz.utils import plot_text_heatmap

html_path = "heatmap.html"
display(HTML(f"<h3>Reference Token: {res['target_token']}</h3>"))
plot_text_heatmap(res['tokens'], res['scores'], save_path=html_path)
display(HTML(filename=html_path))

## Latent Space Analysis

In [None]:
batch_size = 50
batch = test_df.sample(batch_size, random_state=42)
activations = []
labels = []

print("Extracting embeddings...")
for idx, row in batch.iterrows():
    txt = row['text'][:500]
    inp = tokenizer(txt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        out = model(**inp, output_hidden_states=True)
    activations.append(out.hidden_states[-1][:, -1, :].float().cpu().numpy())
    labels.append(row['cancer_type'])

activations = np.vstack(activations)

try:
    tsne_emb = compute_tsne(activations, perplexity=min(30, len(activations)-1))
    plot_latent_space(tsne_emb, labels, method="t-SNE")
except Exception as e:
    print(f"t-SNE failed: {e}")

## Attention Visualization

In [None]:
attn, tokens = get_attention_weights(model, tokenizer, prompt, layer_idx=-1)
plot_attention_heatmap(attn, tokens, title="Last Layer Attention")

## Counterfactual Analysis

In [None]:
template = "DIAGNOSIS: {INSERT} TUMOR. BIOPSY PERFORMED."
candidates = ["KIDNEY", "LUNG", "BRAIN", "PROSTATE"]

results = sensitivity_check(model, tokenizer, template, candidates)
pd.DataFrame(results)