# V9 Phase 2.7: Layer-wise Forgetting Analysis

## Research Question
**Where does forgetting happen in the network?**

### Hypothesis
- **True Unlearning**: Representations diverge from fine-tuned model in **early layers** (information never encoded)
- **Obfuscation/Hiding**: Representations diverge only in **late layers** (output suppression, knowledge retained)

### Method
For each layer `l` and each prompt `p`:
```
dist_to_base[l] = cosine_distance(hidden_unlearned[l], hidden_base[l])
dist_to_finetuned[l] = cosine_distance(hidden_unlearned[l], hidden_finetuned[l])
```

### Expected Results
| Unlearning Type | Early Layers | Late Layers |
|-----------------|--------------|-------------|
| True Forgetting | Close to base | Close to base |
| Obfuscation | Close to fine-tuned | Far from fine-tuned |

### Novelty vs FADE
- FADE: Output distribution only (black-box)
- **Our method**: Internal representations (white-box, mechanistic)
- Can explain *why* unlearning fails

---

In [None]:
# Install dependencies
!pip install -q protobuf==3.20.3 transformers accelerate datasets scipy matplotlib seaborn

# HuggingFace login
import os
from huggingface_hub import login

# Try Kaggle secrets first
try:
    from kaggle_secrets import UserSecretsClient
    secrets = UserSecretsClient()
    hf_token = secrets.get_secret("HF_TOKEN")
    login(token=hf_token)
    print("âœ“ Logged in via Kaggle Secrets")
except:
    print("Kaggle secrets not found, using interactive login...")
    login()

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from dataclasses import dataclass
from typing import List, Dict, Tuple
from tqdm import tqdm
import gc
import warnings
warnings.filterwarnings('ignore')

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 1. Hidden State Extraction

In [None]:
def extract_hidden_states(model, tokenizer, prompt: str) -> Dict[int, torch.Tensor]:
    """
    Extract hidden states from all layers for a given prompt.
    Returns dict: {layer_idx: hidden_state_tensor}
    """
    # Format prompt
    messages = [{"role": "user", "content": f"Answer briefly: {prompt}"}]
    formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(formatted, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model(
            **inputs,
            output_hidden_states=True,
            return_dict=True
        )
    
    # hidden_states: tuple of (n_layers + 1) tensors, each [batch, seq_len, hidden_dim]
    # We take the last token's hidden state (the one that predicts next token)
    hidden_states = {}
    for layer_idx, hs in enumerate(outputs.hidden_states):
        # hs shape: [1, seq_len, hidden_dim]
        # Take last token: [hidden_dim]
        hidden_states[layer_idx] = hs[0, -1, :].cpu().float()
    
    return hidden_states


def cosine_distance(v1: torch.Tensor, v2: torch.Tensor) -> float:
    """Compute cosine distance (1 - cosine_similarity)."""
    cos_sim = F.cosine_similarity(v1.unsqueeze(0), v2.unsqueeze(0)).item()
    return 1.0 - cos_sim


def l2_distance(v1: torch.Tensor, v2: torch.Tensor) -> float:
    """Compute L2 (Euclidean) distance."""
    return torch.norm(v1 - v2).item()


print("Hidden state extraction functions defined")

## 2. Load TOFU Dataset

In [None]:
from datasets import load_dataset

print("Loading TOFU forget10 dataset...")
forget10_data = load_dataset("locuslab/TOFU", "forget10")['train']

# Use subset for efficiency (hidden state extraction is expensive)
N_SAMPLES = 20  # Start with 20, increase if needed
test_questions = [item['question'] for item in forget10_data][:N_SAMPLES]
test_answers = [item['answer'] for item in forget10_data][:N_SAMPLES]

print(f"Using {N_SAMPLES} questions from forget10")
print(f"Sample Q: {test_questions[0]}")

## 3. Define Models

In [None]:
# Models for layer-wise analysis
MODELS = {
    "base": "meta-llama/Llama-3.2-1B-Instruct",
    "fine_tuned": "open-unlearning/tofu_Llama-3.2-1B-Instruct_full",
    # Most promising unlearned model from Phase 2.6
    "idk_dpo_e10": "open-unlearning/unlearn_tofu_Llama-3.2-1B-Instruct_forget10_IdkDPO_lr2e-05_beta0.1_alpha1_epoch10",
}

print("Models for layer-wise analysis:")
for name, path in MODELS.items():
    print(f"  {name}: {path}")

print("\nðŸ”¬ Analysis Plan:")
print("  1. Extract hidden states from all layers")
print("  2. Compare unlearned model to base AND fine-tuned")
print("  3. Identify WHERE forgetting happens (early vs late layers)")

## 4. Extract Hidden States from All Models

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

# Store all hidden states: {model_name: {question_idx: {layer: tensor}}}
all_hidden_states = {}

for model_name, model_path in MODELS.items():
    print(f"\n{'='*60}")
    print(f"Loading {model_name}: {model_path}")
    print("="*60)
    
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.float16,
            device_map="auto",
        )
        model.eval()
        
        # Get number of layers
        n_layers = model.config.num_hidden_layers + 1  # +1 for embedding layer
        print(f"  Model has {n_layers} layers (including embedding)")
        
        # Extract hidden states for each question
        model_hidden = {}
        for q_idx, question in enumerate(tqdm(test_questions, desc=f"Extracting {model_name}")):
            hidden = extract_hidden_states(model, tokenizer, question)
            model_hidden[q_idx] = hidden
        
        all_hidden_states[model_name] = model_hidden
        print(f"  âœ“ Extracted hidden states for {len(test_questions)} questions")
        
        # Free memory
        del model
        del tokenizer
        torch.cuda.empty_cache()
        gc.collect()
        
    except Exception as e:
        print(f"  âœ— Error: {e}")
        import traceback
        traceback.print_exc()

