# Feature Importance Analysis (Integrated Gradients)


In [1]:
import os
import sys
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM

sys.path.append(os.path.abspath('..'))

from src.dataset import load_data
from src.interpretability.feature_importance.integrated_gradients import LLMIsDefault
from src.interpretability.viz.utils import plot_token_importance, plot_text_heatmap
from src.model_utils import get_latest_checkpoint

%matplotlib inline

## Load Model & Data

In [2]:
DEFAULT_MODEL_ID = "google/medgemma-1.5-4b-it"
# Check for checkpoints in the parent directory
CHECKPOINT_DIR = "../checkpoints"

model_id = get_latest_checkpoint(DEFAULT_MODEL_ID, checkpoint_dir=CHECKPOINT_DIR)
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Loading tokenizer from base model: {DEFAULT_MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL_ID)

print(f"Loading model: {model_id}...")

# Check if model_id is a PEFT adapter (contains adapter_config.json)
is_adapter = False
if os.path.isdir(model_id) and "adapter_config.json" in os.listdir(model_id):
    is_adapter = True

if is_adapter:
    from peft import PeftModel
    print(f"Detected PEFT adapter at {model_id}. Loading base model {DEFAULT_MODEL_ID} first...")
    base_model = AutoModelForCausalLM.from_pretrained(
        DEFAULT_MODEL_ID, 
        dtype=torch.bfloat16, 
        device_map="auto",
        attn_implementation="eager"
    )
    model = PeftModel.from_pretrained(base_model, model_id)
else:
    model = AutoModelForCausalLM.from_pretrained(
        model_id, 
        dtype=torch.bfloat16, 
        device_map="auto",
        attn_implementation="eager"
    )

# Enable Gradient Checkpointing to save memory during IG
model.gradient_checkpointing_enable()
model.eval()
print("Model loaded with gradient checkpointing enabled.")

Found latest checkpoint: ../checkpoints/final_medgemma_model
Loading tokenizer from base model: google/medgemma-1.5-4b-it...
Loading model: ../checkpoints/final_medgemma_model...
Detected PEFT adapter at ../checkpoints/final_medgemma_model. Loading base model google/medgemma-1.5-4b-it first...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Model loaded with gradient checkpointing enabled.


In [3]:
_, _, test_df = load_data('../tcga_reports_valid.csv')
print(f"Loaded {len(test_df)} test samples.")
test_df.head(2)

Loaded 1479 test samples.


Unnamed: 0,patient_id,cancer_type,study_name,icd_o_3_site,icd_o_3_histology,icd_o_3_behavior,text
3778,TCGA-CS-5394,LGG,Brain Lower Grade Glioma,C719,9401,3,FINAL DIAGNOSIS: 1. LEFT FRONTAL TUMOR: ANAPLA...
2402,TCGA-B0-4848,KIRC,Kidney renal clear cell carcinoma,C649,8310,3,"FINAL DIAGNOSIS. PART 1: LEFT KIDNEY, RADICAL ..."


## Integrated Gradients Analysis

In [None]:
import gc
torch.cuda.empty_cache()
gc.collect()

sample = test_df.sample(1).iloc[0]
text = sample['text']
print(f"Patient: {sample['patient_id']} | Cancer: {sample['cancer_type']}")

prompt = f"### Instruction:\nAnalyze the report.\n\n### Input:\n{text}\n\n### Response:\n"

ig = LLMIsDefault(model, tokenizer)

# Reduced internal_batch_size to 1 to avoid OOM
res = ig.interpret(prompt, n_steps=20, internal_batch_size=1)

plot_token_importance(res['tokens'][-20:], res['scores'][-20:], title=f"Importance for '{res['target_token']}'")

In [None]:
from IPython.core.display import display, HTML

def display_heatmap(tokens, scores):
    html = plot_text_heatmap(tokens, scores)
    display(HTML(html))

display_heatmap(res['tokens'], res['scores'])