<a href="https://colab.research.google.com/github/MLDreamer/AIMathematicallyexplained/blob/main/DOLA_self_doubt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
MODEL_NAME = "microsoft/phi-2"  # 2.7B parameters

In [6]:
"""
Mathematical Consciousness Lab: The Self-Correction Engine
A DOLA Playground - Reproduce the mathematics of truthfulness

This notebook demonstrates how contrasting early and late transformer layers
mathematically eliminates hallucinations. You'll see the equation in action:

    log(P_Corrected) = log(P_Late) - \u03bb \u00b7 log(P_Early)

Author: AI Demystification Series
Topic: Decoding by Contrasting Layers (DOLA)
"""

# ==============================================================================
# SECTION 0: SETUP & INSTALLATION
# ==============================================================================

print("=" * 70)
print("MATHEMATICAL CONSCIOUSNESS LAB: THE SELF-CORRECTION ENGINE")
print("=" * 70)
print("\nInstalling required libraries...")
print("This may take 2-3 minutes on first run.\n")

# Install dependencies
!pip install -q transformers torch accelerate

import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import warnings
warnings.filterwarnings('ignore')

print("\u2713 Libraries installed successfully!\n")

# ==============================================================================
# CONFIGURATION: Choose Your Model & Layers
# ==============================================================================

print("=" * 70)
print("CONFIGURATION")
print("=" * 70)

# Model selection (using smaller, open-source models for Colab compatibility)
# MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"  # 1B parameter model
MODEL_NAME = "microsoft/phi-2"  # 2.7B parameters
# Alternative options:
# MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"  # 1.1B parameters

# Layer configuration
EARLY_LAYER = 8   # Where factual knowledge lives (typically 10-30% of depth)
LATE_LAYER = -1   # Final layer (use -1 for last layer)
LAMBDA = 0.7      # Self-correction intensity (0.5 - 1.0 recommended)

print(f"\n\u23f9 Configuration:")
print(f"   Model: {MODEL_NAME}")
print(f"   Early Layer (Factual): {EARLY_LAYER}")
print(f"   Late Layer (Final): {LATE_LAYER}")
print(f"   Lambda (\u03bb): {LAMBDA}")
print(f"\n   Interpretation:")
print(f"   \u03bb = 0.0 \u2192 No correction (standard generation)")
print(f"   \u03bb = 0.5 \u2192 Moderate correction")
print(f"   \u03bb = 1.0 \u2192 Strong correction (maximum factuality)")
print("=" * 70)

# ==============================================================================
# SECTION 1: LOAD THE MODEL & TOKENIZER
# ==============================================================================

print("\n" + "=" * 70)
print("STEP 1: LOADING MODEL")
print("=" * 70)
print(f"\nLoading {MODEL_NAME}...")
print("This may take 1-2 minutes...\n")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Load model with output_hidden_states enabled (crucial for DOLA!)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
    output_hidden_states=True,  # This lets us access intermediate layers
    trust_remote_code=True
)

model.eval()  # Set to evaluation mode

total_layers = model.config.num_hidden_layers
print(f"\u2713 Model loaded successfully!")
print(f"\u2713 Total layers: {total_layers}")
print(f"\u2713 Vocabulary size: {model.config.vocab_size:,}")
print(f"\u2713 Parameters: ~{sum(p.numel() for p in model.parameters()) / 1e9:.2f}B")

# ==============================================================================
# SECTION 2: THE FACTUAL QUERY
# ==============================================================================

print("\n" + "=" * 70)
print("STEP 2: RUNNING THE FACTUAL QUERY")
print("=" * 70)

# Test prompts (choose one or add your own)
TEST_PROMPTS = [
    "The capital of France is",
    "The first person to walk on the moon was",
    "The largest ocean on Earth is the",
    "Mount Everest is located in",
    "The speed of light is approximately"
]

PROMPT = TEST_PROMPTS[0]  # Change index to test different prompts

print(f"\n البحث Query: '{PROMPT}'")
print(f"   Expected answer: Should be factually correct")
print(f"\n   Generating forward pass through all {total_layers} layers...")

# Tokenize input
inputs = tokenizer(PROMPT, return_tensors="pt").to(model.device)
input_ids = inputs["input_ids"]

print(f"   Tokenized input: {tokenizer.convert_ids_to_tokens(input_ids[0])}")

# Forward pass with hidden states
with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)

    # Extract hidden states from all layers
    hidden_states = outputs.hidden_states  # Tuple of (num_layers + 1) tensors

    # Get logits (raw scores) from different layers
    # We need to manually compute logits from hidden states using the LM head

    # Late layer logits (final prediction)
    late_hidden = hidden_states[LATE_LAYER]  # Shape: [batch, seq_len, hidden_dim]
    late_logits = model.lm_head(late_hidden)  # Shape: [batch, seq_len, vocab_size]

    # Early layer logits (factual foundation)
    early_hidden = hidden_states[EARLY_LAYER]
    early_logits = model.lm_head(early_hidden)

    # Focus on the last token's prediction (what comes next)
    late_logits_next = late_logits[0, -1, :]  # Shape: [vocab_size]
    early_logits_next = early_logits[0, -1, :]  # Shape: [vocab_size]

