# Debug Projection Intervention: GPT-2 vs LLaMA

This notebook investigates why GPT-2 produces 185-365 interventions while LLaMA produces 0.

## Setup and Imports

In [29]:
!pip install transformers torch peft matplotlib datasets tqdm hf_transfer dotenv

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [30]:
import os
os.listdir("/workspace/CoT_Exploration/codi")

['train.py',
 'test.py',
 'src',
 'scripts',
 'requirements.txt',
 'probe_latent_token.py',
 'outputs',
 'imgs',
 'README.md',
 '.git']

In [31]:
import torch
import sys
import re
import os
from pathlib import Path

sys.path.insert(0, "/workspace/CoT_Exploration/codi")
from src.model import CODI, ModelArguments, TrainingArguments
from peft import LoraConfig, TaskType
from transformers import AutoTokenizer
from safetensors.torch import load_file

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


## Test Example

In [32]:
# GSM8K example
question = "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"
answer = "Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day. She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer's market. #### 18"

test_text = f"{question}\n{answer}"
print(f"Test text:\n{test_text}")
print(f"\nLength: {len(test_text)} characters")

Test text:
Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?
Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day. She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer's market. #### 18

Length: 410 characters


## Load CODI-GPT-2

In [None]:
print("="*80)
print("Loading CODI-GPT-2 from Local Checkpoint")
print("="*80)

  # Load from HuggingFace Hub
gpt2_model_name = "lhao499/codi-gpt2"

gpt2_model_args = ModelArguments(
      model_name_or_path="gpt2",
      lora_init=True,
      lora_r=128,
      lora_alpha=32,
      ckpt_dir="/workspace/CoT_Exploration/models/CODI-gpt2",  # Local checkpoint
      full_precision=True,
      token=None
  )

gpt2_training_args = TrainingArguments(
      output_dir="./outputs",
      model_max_length=512,
      inf_latent_iterations=6,
      use_prj=True,
      prj_dim=768,
      remove_eos=True,
      greedy=True,
      bf16=False,
      inf_num_iterations=1
  )

gpt2_lora_config = LoraConfig(
      task_type=TaskType.CAUSAL_LM,
      inference_mode=False,
      r=gpt2_model_args.lora_r,
      lora_alpha=gpt2_model_args.lora_alpha,
      lora_dropout=0.1,
      target_modules=["c_attn", "c_proj", "c_fc"],
      init_lora_weights=True,
  )

gpt2_model = CODI(gpt2_model_args,
gpt2_training_args, gpt2_lora_config)
gpt2_model = gpt2_model.to(device)
gpt2_model = gpt2_model.to(torch.bfloat16)
gpt2_model.eval()
gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2")

print("✓ CODI-GPT-2 loaded successfully from HuggingFace")

## Load CODI-LLaMA

In [24]:
# Login to HuggingFace (optional - only needed  for gated models like LLaMA)
import os
from dotenv import load_dotenv

  # Load environment variables from .env file
load_dotenv()

  # Get HuggingFace token
hf_token = os.getenv('HF_TOKEN')

if hf_token:
      from huggingface_hub import login
      login(token=hf_token)
      print("✓ Logged in to HuggingFace")
else:
      print("⚠ No HF_TOKEN found in .env file -  proceeding without authentication")
      print("  (This is fine if models are public)")

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


✓ Logged in to HuggingFace


In [None]:
print("="*80)
print("Loading CODI-LLaMA from Local Checkpoint")
print("="*80)

  # Load from HuggingFace Hub
llama_model_name = "lhao499/codi-llama3.2-1b"

llama_model_args = ModelArguments(

model_name_or_path="meta-llama/Llama-3.2-1B",
      lora_init=True,
      lora_r=128,
      lora_alpha=32,
      ckpt_dir="/workspace/CoT_Exploration/models/CODI-llama3.2-1b",  # Local checkpoint
      full_precision=True,
      token=None
  )

llama_training_args = TrainingArguments(
      output_dir="./outputs",
      model_max_length=512,
      inf_latent_iterations=6,
      use_prj=True,
      prj_dim=2048,
      remove_eos=True,
      greedy=True,
      bf16=False,
      inf_num_iterations=1
  )

