# V8 Phase 1 v2: Epistemic Uncertainty for Unlearning Verification

## Key Changes from v1

**Problems in v1:**
1. Fine-tuning INCREASED entropy (0.428 → 0.925) - unexpected
2. Gradient ascent caused complete collapse (entropy → 0)
3. Base model hallucinated about TOFU authors (treated them as real)

**Fixes in v2:**
1. Monitor perplexity to detect collapse early
2. Use milder unlearning (fewer steps, lower LR)
3. Track UQ at each unlearning step
4. Early stopping when model starts degrading

---

In [None]:
!pip install -q transformers accelerate bitsandbytes datasets peft trl
!pip install -q scipy matplotlib seaborn

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import List, Dict
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

## 1. Enhanced Uncertainty Measurer with Perplexity

In [None]:
@dataclass
class UncertaintyResult:
    prompt: str
    response: str
    mean_entropy: float
    first_token_entropy: float
    max_entropy: float
    entropy_std: float
    num_tokens: int

class TokenEntropyMeasurer:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.device = next(model.parameters()).device
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
    
    def measure(self, prompt: str, max_tokens: int = 50) -> UncertaintyResult:
        formatted = f"<s>[INST] {prompt} [/INST]"
        inputs = self.tokenizer(formatted, return_tensors="pt").to(self.device)
        prompt_len = inputs.input_ids.shape[1]
        
        generated_ids = inputs.input_ids.clone()
        entropies = []
        
        self.model.eval()
        for _ in range(max_tokens):
            with torch.no_grad():
                outputs = self.model(generated_ids)
                logits = outputs.logits[0, -1]
                probs = F.softmax(logits.float(), dim=-1)
                entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item()
                entropies.append(entropy)
                
                next_token = torch.argmax(probs).unsqueeze(0).unsqueeze(0)
                generated_ids = torch.cat([generated_ids, next_token], dim=1)
                
                if next_token.item() == self.tokenizer.eos_token_id:
                    break
        
        response = self.tokenizer.decode(generated_ids[0, prompt_len:], skip_special_tokens=True)
        
        return UncertaintyResult(
            prompt=prompt,
            response=response,
            mean_entropy=np.mean(entropies) if entropies else 0.0,
            first_token_entropy=entropies[0] if entropies else 0.0,
            max_entropy=np.max(entropies) if entropies else 0.0,
            entropy_std=np.std(entropies) if entropies else 0.0,
            num_tokens=len(entropies),
        )
    
    def measure_batch(self, prompts: List[str], max_tokens: int = 50) -> List[UncertaintyResult]:
        results = []
        for prompt in tqdm(prompts, desc="Measuring UQ"):
            results.append(self.measure(prompt, max_tokens))
        return results

def compute_perplexity(model, tokenizer, texts: List[str], max_length: int = 256) -> float:
    """Compute perplexity on a set of texts - key health metric."""
    model.eval()
    total_loss = 0
    total_tokens = 0
    
    with torch.no_grad():
        for text in texts[:20]:  # Sample for speed
            inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
            outputs = model(**inputs, labels=inputs["input_ids"])
            total_loss += outputs.loss.item() * inputs["input_ids"].shape[1]
            total_tokens += inputs["input_ids"].shape[1]
    
    avg_loss = total_loss / total_tokens
    perplexity = np.exp(avg_loss)
    return perplexity

## 2. Load TOFU Dataset

In [None]:
from datasets import load_dataset

print("Loading TOFU dataset...")
forget_data = load_dataset("locuslab/TOFU", "forget10")['train']
retain_data = load_dataset("locuslab/TOFU", "retain90")['train']

print(f"Forget set: {len(forget_data)} samples")
print(f"Retain set: {len(retain_data)} samples")

# Get questions
forget_questions = [item['question'] for item in forget_data]
retain_questions = [item['question'] for item in retain_data][:100]  # Sample

# Prepare retain texts for perplexity monitoring
retain_texts = [f"Question: {item['question']}\nAnswer: {item['answer']}" for item in retain_data]

## 3. Load Model

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

print(f"Loading {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
)
base_model.eval()
print("Model loaded!")

## 4. Baseline Measurements

In [None]:
# Measure base model
print("Measuring base model uncertainty...")
measurer = TokenEntropyMeasurer(base_model, tokenizer)
base_results = measurer.measure_batch(forget_questions[:30], max_tokens=30)