print(f"\u2713 Forward pass complete!")
print(f"\u2713 Extracted logits from Layer {EARLY_LAYER} (Early) and Layer {LATE_LAYER} (Late)")

# ==============================================================================
# SECTION 3: EXAMINING THE LAYERED VOTES
# ==============================================================================

print("\n" + "=" * 70)
print("STEP 3: EXTRACTING THE LAYERED VOTES (LOGITS)")
print("=" * 70)

def get_top_k_tokens(logits, k=10):
    """Get top k tokens and their logits/probabilities"""
    # Convert logits to probabilities
    probs = torch.softmax(logits, dim=-1)

    # Get top k
    top_probs, top_indices = torch.topk(probs, k)

    # Decode tokens
    top_tokens = [tokenizer.decode([idx]) for idx in top_indices]
    top_logits = logits[top_indices].cpu().numpy()
    top_probs = top_probs.cpu().numpy()

    return top_tokens, top_logits, top_probs

print("\n\U0001f4ca TOP 10 PREDICTIONS FROM EACH LAYER:\n")

# Early layer predictions
print(f"\U0001f7e2 EARLY LAYER {EARLY_LAYER} (Factual Foundation):")
print("-" * 70)
early_tokens, early_logits_top, early_probs = get_top_k_tokens(early_logits_next, k=10)

print(f"{'Rank':<6} {'Token':<20} {'Logit':<12} {'Probability':<12}")
print("-" * 70)
for i, (token, logit, prob) in enumerate(zip(early_tokens, early_logits_top, early_probs), 1):
    print(f"{i:<6} {token:<20} {logit:>10.3f}  {prob:>10.4f} ({prob*100:>5.2f}%)")

print("\n" + "\U0001f534 LATE LAYER (Final Prediction):")
print("-" * 70)
late_tokens, late_logits_top, late_probs = get_top_k_tokens(late_logits_next, k=10)

print(f"{'Rank':<6} {'Token':<20} {'Logit':<12} {'Probability':<12}")
print("-" * 70)
for i, (token, logit, prob) in enumerate(zip(late_tokens, late_logits_top, late_probs), 1):
    print(f"{i:<6} {token:<20} {logit:>10.3f}  {prob:>10.4f} ({prob*100:>5.2f}%)")

print("\n\U0001f4a1 OBSERVATION:")
print("   Notice how the early layer might be more confident about factual tokens,")
print("   while the late layer may distribute probability across more options.")

# ==============================================================================
# SECTION 4: THE MATHEMATICAL SELF-CORRECTION (DOLA)
# ==============================================================================

print("\n" + "=" * 70)
print("STEP 4: IMPLEMENTING DOLA - THE SELF-CORRECTION EQUATION")
print("=" * 70)

print(f"\n\U0001f4cf Applying the DOLA formula:")
print(f"   log(P_Corrected) = log(P_Late) - \u03bb \u00b7 log(P_Early)")
print(f"   where \u03bb = {LAMBDA}\n")

# Convert logits to log probabilities
late_log_probs = torch.log_softmax(late_logits_next, dim=-1)
early_log_probs = torch.log_softmax(early_logits_next, dim=-1)

# Apply DOLA correction
dola_log_probs = late_log_probs - LAMBDA * early_log_probs

# Convert back to probabilities
dola_probs = torch.softmax(dola_log_probs, dim=-1)

# Alternative: Work directly with logits (mathematically equivalent)
# dola_logits = late_logits_next - LAMBDA * early_logits_next
# dola_probs = torch.softmax(dola_logits, dim=-1)

print("\u2713 DOLA correction applied!")
print(f"   Mathematical operation:")
print(f"   - Subtracted {LAMBDA} \u00d7 Early Layer log-probs from Late Layer log-probs")
print(f"   - Tokens the late layer likes but early layer doesn't \u2192 PENALIZED")
print(f"   - Tokens both layers agree on \u2192 PRESERVED or BOOSTED")

# ==============================================================================
# SECTION 5: THE GRAND REVEAL
# ==============================================================================

print("\n" + "=" * 70)
print("STEP 5: THE GRAND REVEAL - TRUTH EMERGES")
print("=" * 70)

print("\n\U0001f3af TOP 10 PREDICTIONS AFTER DOLA CORRECTION:\n")
print("-" * 70)

# Get top k from DOLA-corrected distribution
top_probs_dola, top_indices_dola = torch.topk(dola_probs, 10)
top_tokens_dola = [tokenizer.decode([idx]) for idx in top_indices_dola]
top_logits_dola = dola_log_probs[top_indices_dola].cpu().numpy()

print(f"{'Rank':<6} {'Token':<20} {'Log-Prob':<12} {'Probability':<12} {'Change':<10}")
print("-" * 70)