llama_lora_config = LoraConfig(
      task_type=TaskType.CAUSAL_LM,
      inference_mode=False,
      r=llama_model_args.lora_r,
      lora_alpha=llama_model_args.lora_alpha,
      lora_dropout=0.1,
      target_modules=["q_proj", "v_proj",
  "k_proj", "o_proj", "gate_proj", "up_proj",
  "down_proj"],
      init_lora_weights=True,
  )

llama_model = CODI(llama_model_args,
  llama_training_args, llama_lora_config)
llama_model = llama_model.to(device)
llama_model = llama_model.to(torch.bfloat16)
llama_model.eval()

llama_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")

print("✓ CODI-LLaMA loaded successfully from HuggingFace")

## Run GPT-2 Forward Pass

In [34]:
print("="*80)
print("GPT-2 Forward Pass")
print("="*80)

gpt2_inputs = gpt2_tokenizer(test_text, return_tensors="pt", add_special_tokens=True)
gpt2_input_ids = gpt2_inputs.input_ids.to(device)

print(f"Input shape: {gpt2_input_ids.shape}")

with torch.no_grad():
    gpt2_outputs = gpt2_model.codi(
        input_ids=gpt2_input_ids,
        output_hidden_states=True
    )
    gpt2_continuous_thoughts = gpt2_outputs.hidden_states[-1]

print(f"Continuous thoughts shape: {gpt2_continuous_thoughts.shape}")
print(f"Number of thought positions: {gpt2_continuous_thoughts.shape[1]}")


GPT-2 Forward Pass
Input shape: torch.Size([1, 114])
Continuous thoughts shape: torch.Size([1, 114, 768])
Number of thought positions: 114


## Run LLaMA Forward Pass

In [None]:
print("="*80)
print("LLaMA Forward Pass")
print("="*80)

llama_inputs = llama_tokenizer(test_text, return_tensors="pt", add_special_tokens=True)
llama_input_ids = llama_inputs.input_ids.to(device)

print(f"Input shape: {llama_input_ids.shape}")

with torch.no_grad():
    llama_outputs = llama_model.codi(
        input_ids=llama_input_ids,
        output_hidden_states=True
    )
    llama_continuous_thoughts = llama_outputs.hidden_states[-1]

print(f"Continuous thoughts shape: {llama_continuous_thoughts.shape}")
print(f"Number of thought positions: {llama_continuous_thoughts.shape[1]}")


## Decode GPT-2 Continuous Thoughts (like section5_analysis.py)

In [38]:
print("="*80)
print("GPT-2: Decoding Continuous Thoughts")
print("="*80)

# Note: HuggingFace models don't have projection layers, decode directly
print(f"Continuous thoughts shape: {gpt2_continuous_thoughts.shape}")

# Number detection regex
number_regex = re.compile(r'^\s?\d+')

# Decode first 15 positions directly from continuous thoughts
print("\nDecoded tokens from continuous thoughts:")
gpt2_decoded_tokens = []
gpt2_number_positions = []

for i in range(min(15, gpt2_continuous_thoughts.shape[1])):
    logits = gpt2_model.codi.lm_head(gpt2_continuous_thoughts[:, i, :])
    top1_token_id = torch.argmax(logits, dim=-1).item()
    top1_token_str = gpt2_tokenizer.decode([top1_token_id])

    is_number = bool(number_regex.match(top1_token_str))
    pos_type = "BoT" if i == 0 else f"T{i}"

    gpt2_decoded_tokens.append((i, top1_token_id, top1_token_str, is_number))
    if is_number:
        gpt2_number_positions.append(i)

    marker = " ← NUMBER!" if is_number else ""
    print(f"  {pos_type:4s} [pos={i:2d}]: token_id={top1_token_id:5d} → '{top1_token_str}'{marker}")

print(f"\n✓ Numbers detected at positions: {gpt2_number_positions}")
print(f"✓ Total numbers: {len(gpt2_number_positions)}/15")

GPT-2: Decoding Continuous Thoughts
Continuous thoughts shape: torch.Size([1, 114, 768])