base_entropies = [r.mean_entropy for r in base_results]
base_perplexity = compute_perplexity(base_model, tokenizer, retain_texts)

print(f"\nBase Model:")
print(f"  Mean entropy: {np.mean(base_entropies):.3f}")
print(f"  Perplexity on retain set: {base_perplexity:.2f}")

In [None]:
# Check base model responses - does it hallucinate about TOFU authors?
print("\nBase model responses (checking for hallucination):")
print("="*60)
for i in range(3):
    print(f"\nQ: {base_results[i].prompt}")
    print(f"A: {base_results[i].response[:150]}...")
    print(f"Entropy: {base_results[i].mean_entropy:.3f}")

## 5. Fine-tune on TOFU

In [None]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling

print("Loading fresh model for fine-tuning...")
finetune_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
)
finetune_model = prepare_model_for_kbit_training(finetune_model)

lora_config = LoraConfig(
    r=8,  # Reduced from 16
    lora_alpha=16,  # Reduced from 32
    target_modules=["q_proj", "v_proj"],  # Fewer modules
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM",
)

finetune_model = get_peft_model(finetune_model, lora_config)
finetune_model.print_trainable_parameters()

In [None]:
# Prepare training data
def format_sample(example):
    return {"text": f"<s>[INST] {example['question']} [/INST] {example['answer']}</s>"}

train_data = forget_data.map(format_sample)

def tokenize(example):
    return tokenizer(example["text"], truncation=True, max_length=256, padding="max_length")

tokenized_data = train_data.map(tokenize, batched=True, remove_columns=train_data.column_names)
print(f"Training on {len(tokenized_data)} samples")

In [None]:
# Train with fewer epochs
training_args = TrainingArguments(
    output_dir="./tofu_finetuned",
    num_train_epochs=2,  # Reduced from 3
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=1e-4,  # Reduced from 2e-4
    fp16=True,
    logging_steps=10,
    save_strategy="no",
    report_to="none",
)

trainer = Trainer(
    model=finetune_model,
    args=training_args,
    train_dataset=tokenized_data,
    data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
)

print("Fine-tuning...")
trainer.train()
print("Done!")

In [None]:
# Measure fine-tuned model
finetune_model.eval()
measurer_ft = TokenEntropyMeasurer(finetune_model, tokenizer)
ft_results = measurer_ft.measure_batch(forget_questions[:30], max_tokens=30)

ft_entropies = [r.mean_entropy for r in ft_results]
ft_perplexity = compute_perplexity(finetune_model, tokenizer, retain_texts)

print(f"\nFine-tuned Model:")
print(f"  Mean entropy: {np.mean(ft_entropies):.3f} (was {np.mean(base_entropies):.3f})")
print(f"  Perplexity: {ft_perplexity:.2f} (was {base_perplexity:.2f})")

In [None]:
# Check fine-tuned responses
print("\nFine-tuned model responses:")
print("="*60)
for i in range(3):
    print(f"\nQ: {ft_results[i].prompt}")
    print(f"A: {ft_results[i].response[:150]}...")
    print(f"Entropy: {ft_results[i].mean_entropy:.3f}")

## 6. Gradual Unlearning with Monitoring

**Key improvement:** Track UQ and perplexity at each step, stop before collapse.