print(f"\nâœ“ Extracted hidden states from {len(all_hidden_states)} models")

## 5. Compute Layer-wise Distances

In [None]:
def compute_layerwise_distances(
    hidden_states: Dict[str, Dict[int, Dict[int, torch.Tensor]]],
    target_model: str,
    reference_model: str,
) -> Dict[int, List[float]]:
    """
    Compute cosine distance between target and reference model at each layer.
    Returns: {layer_idx: [distances for each question]}
    """
    target_hidden = hidden_states[target_model]
    ref_hidden = hidden_states[reference_model]
    
    # Get number of layers from first question
    n_layers = len(target_hidden[0])
    
    layer_distances = {layer: [] for layer in range(n_layers)}
    
    for q_idx in target_hidden.keys():
        for layer in range(n_layers):
            dist = cosine_distance(
                target_hidden[q_idx][layer],
                ref_hidden[q_idx][layer]
            )
            layer_distances[layer].append(dist)
    
    return layer_distances


# Compute distances for unlearned model
unlearned_model = "idk_dpo_e10"

print(f"Computing layer-wise distances for {unlearned_model}...")

# Distance to base model
dist_to_base = compute_layerwise_distances(all_hidden_states, unlearned_model, "base")
print(f"  âœ“ Distance to base computed")

# Distance to fine-tuned model
dist_to_finetuned = compute_layerwise_distances(all_hidden_states, unlearned_model, "fine_tuned")
print(f"  âœ“ Distance to fine-tuned computed")

# Also compute baseline: fine-tuned vs base (to understand the scale)
dist_ft_to_base = compute_layerwise_distances(all_hidden_states, "fine_tuned", "base")
print(f"  âœ“ Baseline (fine-tuned vs base) computed")

## 6. Analyze Layer-wise Patterns

In [None]:
def summarize_distances(layer_distances: Dict[int, List[float]]) -> Dict[int, Dict[str, float]]:
    """Compute mean and std for each layer."""
    summary = {}
    for layer, distances in layer_distances.items():
        summary[layer] = {
            "mean": np.mean(distances),
            "std": np.std(distances),
            "min": np.min(distances),
            "max": np.max(distances),
        }
    return summary


# Summarize all distances
summary_to_base = summarize_distances(dist_to_base)
summary_to_ft = summarize_distances(dist_to_finetuned)
summary_ft_base = summarize_distances(dist_ft_to_base)

n_layers = len(summary_to_base)

print("\n" + "="*80)
print("LAYER-WISE DISTANCE ANALYSIS")
print("="*80)
print(f"\n{unlearned_model} distances:")
print(f"{'Layer':<8} {'â†’ Base':<15} {'â†’ Fine-tuned':<15} {'Closer to':<15}")
print("-"*55)