for i, (token, log_prob, prob) in enumerate(zip(top_tokens_dola, top_logits_dola, top_probs_dola), 1):
    # Find this token's rank in late layer
    token_id = top_indices_dola[i-1]
    late_prob = late_probs[0] if token == late_tokens[0] else 0.0

    # Compute change
    change = ""
    if token in late_tokens[:3]:
        late_rank = late_tokens.index(token) + 1
        if i < late_rank:
            change = f"\u2191 from #{late_rank}"
        elif i > late_rank:
            change = f"\u2193 from #{late_rank}"
        else:
            change = f"= #{late_rank}"
    else:
        change = "NEW"

    prob_val = prob.item()
    print(f"{i:<6} {token:<20} {log_prob:>10.3f}  {prob_val:>10.4f} ({prob_val*100:>5.2f}%)  {change}")

# ==============================================================================
# SECTION 6: COMPARATIVE ANALYSIS
# ==============================================================================

print("\n" + "=" * 70)
print("STEP 6: COMPARATIVE ANALYSIS")
print("=" * 70)

print("\n\U0001f4ca SIDE-BY-SIDE COMPARISON (Top 5):\n")
print(f"{'Rank':<6} {'Late Layer':<25} {'Early Layer':<25} {'DOLA Corrected':<25}")
print("=" * 90)

for i in range(5):
    late_tok = late_tokens[i] if i < len(late_tokens) else "---"
    early_tok = early_tokens[i] if i < len(early_tokens) else "---"
    dola_tok = top_tokens_dola[i] if i < len(top_tokens_dola) else "---"

    late_pct = f"({late_probs[i]*100:.1f}%)" if i < len(late_probs) else ""
    early_pct = f"({early_probs[i]*100:.1f}%)" if i < len(early_probs) else ""
    dola_pct = f"({top_probs_dola[i]*100:.1f}%)" if i < len(top_probs_dola) else ""

    print(f"{i+1:<6} {late_tok:<15}{late_pct:<10} {early_tok:<15}{early_pct:<10} {dola_tok:<15}{dola_pct:<10}")

# ==============================================================================
# SECTION 7: CHALLENGE YOUR CONCLUSION
# ==============================================================================

print("\n" + "=" * 70)
print("STEP 7: CHALLENGE YOUR CONCLUSION")
print("=" * 70)

print("\n\U0001f914 CRITICAL ANALYSIS:\n")

# Get the actual winner from each method
late_winner = late_tokens[0]
early_winner = early_tokens[0]
dola_winner = top_tokens_dola[0]

print(f"   Without DOLA (Late Layer): '{late_winner}' wins")
print(f"   Early Layer says: '{early_winner}' should win")
print(f"   With DOLA (\u03bb={LAMBDA}): '{dola_winner}' wins\n")

if dola_winner == early_winner:
    print("   \u2705 RESULT: DOLA shifted prediction toward the factual foundation!")
    print("      The mathematical subtraction successfully penalized speculation.")
elif dola_winner == late_winner:
    print("   \u26a0\ufe0f  RESULT: DOLA preserved the late layer's choice")
    print("      This suggests both layers agreed, or \u03bb is too low.")
    print("      Try increasing \u03bb to 0.9 or 1.0 for stronger correction.")
else:
    print("   \U0001f3af RESULT: DOLA found a compromise between layers")
    print("      The correction balanced factual grounding with context.")

print("\n" + "=" * 70)
print("PHILOSOPHICAL QUESTION")
print("=" * 70)
print("""
Did we make the model 'more truthful'? Consider:

1. The model's weights didn't change
2. We just changed HOW we listen to it
3. Both answers were always there, in different layers
4. Truth emerged from mathematical contrast, not training

Is this genuine truthfulness, or clever probability manipulation?
Perhaps truth isn't IN the model\u2014it's in the tension BETWEEN its layers.

The machine didn't learn to stop lying.
We learned to hear when it was telling itself the truth all along.
""")

# ==============================================================================
# SECTION 8: INTERACTIVE EXPERIMENTS
# ==============================================================================

print("=" * 70)
print("SECTION 8: YOUR TURN TO EXPERIMENT")
print("=" * 70)
print("""
Try these experiments:

1. Change LAMBDA (line 36):
   - \u03bb = 0.0 (no correction)
   - \u03bb = 0.3 (gentle correction)
   - \u03bb = 0.7 (moderate correction)
   - \u03bb = 1.0 (strong correction)
   - \u03bb = 1.5 (aggressive correction)

2. Change EARLY_LAYER (line 34):
   - Try 5, 10, 15, 20
   - Earlier layers = more primitive/factual
   - Later layers = more abstract

3. Change PROMPT (line 130):
   - Try different factual questions
   - Try questions with common misconceptions
   - Try creative prompts

4. Change MODEL (line 24):
   - Try different open-source models
   - Larger models may show stronger layer differentiation

After each change, rerun all cells to see how DOLA responds!
""")

print("\n\u2713 Notebook complete! The mathematics of self-correction is yours to explore.")
print("=" * 70)

SyntaxError: (unicode error) 'unicodeescape' codec can't decode bytes in position 2-3: truncated \uXXXX escape (ipython-input-866142177.py, line 114)