In [None]:
def gradual_unlearn_with_monitoring(
    model, tokenizer, forget_data, retain_texts,
    num_steps=10, lr=5e-6, samples_per_step=50,
    max_perplexity_ratio=3.0,  # Stop if perplexity > 3x baseline
    measure_every=2,
):
    """
    Gradual unlearning with health monitoring.
    
    Returns trajectory of metrics at each checkpoint.
    """
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    
    # Prepare forget texts
    forget_texts = [f"<s>[INST] {item['question']} [/INST] {item['answer']}</s>" 
                   for item in forget_data]
    
    # Get baseline perplexity
    model.eval()
    baseline_ppl = compute_perplexity(model, tokenizer, retain_texts)
    model.train()
    
    # Track metrics
    trajectory = {
        'step': [0],
        'loss': [0],
        'perplexity': [baseline_ppl],
        'entropy': [],
    }
    
    # Measure initial entropy
    model.eval()
    measurer = TokenEntropyMeasurer(model, tokenizer)
    forget_questions = [item['question'] for item in forget_data][:20]
    initial_results = measurer.measure_batch(forget_questions, max_tokens=20)
    trajectory['entropy'].append(np.mean([r.mean_entropy for r in initial_results]))
    model.train()
    
    print(f"Starting gradual unlearning...")
    print(f"Baseline perplexity: {baseline_ppl:.2f}")
    print(f"Will stop if perplexity > {baseline_ppl * max_perplexity_ratio:.2f}")
    print()
    
    for step in range(1, num_steps + 1):
        step_loss = 0
        
        # Sample from forget set
        indices = np.random.choice(len(forget_texts), min(samples_per_step, len(forget_texts)), replace=False)
        
        for idx in indices:
            text = forget_texts[idx]
            inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
            
            outputs = model(**inputs, labels=inputs["input_ids"])
            loss = outputs.loss
            
            # Gradient ASCENT
            (-loss).backward()
            optimizer.step()
            optimizer.zero_grad()
            
            step_loss += loss.item()
        
        avg_loss = step_loss / len(indices)
        
        # Check perplexity (model health)
        model.eval()
        current_ppl = compute_perplexity(model, tokenizer, retain_texts)
        model.train()
        
        trajectory['step'].append(step)
        trajectory['loss'].append(avg_loss)
        trajectory['perplexity'].append(current_ppl)
        
        # Measure entropy periodically
        if step % measure_every == 0:
            model.eval()
            results = measurer.measure_batch(forget_questions, max_tokens=20)
            mean_entropy = np.mean([r.mean_entropy for r in results])
            trajectory['entropy'].append(mean_entropy)
            model.train()
            
            print(f"Step {step}: Loss={avg_loss:.2f}, PPL={current_ppl:.2f}, Entropy={mean_entropy:.3f}")
        else:
            print(f"Step {step}: Loss={avg_loss:.2f}, PPL={current_ppl:.2f}")
        
        # Early stopping if model is collapsing
        if current_ppl > baseline_ppl * max_perplexity_ratio:
            print(f"\n[STOP] Perplexity too high ({current_ppl:.2f} > {baseline_ppl * max_perplexity_ratio:.2f})")
            print("Model starting to collapse - stopping early.")
            break
    
    model.eval()
    return model, trajectory

In [None]:
# Run gradual unlearning
unlearned_model, trajectory = gradual_unlearn_with_monitoring(
    finetune_model,
    tokenizer,
    list(forget_data),
    retain_texts,
    num_steps=15,
    lr=5e-6,
    samples_per_step=40,
    max_perplexity_ratio=2.5,
    measure_every=3,
)

## 7. Analyze Trajectory

In [None]:
# Plot trajectory
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss trajectory
ax1 = axes[0]
ax1.plot(trajectory['step'], trajectory['loss'], 'b-o')
ax1.set_xlabel('Unlearning Step')
ax1.set_ylabel('Loss')
ax1.set_title('Unlearning Loss (higher = more unlearned)')
ax1.grid(True, alpha=0.3)

# Perplexity trajectory
ax2 = axes[1]
ax2.plot(trajectory['step'], trajectory['perplexity'], 'r-o')
ax2.axhline(trajectory['perplexity'][0] * 2.5, color='r', linestyle='--', alpha=0.5, label='Collapse threshold')
ax2.set_xlabel('Unlearning Step')
ax2.set_ylabel('Perplexity')
ax2.set_title('Model Health (perplexity on retain set)')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Entropy trajectory  
ax3 = axes[2]
entropy_steps = [0] + list(range(3, len(trajectory['step']), 3))[:len(trajectory['entropy'])-1]
if len(trajectory['entropy']) > 0:
    ax3.plot(entropy_steps[:len(trajectory['entropy'])], trajectory['entropy'], 'g-o')
    ax3.axhline(np.mean(base_entropies), color='blue', linestyle='--', alpha=0.5, label='Base model')
ax3.set_xlabel('Unlearning Step')
ax3.set_ylabel('Mean Entropy')
ax3.set_title('Uncertainty on Forget Set')
ax3.legend()
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('unlearning_trajectory.png', dpi=150)
plt.show()

## 8. Final Comparison