Decoded tokens from continuous thoughts:
  BoT  [pos= 0]: token_id=   13 → '.'
  T1   [pos= 1]: token_id=   11 → ','
  T2   [pos= 2]: token_id=  717 → ' first'
  T3   [pos= 3]: token_id=  389 → ' are'
  T4   [pos= 4]: token_id=  287 → ' in'
  T5   [pos= 5]: token_id= 3625 → ' feet'
  T6   [pos= 6]: token_id=  287 → ' in'
  T7   [pos= 7]: token_id= 1110 → ' day'
  T8   [pos= 8]: token_id=   11 → ','
  T9   [pos= 9]: token_id=  198 → '
'
  T10  [pos=10]: token_id=  318 → ' is'
  T11  [pos=11]: token_id=  546 → ' about'
  T12  [pos=12]: token_id=  284 → ' to'
  T13  [pos=13]: token_id=  790 → ' every'
  T14  [pos=14]: token_id= 1123 → ' each'

✓ Numbers detected at positions: []
✓ Total numbers: 0/15


## Decode LLaMA Continuous Thoughts (like section5_analysis.py)

In [None]:
print("="*80)
print("LLaMA: Decoding Continuous Thoughts")
print("="*80)

# Note: HuggingFace models don't have projection layers, decode directly
print(f"Continuous thoughts shape: {llama_continuous_thoughts.shape}")

# Decode first 15 positions directly from continuous thoughts
print("\nDecoded tokens from continuous thoughts:")
llama_decoded_tokens = []
llama_number_positions = []

for i in range(min(15, llama_continuous_thoughts.shape[1])):
    logits = llama_model.codi.lm_head(llama_continuous_thoughts[:, i, :])
    top1_token_id = torch.argmax(logits, dim=-1).item()
    top1_token_str = llama_tokenizer.decode([top1_token_id])

    is_number = bool(number_regex.match(top1_token_str))
    pos_type = "BoT" if i == 0 else f"T{i}"

    llama_decoded_tokens.append((i, top1_token_id, top1_token_str, is_number))
    if is_number:
        llama_number_positions.append(i)

    marker = " ← NUMBER!" if is_number else ""
    print(f"  {pos_type:4s} [pos={i:2d}]: token_id={top1_token_id:5d} → '{top1_token_str}'{marker}")

print(f"\n✓ Numbers detected at positions: {llama_number_positions}")
print(f"✓ Total numbers: {len(llama_number_positions)}/15")

## Compare Tokenizers (Sanity Check)

In [None]:
print("="*80)
print("Tokenizer Comparison")
print("="*80)

test_numbers = ['0', '1', '2', '3', '4', '5', '16', '18']

print("\nHow do tokenizers encode/decode numbers?\n")
for num in test_numbers:
    gpt2_ids = gpt2_tokenizer.encode(num, add_special_tokens=False)
    gpt2_decoded = gpt2_tokenizer.decode(gpt2_ids)
    gpt2_matches = bool(number_regex.match(gpt2_decoded))
    
    llama_ids = llama_tokenizer.encode(num, add_special_tokens=False)
    llama_decoded = llama_tokenizer.decode(llama_ids)
    llama_matches = bool(number_regex.match(llama_decoded))
    
    match_indicator = "✓" if gpt2_decoded == llama_decoded else "✗ MISMATCH"
    
    print(f"Number: '{num}'")
    print(f"  GPT-2:  IDs={gpt2_ids} → '{gpt2_decoded}' (matches={gpt2_matches})")
    print(f"  LLaMA:  IDs={llama_ids} → '{llama_decoded}' (matches={llama_matches})")
    print(f"  {match_indicator}\n")

## GPT-2: Projection Intervention (Target Token = '5', k=3)

In [None]:
print("="*80)
print("GPT-2: Projection Intervention")
print("="*80)

target_token = '5'
k = 3

# Get target token embedding
target_token_id = gpt2_tokenizer.encode(target_token, add_special_tokens=False)[0]
embedding_layer = gpt2_model.codi.get_input_embeddings()
target_embd = embedding_layer(torch.tensor([target_token_id], device=device))

print(f"Target token: '{target_token}'")
print(f"Target token ID: {target_token_id}")
print(f"k (top-k intervention): {k}")

gpt2_interventions = []