for layer in range(n_layers):
    d_base = summary_to_base[layer]['mean']
    d_ft = summary_to_ft[layer]['mean']
    closer = "BASE" if d_base < d_ft else "FINE-TUNED"
    
    print(f"{layer:<8} {d_base:<15.4f} {d_ft:<15.4f} {closer:<15}")

In [None]:
# Compute "forgetting ratio" per layer
# Ratio = dist_to_finetuned / (dist_to_base + dist_to_finetuned)
# High ratio (>0.5) = closer to base = forgetting
# Low ratio (<0.5) = closer to fine-tuned = retained

forgetting_ratio = {}
for layer in range(n_layers):
    d_base = summary_to_base[layer]['mean']
    d_ft = summary_to_ft[layer]['mean']
    ratio = d_ft / (d_base + d_ft) if (d_base + d_ft) > 0 else 0.5
    forgetting_ratio[layer] = ratio

print("\n" + "="*80)
print("FORGETTING RATIO BY LAYER")
print("="*80)
print("\nRatio = dist_to_finetuned / (dist_to_base + dist_to_finetuned)")
print("  > 0.5: Closer to BASE (forgetting)")
print("  < 0.5: Closer to FINE-TUNED (retained)")
print()

# Divide into early, middle, late
early_layers = list(range(0, n_layers // 3))
middle_layers = list(range(n_layers // 3, 2 * n_layers // 3))
late_layers = list(range(2 * n_layers // 3, n_layers))

early_ratio = np.mean([forgetting_ratio[l] for l in early_layers])
middle_ratio = np.mean([forgetting_ratio[l] for l in middle_layers])
late_ratio = np.mean([forgetting_ratio[l] for l in late_layers])

print(f"Early layers  (0-{early_layers[-1]}):   Forgetting ratio = {early_ratio:.3f}")
print(f"Middle layers ({middle_layers[0]}-{middle_layers[-1]}):  Forgetting ratio = {middle_ratio:.3f}")
print(f"Late layers   ({late_layers[0]}-{late_layers[-1]}):  Forgetting ratio = {late_ratio:.3f}")

# Interpretation
print("\n" + "="*80)
print("INTERPRETATION")
print("="*80)

if early_ratio > 0.5 and late_ratio > 0.5:
    print("\nâœ“ TRUE UNLEARNING PATTERN")
    print("  Representations are closer to BASE across all layers")
    print("  â†’ Knowledge was likely removed from the model")
elif early_ratio < 0.5 and late_ratio > 0.5:
    print("\nâš  OBFUSCATION/HIDING PATTERN")
    print("  Early layers: closer to FINE-TUNED (knowledge retained)")
    print("  Late layers: closer to BASE (output suppressed)")
    print("  â†’ Knowledge retained internally, only output changed!")
elif early_ratio < 0.5 and late_ratio < 0.5:
    print("\nâœ— NO UNLEARNING")
    print("  Representations remain close to FINE-TUNED throughout")
    print("  â†’ Unlearning method had minimal effect")
else:
    print("\n? MIXED PATTERN")
    print(f"  Early: {early_ratio:.3f}, Late: {late_ratio:.3f}")
    print("  â†’ Requires further investigation")

## 7. Visualization

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

layers = list(range(n_layers))

# 1. Distance comparison across layers
ax1 = axes[0, 0]
ax1.plot(layers, [summary_to_base[l]['mean'] for l in layers], 'b-o', label=f'{unlearned_model} â†’ base', linewidth=2)
ax1.plot(layers, [summary_to_ft[l]['mean'] for l in layers], 'r-s', label=f'{unlearned_model} â†’ fine-tuned', linewidth=2)
ax1.plot(layers, [summary_ft_base[l]['mean'] for l in layers], 'g--', label='fine-tuned â†’ base (baseline)', alpha=0.5)
ax1.fill_between(layers, 
                  [summary_to_base[l]['mean'] - summary_to_base[l]['std'] for l in layers],
                  [summary_to_base[l]['mean'] + summary_to_base[l]['std'] for l in layers],
                  alpha=0.2, color='blue')
ax1.fill_between(layers,
                  [summary_to_ft[l]['mean'] - summary_to_ft[l]['std'] for l in layers],
                  [summary_to_ft[l]['mean'] + summary_to_ft[l]['std'] for l in layers],
                  alpha=0.2, color='red')
ax1.set_xlabel('Layer')
ax1.set_ylabel('Cosine Distance')
ax1.set_title('Layer-wise Distance to Reference Models')
ax1.legend()
ax1.grid(True, alpha=0.3)

# 2. Forgetting ratio by layer
ax2 = axes[0, 1]
colors = ['green' if r > 0.5 else 'red' for r in forgetting_ratio.values()]
ax2.bar(layers, list(forgetting_ratio.values()), color=colors, alpha=0.7)
ax2.axhline(0.5, color='black', linestyle='--', linewidth=2, label='Threshold')
ax2.set_xlabel('Layer')
ax2.set_ylabel('Forgetting Ratio')
ax2.set_title('Forgetting Ratio by Layer\n(>0.5 = closer to base = forgetting)')
ax2.legend()
ax2.grid(True, alpha=0.3)

# 3. Early vs Late comparison (box plot)
ax3 = axes[1, 0]
early_ratios = [forgetting_ratio[l] for l in early_layers]
middle_ratios = [forgetting_ratio[l] for l in middle_layers]
late_ratios = [forgetting_ratio[l] for l in late_layers]

bp = ax3.boxplot([early_ratios, middle_ratios, late_ratios], 
                  labels=['Early\n(embedding)', 'Middle\n(processing)', 'Late\n(output)'],
                  patch_artist=True)
colors_box = ['lightblue', 'lightyellow', 'lightgreen']
for patch, color in zip(bp['boxes'], colors_box):
    patch.set_facecolor(color)
ax3.axhline(0.5, color='red', linestyle='--', linewidth=2)
ax3.set_ylabel('Forgetting Ratio')
ax3.set_title('Forgetting Ratio by Layer Region')
ax3.grid(True, alpha=0.3)

# 4. 2D scatter: Early vs Late forgetting ratio
ax4 = axes[1, 1]
ax4.scatter(early_ratio, late_ratio, s=200, c='purple', marker='*', zorder=5)
ax4.annotate(unlearned_model, (early_ratio, late_ratio), fontsize=10, 
              xytext=(10, 10), textcoords='offset points')

# Add reference regions
ax4.axhline(0.5, color='gray', linestyle='--', alpha=0.5)
ax4.axvline(0.5, color='gray', linestyle='--', alpha=0.5)

# Label quadrants
ax4.text(0.25, 0.75, 'OBFUSCATION\n(hiddenâ†’output suppression)', ha='center', fontsize=9, color='red')
ax4.text(0.75, 0.75, 'TRUE UNLEARNING\n(knowledge removed)', ha='center', fontsize=9, color='green')
ax4.text(0.25, 0.25, 'NO EFFECT\n(still fine-tuned)', ha='center', fontsize=9, color='orange')
ax4.text(0.75, 0.25, 'STRANGE\n(early forget, late retain?)', ha='center', fontsize=9, color='gray')

ax4.set_xlabel('Early Layer Forgetting Ratio')
ax4.set_ylabel('Late Layer Forgetting Ratio')
ax4.set_title('Unlearning Signature Space')
ax4.set_xlim(0, 1)
ax4.set_ylim(0, 1)
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('v9_phase2.7_layerwise.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nâœ“ Saved visualization to v9_phase2.7_layerwise.png")

## 8. Per-Question Analysis

In [None]:
# Analyze individual questions to find outliers
print("\n" + "="*80)
print("PER-QUESTION ANALYSIS")
print("="*80)

# Compute per-question forgetting pattern
question_patterns = []

for q_idx in range(len(test_questions)):
    # Early layer distance ratio for this question
    early_dist_base = np.mean([dist_to_base[l][q_idx] for l in early_layers])
    early_dist_ft = np.mean([dist_to_finetuned[l][q_idx] for l in early_layers])
    early_ratio_q = early_dist_ft / (early_dist_base + early_dist_ft) if (early_dist_base + early_dist_ft) > 0 else 0.5
    
    # Late layer distance ratio
    late_dist_base = np.mean([dist_to_base[l][q_idx] for l in late_layers])
    late_dist_ft = np.mean([dist_to_finetuned[l][q_idx] for l in late_layers])
    late_ratio_q = late_dist_ft / (late_dist_base + late_dist_ft) if (late_dist_base + late_dist_ft) > 0 else 0.5
    
    # Classify pattern
    if early_ratio_q > 0.5 and late_ratio_q > 0.5:
        pattern = "TRUE_UNLEARN"
    elif early_ratio_q < 0.5 and late_ratio_q > 0.5:
        pattern = "OBFUSCATION"
    elif early_ratio_q < 0.5 and late_ratio_q < 0.5:
        pattern = "NO_EFFECT"
    else:
        pattern = "MIXED"
    
    question_patterns.append({
        'q_idx': q_idx,
        'question': test_questions[q_idx][:50] + "...",
        'early_ratio': early_ratio_q,
        'late_ratio': late_ratio_q,
        'pattern': pattern,
    })

# Count patterns
from collections import Counter
pattern_counts = Counter(p['pattern'] for p in question_patterns)

print("\nPattern Distribution:")
for pattern, count in pattern_counts.most_common():
    print(f"  {pattern}: {count}/{len(question_patterns)} ({count/len(question_patterns)*100:.1f}%)")

# Show examples of each pattern
print("\nExamples by Pattern:")
for pattern in ["OBFUSCATION", "TRUE_UNLEARN", "NO_EFFECT", "MIXED"]:
    examples = [p for p in question_patterns if p['pattern'] == pattern][:2]
    if examples:
        print(f"\n{pattern}:")
        for ex in examples:
            print(f"  Q: {ex['question']}")
            print(f"     Early: {ex['early_ratio']:.3f}, Late: {ex['late_ratio']:.3f}")

## 9. Save Results

In [None]:
import json

results = {
    "experiment": "V9 Phase 2.7: Layer-wise Forgetting Analysis",
    "unlearned_model": unlearned_model,
    "n_questions": len(test_questions),
    "n_layers": n_layers,
    "layer_regions": {
        "early": {"layers": early_layers, "forgetting_ratio": float(early_ratio)},
        "middle": {"layers": middle_layers, "forgetting_ratio": float(middle_ratio)},
        "late": {"layers": late_layers, "forgetting_ratio": float(late_ratio)},
    },
    "interpretation": {
        "early_closer_to": "BASE" if early_ratio > 0.5 else "FINE-TUNED",
        "late_closer_to": "BASE" if late_ratio > 0.5 else "FINE-TUNED",
        "overall_pattern": "TRUE_UNLEARNING" if (early_ratio > 0.5 and late_ratio > 0.5) else 
                          "OBFUSCATION" if (early_ratio < 0.5 and late_ratio > 0.5) else
                          "NO_EFFECT" if (early_ratio < 0.5 and late_ratio < 0.5) else "MIXED",
    },
    "per_layer_forgetting_ratio": {str(k): float(v) for k, v in forgetting_ratio.items()},
    "pattern_distribution": dict(pattern_counts),
}

with open("v9_phase2.7_results.json", "w") as f:
    json.dump(results, f, indent=2)

print("Saved results to v9_phase2.7_results.json")
print("\n" + "="*60)
print("PHASE 2.7 COMPLETE")
print("="*60)
print(f"\nKey Finding: {results['interpretation']['overall_pattern']}")
print(f"  Early layers ({early_layers[0]}-{early_layers[-1]}): closer to {results['interpretation']['early_closer_to']}")
print(f"  Late layers ({late_layers[0]}-{late_layers[-1]}): closer to {results['interpretation']['late_closer_to']}")

## 10. Conclusion

### What This Analysis Tells Us

| Pattern | Early Layers | Late Layers | Meaning |
|---------|--------------|-------------|----------|
| **TRUE UNLEARNING** | Close to base | Close to base | Knowledge removed from all layers |
| **OBFUSCATION** | Close to fine-tuned | Close to base | Knowledge retained, only output changed |
| **NO EFFECT** | Close to fine-tuned | Close to fine-tuned | Unlearning method failed |

### Novelty Over FADE
- FADE: Only looks at output distributions (black-box)
- **Our method**: Analyzes internal representations (white-box)
- Can identify *where* in the network forgetting happens
- Provides mechanistic understanding of unlearning failure

### Next Steps
1. Test more unlearning methods (GradDiff, NPO, etc.)
2. Correlate layer patterns with downstream attack success
3. Propose layer-targeted unlearning (focus on early layers)