In [None]:
# Measure final unlearned model
measurer_ul = TokenEntropyMeasurer(unlearned_model, tokenizer)
ul_results = measurer_ul.measure_batch(forget_questions[:30], max_tokens=30)
ul_entropies = [r.mean_entropy for r in ul_results]
ul_perplexity = compute_perplexity(unlearned_model, tokenizer, retain_texts)

print("=" * 60)
print("FINAL RESULTS")
print("=" * 60)
print(f"\n{'Model':<20} {'Entropy':<12} {'Perplexity':<12}")
print("-" * 44)
print(f"{'Base':<20} {np.mean(base_entropies):<12.3f} {base_perplexity:<12.2f}")
print(f"{'Fine-tuned':<20} {np.mean(ft_entropies):<12.3f} {ft_perplexity:<12.2f}")
print(f"{'Unlearned':<20} {np.mean(ul_entropies):<12.3f} {ul_perplexity:<12.2f}")

# Uncertainty Ratio
ur = np.mean(ul_entropies) / np.mean(base_entropies) if np.mean(base_entropies) > 0 else 0
print(f"\nUncertainty Ratio (UR): {ur:.3f}")

# Interpretation
print("\n" + "=" * 60)
print("INTERPRETATION")
print("=" * 60)

if np.mean(ul_entropies) < 0.01:  # Check for collapse
    print("[COLLAPSED] Model outputs garbage - unlearning too aggressive")
elif ur < 0.7:
    print(f"[HIDING] UR={ur:.3f} < 0.7")
    print("Model uncertainty lower than base - knowledge likely still hidden")
elif ur < 1.0:
    print(f"[PARTIAL] UR={ur:.3f} in [0.7, 1.0)")
    print("Model approaching base uncertainty - some knowledge may remain")
else:
    print(f"[CANDIDATE] UR={ur:.3f} >= 1.0")
    print("Model uncertainty matches/exceeds base - possible true unlearning")
    print("Recommend: Run adversarial recovery test to confirm")

In [None]:
# Sample responses comparison
print("\n" + "=" * 70)
print("SAMPLE RESPONSES")
print("=" * 70)

for i in range(min(3, len(base_results))):
    print(f"\n--- Question {i+1} ---")
    print(f"Q: {base_results[i].prompt}")
    print(f"\nBase (UQ={base_results[i].mean_entropy:.2f}): {base_results[i].response[:100]}")
    print(f"Fine-tuned (UQ={ft_results[i].mean_entropy:.2f}): {ft_results[i].response[:100]}")
    print(f"Unlearned (UQ={ul_results[i].mean_entropy:.2f}): {ul_results[i].response[:100]}")

## 9. Save Results

In [None]:
import json

results = {
    "model": MODEL_NAME,
    "base_entropy": float(np.mean(base_entropies)),
    "base_perplexity": float(base_perplexity),
    "finetuned_entropy": float(np.mean(ft_entropies)),
    "finetuned_perplexity": float(ft_perplexity),
    "unlearned_entropy": float(np.mean(ul_entropies)),
    "unlearned_perplexity": float(ul_perplexity),
    "uncertainty_ratio": float(ur),
    "unlearning_steps": len(trajectory['step']) - 1,
    "trajectory": {
        "steps": trajectory['step'],
        "loss": trajectory['loss'],
        "perplexity": trajectory['perplexity'],
        "entropy": trajectory['entropy'],
    }
}

with open("phase1_v2_results.json", "w") as f:
    json.dump(results, f, indent=2)

print("Results saved to phase1_v2_results.json")

## 10. Key Insights

### What We Learned:

1. **Gradient ascent without stopping → collapse**
   - v1 showed this clearly with 0 entropy garbage output
   - This validates the need for a stopping criterion

2. **Perplexity monitoring is essential**
   - Tracks model health during unlearning
   - Can detect collapse before it's complete

3. **Uncertainty Ratio (UR) as diagnostic**
   - UR < 0.7: Likely HIDING (knowledge suppressed but present)
   - UR ≈ 1.0: Candidate for TRUE UNLEARNING
   - UR = 0: Model collapsed (not useful)

### Next Steps:

1. If UR < 1.0: Run adversarial recovery test
2. If recovery succeeds: Confirms HIDING hypothesis
3. Phase 2: Use UQ as feedback signal for iterative unlearning