for i in range(min(15, gpt2_latent_embd.shape[1])):
    pos_type = "BoT" if i == 0 else f"T{i}"
    
    # Get predicted token
    logits = gpt2_model.codi.lm_head(gpt2_latent_embd[:, i, :])
    top1_token_id = torch.argmax(logits, dim=-1).item()
    top1_token_str = gpt2_tokenizer.decode([top1_token_id])
    
    # Check if it's a number
    is_number = bool(number_regex.match(top1_token_str))
    
    if is_number:
        # Get predicted token embedding
        predicted_embd = embedding_layer(torch.tensor([top1_token_id], device=device))
        
        # Get activation
        A = gpt2_latent_embd[:, i, :]  # [1, hidden_dim]
        
        # Normalize embeddings
        E_pred_norm = predicted_embd / torch.norm(predicted_embd, dim=-1, keepdim=True)
        E_target_norm = target_embd / torch.norm(target_embd, dim=-1, keepdim=True)
        
        # Projection removal and replacement
        proj_predicted = torch.sum(A * E_pred_norm, dim=-1, keepdim=True) * E_pred_norm
        proj_target = torch.sum(A * E_target_norm, dim=-1, keepdim=True) * E_target_norm
        
        A_modified = A - proj_predicted + k * proj_target
        
        # Decode modified activation
        logits_modified = gpt2_model.codi.lm_head(A_modified)
        new_token_id = torch.argmax(logits_modified, dim=-1).item()
        new_token_str = gpt2_tokenizer.decode([new_token_id])
        
        gpt2_interventions.append({
            'position': i,
            'predicted_token': top1_token_str,
            'new_token': new_token_str,
            'intervened': True
        })
        
        print(f"{pos_type:4s} [pos={i:2d}]: '{top1_token_str}' → '{new_token_str}' (intervened)")
    else:
        print(f"{pos_type:4s} [pos={i:2d}]: '{top1_token_str}' (not a number, skipped)")

print(f"\n✓ GPT-2 interventions: {len(gpt2_interventions)}")

## LLaMA: Projection Intervention (Target Token = '5', k=3)

In [None]:
print("="*80)
print("LLaMA: Projection Intervention")
print("="*80)

target_token = '5'
k = 3

# Get target token embedding
target_token_id = llama_tokenizer.encode(target_token, add_special_tokens=False)[0]
embedding_layer = llama_model.codi.get_input_embeddings()
target_embd = embedding_layer(torch.tensor([target_token_id], device=device))

print(f"Target token: '{target_token}'")
print(f"Target token ID: {target_token_id}")
print(f"k (top-k intervention): {k}")

llama_interventions = []

for i in range(min(15, llama_latent_embd.shape[1])):
    pos_type = "BoT" if i == 0 else f"T{i}"
    
    # Get predicted token
    logits = llama_model.codi.lm_head(llama_latent_embd[:, i, :])
    top1_token_id = torch.argmax(logits, dim=-1).item()
    top1_token_str = llama_tokenizer.decode([top1_token_id])
    
    # Check if it's a number
    is_number = bool(number_regex.match(top1_token_str))
    
    if is_number:
        # Get predicted token embedding
        predicted_embd = embedding_layer(torch.tensor([top1_token_id], device=device))
        
        # Get activation
        A = llama_latent_embd[:, i, :]  # [1, hidden_dim]
        
        # Normalize embeddings
        E_pred_norm = predicted_embd / torch.norm(predicted_embd, dim=-1, keepdim=True)
        E_target_norm = target_embd / torch.norm(target_embd, dim=-1, keepdim=True)
        
        # Projection removal and replacement
        proj_predicted = torch.sum(A * E_pred_norm, dim=-1, keepdim=True) * E_pred_norm
        proj_target = torch.sum(A * E_target_norm, dim=-1, keepdim=True) * E_target_norm
        
        A_modified = A - proj_predicted + k * proj_target
        
        # Decode modified activation
        logits_modified = llama_model.codi.lm_head(A_modified)
        new_token_id = torch.argmax(logits_modified, dim=-1).item()
        new_token_str = llama_tokenizer.decode([new_token_id])
        
        llama_interventions.append({
            'position': i,
            'predicted_token': top1_token_str,
            'new_token': new_token_str,
            'intervened': True
        })
        
        print(f"{pos_type:4s} [pos={i:2d}]: '{top1_token_str}' → '{new_token_str}' (intervened)")
    else:
        print(f"{pos_type:4s} [pos={i:2d}]: '{top1_token_str}' (not a number, skipped)")

print(f"\n✓ LLaMA interventions: {len(llama_interventions)}")

## Summary and Diagnosis

In [None]:
print("="*80)
print("SUMMARY")
print("="*80)

print(f"\nGPT-2:")
print(f"  - Numbers detected: {len(gpt2_number_positions)}")
print(f"  - Interventions: {len(gpt2_interventions)}")

print(f"\nLLaMA:")
print(f"  - Numbers detected: {len(llama_number_positions)}")
print(f"  - Interventions: {len(llama_interventions)}")

print("\n" + "="*80)
print("DIAGNOSIS")
print("="*80)

if len(llama_number_positions) == 0:
    print("\n⚠️  ROOT CAUSE IDENTIFIED:")
    print("LLaMA's lm_head does NOT predict number tokens from continuous thoughts.")
    print("This explains why there are 0 interventions.")
    print("\nPossible reasons:")
    print("  1. Projection layers (bot_projection/thought_projection) not trained correctly")
    print("  2. lm_head not loaded correctly (check strict=False in load_state_dict)")
    print("  3. Vocabulary size mismatch causing wrong token predictions")
    print("  4. Continuous thoughts from different distribution than GPT-2")
elif len(llama_number_positions) < len(gpt2_number_positions):
    print("\n⚠️  LLaMA detects fewer numbers than GPT-2")
    print(f"  GPT-2: {len(gpt2_number_positions)} numbers")
    print(f"  LLaMA: {len(llama_number_positions)} numbers")
else:
    print("\n✓ Both models detect similar numbers of tokens")
    print("  The issue may be in the intervention logic, not the decoding.")

## Additional Investigation: Top-5 Predictions

In [None]:
print("="*80)
print("Top-5 Predictions Comparison")
print("="*80)

print("\nGPT-2 - First 5 thought positions:")
for i in range(min(5, gpt2_latent_embd.shape[1])):
    logits = gpt2_model.codi.lm_head(gpt2_latent_embd[:, i, :])
    top5_vals, top5_ids = torch.topk(logits[0], 5)
    top5_tokens = [gpt2_tokenizer.decode([tid.item()]) for tid in top5_ids]
    
    pos_type = "BoT" if i == 0 else f"T{i}"
    print(f"\n  {pos_type} [pos={i}]:")
    for j, (token, val) in enumerate(zip(top5_tokens, top5_vals)):
        is_num = "← NUM" if number_regex.match(token) else ""
        print(f"    {j+1}. '{token}' (logit={val.item():.2f}) {is_num}")

print("\n" + "-"*80)
print("\nLLaMA - First 5 thought positions:")
for i in range(min(5, llama_latent_embd.shape[1])):
    logits = llama_model.codi.lm_head(llama_latent_embd[:, i, :])
    top5_vals, top5_ids = torch.topk(logits[0], 5)
    top5_tokens = [llama_tokenizer.decode([tid.item()]) for tid in top5_ids]
    
    pos_type = "BoT" if i == 0 else f"T{i}"
    print(f"\n  {pos_type} [pos={i}]:")
    for j, (token, val) in enumerate(zip(top5_tokens, top5_vals)):
        is_num = "← NUM" if number_regex.match(token) else ""
        print(f"    {j+1}. '{token}' (logit={val.item():.2f}) {is_num}")

## Inspect Activation Norms

In [None]:
print("="*80)
print("Activation Norms")
print("="*80)

gpt2_norms = torch.norm(gpt2_latent_embd[0], dim=-1)
llama_norms = torch.norm(llama_latent_embd[0], dim=-1)

print(f"\nGPT-2 latent embedding norms (first 10 positions):")
for i in range(min(10, len(gpt2_norms))):
    print(f"  Position {i}: {gpt2_norms[i].item():.2f}")

print(f"\nLLaMA latent embedding norms (first 10 positions):")
for i in range(min(10, len(llama_norms))):
    print(f"  Position {i}: {llama_norms[i].item():.2f}")

print(f"\nGPT-2 mean norm: {gpt2_norms.mean().item():.2f}")
print(f"LLaMA mean norm: {llama_norms.mean().item():.2f}")