# Debug Projection Intervention: GPT-2 vs LLaMA

This notebook contains minimal working examples of CODI, focussing on LLAMA. There is also similar code for GPT-2, but it has not been tested properly.

This includes:
- loading model
- generating continuous chain of thought
- decoding continuous chain of thought
- projection replacement of decoded number tokens by specific token like '5'
- generating chain of thought with and without BOT (beginning of thought token)

## Setup and Imports

In [None]:
!pip install transformers torch peft matplotlib datasets tqdm hf_transfer dotenv

In [None]:
import os
os.listdir("/workspace/CoT_Exploration/codi")

In [None]:
import torch
import sys
import re
import os
from pathlib import Path

sys.path.insert(0, "/workspace/CoT_Exploration/codi")
from src.model import CODI, ModelArguments, TrainingArguments
from peft import LoraConfig, TaskType
from transformers import AutoTokenizer
from safetensors.torch import load_file

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## Test Example

In [39]:
# GSM8K example
question = "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"
answer = "Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day. She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer's market. #### 18"
answer = 18

test_text = f"{question}\n{answer}"
print(f"Test text:\n{test_text}")
print(f"\nLength: {len(test_text)} characters")

Test text:
Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?
Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day. She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer's market. #### 18

Length: 410 characters


In [None]:
# Number detection regex
number_regex = re.compile(r'^\s?\d+')

## Load CODI-LLaMA

In [None]:
# Login to HuggingFace (optional - only needed  for gated models like LLaMA)
import os
from dotenv import load_dotenv

  # Load environment variables from .env file
load_dotenv()

  # Get HuggingFace token
hf_token = os.getenv('HF_TOKEN')

if hf_token:
      from huggingface_hub import login
      login(token=hf_token)
      print("✓ Logged in to HuggingFace")
else:
      print("⚠ No HF_TOKEN found in .env file -  proceeding without authentication")
      print("  (This is fine if models are public)")

In [None]:
import os

checkpoint_path = "/workspace/CoT_Exploration/models/CODI-llama3.2-1b/pytorch_model.bin"
if os.path.exists(checkpoint_path):
    size_mb = os.path.getsize(checkpoint_path) / (1024 * 1024)
    print(f"✓ Checkpoint found: {size_mb:.1f} MB")
else:
    print(f"❌ Checkpoint NOT found at: {checkpoint_path}")
    print(f"\nChecking directory contents:")
    dir_path = os.path.dirname(checkpoint_path)
    if os.path.exists(dir_path):
        files = os.listdir(dir_path)
        print(f"Files in {dir_path}:")
        for f in files:
            print(f"  - {f}")
    else:
        print(f"Directory doesn't exist: {dir_path}")

In [None]:
print("="*80)
print("Loading CODI-LLaMA from Local Checkpoint")
print("="*80)

llama_model_args = ModelArguments(
    model_name_or_path="meta-llama/Llama-3.2-1B",
    lora_init=True,
    lora_r=128,
    lora_alpha=32,
    ckpt_dir="/workspace/CoT_Exploration/models/CODI-llama3.2-1b",  # Local checkpoint
    full_precision=True,
    token=None
)

llama_training_args = TrainingArguments(
    output_dir="./outputs",
    model_max_length=512,
    inf_latent_iterations=6,
    use_prj=True,
    prj_dim=2048,
    remove_eos=True,
    greedy=True,
    bf16=False,
    inf_num_iterations=1
)

llama_lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=llama_model_args.lora_r,
    lora_alpha=llama_model_args.lora_alpha,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", 
                    "gate_proj", "up_proj", "down_proj"],
    init_lora_weights=True,
)

# Create model with random initialization
llama_model = CODI(llama_model_args, llama_training_args, llama_lora_config)

# 🔥 LOAD THE TRAINED CHECKPOINT - THIS IS WHAT YOU WERE MISSING! 🔥
checkpoint_path = os.path.join(llama_model_args.ckpt_dir, "pytorch_model.bin")
print(f"\nLoading checkpoint from: {checkpoint_path}")

if not os.path.exists(checkpoint_path):
    print(f"❌ ERROR: Checkpoint not found at {checkpoint_path}")
    print(f"   Please verify the path exists")
else:
    state_dict = torch.load(checkpoint_path, map_location='cpu')
    
    # Load the trained weights
    missing_keys, unexpected_keys = llama_model.load_state_dict(state_dict, strict=False)
    
    print(f"✓ Checkpoint loaded successfully!")
    if missing_keys:
        print(f"  ⚠ Missing keys: {len(missing_keys)} (this is often normal for LoRA)")
    if unexpected_keys:
        print(f"  ⚠ Unexpected keys: {len(unexpected_keys)}")
    
    # Tie weights for LLaMA models (important!)
    llama_model.codi.tie_weights()
    print(f"✓ Weights tied")

# Move to device and set precision
llama_model = llama_model.to(device)
llama_model = llama_model.to(torch.bfloat16)
llama_model.eval()

print(f"✓ Model moved to {device} with bfloat16 precision")

# Load tokenizer
llama_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")

# Verify critical model components
print("\n" + "="*80)
print("Model Verification:")
print("="*80)
print(f"✓ BoT token ID: {llama_model.bot_id}")
print(f"✓ EoT token ID: {llama_model.eot_id}")
print(f"✓ Projection layer: {'Yes' if hasattr(llama_model, 'prj') else 'No'}")
if hasattr(llama_model, 'prj'):
    print(f"  - Projection dim: {llama_training_args.prj_dim}")
print(f"✓ LoRA rank: {llama_model_args.lora_r}")
print(f"✓ Chain-of-thought iterations: {llama_training_args.inf_latent_iterations}")

print("\n✓ CODI-LLaMA loaded successfully with trained weights!")

In [None]:
# Verify CODI weights are loaded
print("Checking if CODI weights are loaded:")
print(f"Model checkpoint dir: {llama_model_args.ckpt_dir}")
print(f"BoT token ID: {llama_model.bot_id}")
print(f"EoT token ID: {llama_model.eot_id}")
print(f"Using projection: {llama_training_args.use_prj}")
print(f"Projection dim: {llama_training_args.prj_dim if llama_training_args.use_prj else 'N/A'}")

# Check if projection layer exists
if hasattr(llama_model, 'prj'):
    print(f"✓ Projection layer found: {llama_model.prj}")
else:
    print("✗ No projection layer - this might be the problem!")

## Run LLaMA Forward Pass

In [48]:
question = "Kylar went to the store to buy glasses for his new apartment. One glass costs $5, but every second glass costs only 60% of the price. Kylar wants to buy 16 glasses. How much does he need to pay for them?"
answer = 64

In [66]:
print("="*80)
print("LLaMA: Running inference with Chain-of-Thought")
print("="*80)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 1

question = question
questions = [question]

# Tokenize with BoT token
if llama_training_args.remove_eos:
    bot_tensor = torch.tensor([llama_model.bot_id], dtype=torch.long).expand(batch_size, 1).to(device)
else:
    bot_tensor = torch.tensor([llama_tokenizer.eos_token_id, llama_model.bot_id], 
                              dtype=torch.long).expand(batch_size, 2).to(device)

inputs = llama_tokenizer(questions, return_tensors="pt", padding=False)
inputs = {k: v.to(device) for k, v in inputs.items()}
inputs["input_ids"] = torch.cat((inputs["input_ids"], bot_tensor), dim=1)
inputs["attention_mask"] = torch.cat((inputs["attention_mask"], torch.ones_like(bot_tensor)), dim=1)

print(f"WITH BoT - Input IDs: {inputs['input_ids']}")
print(f"WITH BoT - Last 5 token IDs: {inputs['input_ids'][0, -5:].tolist()}")
print(f"WITH BoT - BoT token ID should be: {llama_model.bot_id}")
print(f"WITH BoT - Last token is BoT?: {inputs['input_ids'][0, -1].item() == llama_model.bot_id}")


print(f"Input shape: {inputs['input_ids'].shape}")

# Store only the 7 chain-of-thought hidden states
cot_hidden_states = []

with torch.no_grad():
    # Initial encoding (position 0: BoT)
    past_key_values = None
    outputs = llama_model.codi(
        input_ids=inputs["input_ids"],
        use_cache=True,
        output_hidden_states=True,
        past_key_values=past_key_values,
        attention_mask=inputs["attention_mask"]
    )
    past_key_values = outputs.past_key_values
    
    # 🔥 CRITICAL: Only take the LAST position, not all positions
    latent_embd = outputs.hidden_states[-1][:, -1:, :]  # Shape: [batch, 1, hidden_dim]
    
    # Store BoT position (BEFORE projection)
    cot_hidden_states.append(latent_embd.clone())
    print(f"✓ Position 0 (BoT): shape={latent_embd.shape}")
    
    # Apply initial projection
    if llama_training_args.use_prj:
        latent_embd = llama_model.prj(latent_embd)
    
    # Chain-of-Thought iterations (positions 1-6)
    for i in range(llama_training_args.inf_latent_iterations):
        outputs = llama_model.codi(
            inputs_embeds=latent_embd,
            use_cache=True,
            output_hidden_states=True,
            past_key_values=past_key_values
        )
        past_key_values = outputs.past_key_values
        
        # 🔥 CRITICAL: Only take the LAST position
        latent_embd = outputs.hidden_states[-1][:, -1:, :]  # Shape: [batch, 1, hidden_dim]
        
        # Store BEFORE projection
        cot_hidden_states.append(latent_embd.clone())
        print(f"✓ Position {i+1} (T{i+1}): shape={latent_embd.shape}")
        
        # Apply projection for next iteration
        if llama_training_args.use_prj:
            latent_embd = llama_model.prj(latent_embd)

# Stack all 7 positions
llama_continuous_thoughts = torch.cat(cot_hidden_states, dim=1)
print(f"\nContinuous thoughts shape: {llama_continuous_thoughts.shape}")
print(f"Number of thought positions: {llama_continuous_thoughts.shape[1]}")

# Verify we have exactly 7 positions
if llama_continuous_thoughts.shape[1] != 7:
    print(f"\n❌ ERROR: Expected 7 positions, got {llama_continuous_thoughts.shape[1]}")
else:
    print(f"✓ SUCCESS: Collected exactly 7 CoT positions!")

# Also create llama_latent_embd as alias for compatibility with other cells
llama_latent_embd = llama_continuous_thoughts

LLaMA: Running inference with Chain-of-Thought
WITH BoT - Input IDs: tensor([[128000,     42,   4010,    277,   4024,    311,    279,   3637,    311,
           3780,  29247,    369,    813,    502,  13455,     13,   3861,   9168,
           7194,    400,     20,     11,    719,   1475,   2132,   9168,   7194,
           1193,    220,   1399,      4,    315,    279,   3430,     13,    735,
           4010,    277,   6944,    311,   3780,    220,    845,  29247,     13,
           2650,   1790,   1587,    568,   1205,    311,   2343,    369,   1124,
             30, 128257]], device='cuda:0')
WITH BoT - Last 5 token IDs: [2343, 369, 1124, 30, 128257]
WITH BoT - BoT token ID should be: 128257
WITH BoT - Last token is BoT?: True
Input shape: torch.Size([1, 56])
✓ Position 0 (BoT): shape=torch.Size([1, 1, 2048])
✓ Position 1 (T1): shape=torch.Size([1, 1, 2048])
✓ Position 2 (T2): shape=torch.Size([1, 1, 2048])
✓ Position 3 (T3): shape=torch.Size([1, 1, 2048])
✓ Position 4 (T4): shape=torch

## Decode LLaMA Continuous Thoughts (like section5_analysis.py)

In [50]:
print("="*80)
print("LLaMA: Decoding Continuous Thoughts")
print("="*80)

# Note: HuggingFace models don't have projection layers, decode directly
print(f"Continuous thoughts shape: {llama_continuous_thoughts.shape}")

# Decode first positions directly from continuous thoughts
print("\nDecoded tokens from continuous thoughts:")
llama_decoded_tokens = []
llama_number_positions = []

for i in range(min(50, llama_continuous_thoughts.shape[1])):
    logits = llama_model.codi.lm_head(llama_continuous_thoughts[:, i, :])
    top1_token_id = torch.argmax(logits, dim=-1).item()
    top1_token_str = llama_tokenizer.decode([top1_token_id])

    is_number = bool(number_regex.match(top1_token_str))
    pos_type = "BoT" if i == 0 else f"T{i}"

    llama_decoded_tokens.append((i, top1_token_id, top1_token_str, is_number))
    if is_number:
        llama_number_positions.append(i)

    marker = " ← NUMBER!" if is_number else ""
    print(f"  {pos_type:4s} [pos={i:2d}]: token_id={top1_token_id:5d} → '{top1_token_str}'{marker}")

print(f"\n✓ Numbers detected at positions: {llama_number_positions}")
print(f"✓ Total numbers: {len(llama_number_positions)}")

LLaMA: Decoding Continuous Thoughts
Continuous thoughts shape: torch.Size([1, 7, 2048])

Decoded tokens from continuous thoughts:
  BoT  [pos= 0]: token_id=   23 → '8' ← NUMBER!
  T1   [pos= 1]: token_id= 1272 → '40' ← NUMBER!
  T2   [pos= 2]: token_id= 1490 → '80' ← NUMBER!
  T3   [pos= 3]: token_id=  320 → ' ('
  T4   [pos= 4]: token_id= 1187 → '24' ← NUMBER!
  T5   [pos= 5]: token_id= 1187 → '24' ← NUMBER!
  T6   [pos= 6]: token_id= 2511 → '>>'

✓ Numbers detected at positions: [0, 1, 2, 4, 5]
✓ Total numbers: 5


In [68]:
print("="*80)
print("LLaMA: Generate Final Answer")
print("="*80)
with torch.no_grad():
    # Signal end-of-thought with EOT token
    if llama_training_args.remove_eos:
        eot_tensor = torch.tensor([[llama_model.eot_id]], dtype=torch.long).expand(batch_size, 1, 1).to(device)
    else:
        eot_tensor = torch.tensor([[llama_tokenizer.eos_token_id, llama_model.eot_id]], 
                                   dtype=torch.long).expand(batch_size, 1, 2).to(device)
    
    eot_emb = llama_model.get_embd(llama_model.codi, llama_model.model_name)(eot_tensor).squeeze(1)
    output = eot_emb
    
    # Generate answer tokens autoregressively
    pred_tokens = []
    max_length = 256
    
    for step in range(max_length):
        out = llama_model.codi(
            inputs_embeds=output,
            output_hidden_states=False,
            attention_mask=None,
            use_cache=True,
            output_attentions=False,
            past_key_values=past_key_values
        )
        past_key_values = out.past_key_values
        
        # Get logits for vocabulary tokens only
        logits = out.logits[:, -1, :llama_model.codi.config.vocab_size-1]
        next_token_id = torch.argmax(logits, dim=-1).item()
        
        pred_tokens.append(next_token_id)
        
        # Decode current token
        current_token_str = llama_tokenizer.decode([next_token_id])
        
        # Stop if EOS token
        if next_token_id == llama_tokenizer.eos_token_id:
            print(f"Stopped at step {step} (EOS token)")
            break
        
        # Stop immediately after generating a number token
        if number_regex.match(current_token_str.strip()):
            print(f"Stopped at step {step} (found number: '{current_token_str}')")
            break
        
        # Hard limit to prevent loops
        if step >= 49:
            print(f"Stopped at step {step} (max length)")
            break
        
        # Prepare next input
        output = llama_model.get_embd(llama_model.codi, llama_model.model_name)(
            torch.tensor([[next_token_id]], device=device)
        )
    
    
    # Extract numerical answer
    def extract_answer_number(text):
        text = text.replace(',', '')
        numbers = [s for s in re.findall(r'-?\d+\.?\d*', text)]
        if not numbers:
            return None
        return float(numbers[-1])
    
    predicted_number = extract_answer_number(full_answer)
    print(f"\nExtracted numerical answer: {predicted_number}")
    print(f"Expected answer: {answer}")

LLaMA: Generate Final Answer
Stopped at step 0 (found number: '64')

Extracted numerical answer: 64.0
Expected answer: 64


## Load CODI-GPT-2

In [None]:
print("="*80)
print("Loading CODI-GPT-2 from Local Checkpoint")
print("="*80)

  # Load from HuggingFace Hub
gpt2_model_name = "lhao499/codi-gpt2"

gpt2_model_args = ModelArguments(
      model_name_or_path="gpt2",
      lora_init=True,
      lora_r=128,
      lora_alpha=32,
      ckpt_dir="/workspace/CoT_Exploration/models/CODI-gpt2",  # Local checkpoint
      full_precision=True,
      token=None
  )

gpt2_training_args = TrainingArguments(
      output_dir="./outputs",
      model_max_length=512,
      inf_latent_iterations=6,
      use_prj=True,
      prj_dim=768,
      remove_eos=True,
      greedy=True,
      bf16=False,
      inf_num_iterations=1
  )

gpt2_lora_config = LoraConfig(
      task_type=TaskType.CAUSAL_LM,
      inference_mode=False,
      r=gpt2_model_args.lora_r,
      lora_alpha=gpt2_model_args.lora_alpha,
      lora_dropout=0.1,
      target_modules=["c_attn", "c_proj", "c_fc"],
      init_lora_weights=True,
  )

gpt2_model = CODI(gpt2_model_args,
gpt2_training_args, gpt2_lora_config)
gpt2_model = gpt2_model.to(device)
gpt2_model = gpt2_model.to(torch.bfloat16)
gpt2_model.eval()
gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2")

print("✓ CODI-GPT-2 loaded successfully from HuggingFace")

## Run GPT-2 Forward Pass

In [None]:
print("="*80)
print("GPT-2 Forward Pass")
print("="*80)

gpt2_inputs = gpt2_tokenizer(question, return_tensors="pt", add_special_tokens=True)
gpt2_input_ids = gpt2_inputs.input_ids.to(device)

print(f"Input shape: {gpt2_input_ids.shape}")

with torch.no_grad():
    gpt2_outputs = gpt2_model.codi(
        input_ids=gpt2_input_ids,
        output_hidden_states=True
    )
    gpt2_continuous_thoughts = gpt2_outputs.hidden_states[-1]

print(f"Continuous thoughts shape: {gpt2_continuous_thoughts.shape}")
print(f"Number of thought positions: {gpt2_continuous_thoughts.shape[1]}")


## Decode GPT-2 Continuous Thoughts (like section5_analysis.py)

In [None]:
print("="*80)
print("GPT-2: Decoding Continuous Thoughts")
print("="*80)

# Note: HuggingFace models don't have projection layers, decode directly
print(f"Continuous thoughts shape: {gpt2_continuous_thoughts.shape}")



# Decode first 15 positions directly from continuous thoughts
print("\nDecoded tokens from continuous thoughts:")
gpt2_decoded_tokens = []
gpt2_number_positions = []

for i in range(min(15, gpt2_continuous_thoughts.shape[1])):
    logits = gpt2_model.codi.lm_head(gpt2_continuous_thoughts[:, i, :])
    top1_token_id = torch.argmax(logits, dim=-1).item()
    top1_token_str = gpt2_tokenizer.decode([top1_token_id])

    is_number = bool(number_regex.match(top1_token_str))
    pos_type = "BoT" if i == 0 else f"T{i}"

    gpt2_decoded_tokens.append((i, top1_token_id, top1_token_str, is_number))
    if is_number:
        gpt2_number_positions.append(i)

    marker = " ← NUMBER!" if is_number else ""
    print(f"  {pos_type:4s} [pos={i:2d}]: token_id={top1_token_id:5d} → '{top1_token_str}'{marker}")

print(f"\n✓ Numbers detected at positions: {gpt2_number_positions}")
print(f"✓ Total numbers: {len(gpt2_number_positions)}")

## GPT-2: Projection Intervention (Target Token = '5', k=3)

In [None]:
print("="*80)
print("GPT-2: Projection Intervention")
print("="*80)

target_token = '5'
k = 3

# Get target token embedding
target_token_id = gpt2_tokenizer.encode(target_token, add_special_tokens=False)[0]
embedding_layer = gpt2_model.codi.get_input_embeddings()
target_embd = embedding_layer(torch.tensor([target_token_id], device=device))

print(f"Target token: '{target_token}'")
print(f"Target token ID: {target_token_id}")
print(f"k (top-k intervention): {k}")

gpt2_interventions = []

for i in range(min(15, gpt2_latent_embd.shape[1])):
    pos_type = "BoT" if i == 0 else f"T{i}"
    
    # Get predicted token
    logits = gpt2_model.codi.lm_head(gpt2_latent_embd[:, i, :])
    top1_token_id = torch.argmax(logits, dim=-1).item()
    top1_token_str = gpt2_tokenizer.decode([top1_token_id])
    
    # Check if it's a number
    is_number = bool(number_regex.match(top1_token_str))
    
    if is_number:
        # Get predicted token embedding
        predicted_embd = embedding_layer(torch.tensor([top1_token_id], device=device))
        
        # Get activation
        A = gpt2_latent_embd[:, i, :]  # [1, hidden_dim]
        
        # Normalize embeddings
        E_pred_norm = predicted_embd / torch.norm(predicted_embd, dim=-1, keepdim=True)
        E_target_norm = target_embd / torch.norm(target_embd, dim=-1, keepdim=True)
        
        # Projection removal and replacement
        proj_predicted = torch.sum(A * E_pred_norm, dim=-1, keepdim=True) * E_pred_norm
        proj_target = torch.sum(A * E_target_norm, dim=-1, keepdim=True) * E_target_norm
        
        A_modified = A - proj_predicted + k * proj_target
        
        # Decode modified activation
        logits_modified = gpt2_model.codi.lm_head(A_modified)
        new_token_id = torch.argmax(logits_modified, dim=-1).item()
        new_token_str = gpt2_tokenizer.decode([new_token_id])
        
        gpt2_interventions.append({
            'position': i,
            'predicted_token': top1_token_str,
            'new_token': new_token_str,
            'intervened': True
        })
        
        print(f"{pos_type:4s} [pos={i:2d}]: '{top1_token_str}' → '{new_token_str}' (intervened)")
    else:
        print(f"{pos_type:4s} [pos={i:2d}]: '{top1_token_str}' (not a number, skipped)")

print(f"\n✓ GPT-2 interventions: {len(gpt2_interventions)}")

## LLaMA: Projection Intervention (Target Token = '5', k=3)

In [None]:
print("="*80)
print("LLaMA: CAUSAL Projection Intervention")
print("="*80)

# CRITICAL: Reset to ensure clean state
torch.manual_seed(42)  # For reproducibility
past_key_values = None  # Clear any cached state

target_token = '5'
k = 3

# Get target token embedding
target_token_id = llama_tokenizer.encode(target_token, add_special_tokens=False)[0]
embedding_layer = llama_model.codi.get_input_embeddings()
target_embd = embedding_layer(torch.tensor([target_token_id], device=device))

print(f"Target token: '{target_token}'")
print(f"Target token ID: {target_token_id}")
print(f"k (top-k intervention): {k}")
print(f"\nRunning CAUSAL intervention (affects downstream positions)...\n")

# Re-run the chain-of-thought with interventions
batch_size = 1

# Set pad token if needed
if llama_tokenizer.pad_token is None:
    llama_tokenizer.pad_token = llama_tokenizer.eos_token
    llama_tokenizer.pad_token_id = llama_tokenizer.eos_token_id

# Tokenize with BoT token
if llama_training_args.remove_eos:
    bot_tensor = torch.tensor([llama_model.bot_id], dtype=torch.long).expand(batch_size, 1).to(device)
else:
    bot_tensor = torch.tensor([llama_tokenizer.eos_token_id, llama_model.bot_id], 
                              dtype=torch.long).expand(batch_size, 2).to(device)

inputs = llama_tokenizer([question], return_tensors="pt", padding=False)
inputs = {k: v.to(device) for k, v in inputs.items()}
inputs["input_ids"] = torch.cat((inputs["input_ids"], bot_tensor), dim=1)
inputs["attention_mask"] = torch.cat((inputs["attention_mask"], torch.ones_like(bot_tensor)), dim=1)

llama_interventions = []
intervened_positions = []

with torch.no_grad():
    # Initial encoding (position 0: BoT)
    past_key_values = None
    outputs = llama_model.codi(
        input_ids=inputs["input_ids"],
        use_cache=True,
        output_hidden_states=True,
        past_key_values=past_key_values,
        attention_mask=inputs["attention_mask"]
    )
    past_key_values = outputs.past_key_values
    latent_embd = outputs.hidden_states[-1][:, -1:, :]
    
    # Check BoT position
    logits = llama_model.codi.lm_head(latent_embd.squeeze(1))
    top1_token_id = torch.argmax(logits, dim=-1).item()
    top1_token_str = llama_tokenizer.decode([top1_token_id])
    is_number = bool(number_regex.match(top1_token_str))
    
    # Intervene at BoT if it's a number
    if is_number:
        predicted_embd = embedding_layer(torch.tensor([top1_token_id], device=device))
        A = latent_embd.squeeze(1)
        
        # Normalize embeddings
        E_pred_norm = predicted_embd / torch.norm(predicted_embd, dim=-1, keepdim=True)
        E_target_norm = target_embd / torch.norm(target_embd, dim=-1, keepdim=True)
        
        # Projection intervention
        proj_predicted = torch.sum(A * E_pred_norm, dim=-1, keepdim=True) * E_pred_norm
        proj_target = torch.norm(proj_predicted, dim =-1, keepdim=True) * E_target_norm
        #proj_target = torch.sum(A * E_target_norm, dim=-1, keepdim=True) * E_target_norm
        A_modified = A - proj_predicted + k * proj_target
        
        # Check what it decodes to after intervention
        logits_modified = llama_model.codi.lm_head(A_modified)
        new_token_id = torch.argmax(logits_modified, dim=-1).item()
        new_token_str = llama_tokenizer.decode([new_token_id])
        
        # CAUSALLY apply the intervention
        latent_embd = A_modified.unsqueeze(1)
        
        llama_interventions.append({
            'position': 0,
            'predicted_token': top1_token_str,
            'new_token': new_token_str,
            'intervened': True
        })
        intervened_positions.append(0)
        print(f"BoT  [pos= 0]: '{top1_token_str}' → '{new_token_str}' ✓ INTERVENED (causal)")
    else:
        print(f"BoT  [pos= 0]: '{top1_token_str}' (not a number, no intervention)")
    
    # Apply initial projection
    if llama_training_args.use_prj:
        latent_embd = llama_model.prj(latent_embd)
    
    # Chain-of-Thought iterations (positions 1-6)
    for i in range(llama_training_args.inf_latent_iterations):
        outputs = llama_model.codi(
            inputs_embeds=latent_embd,
            use_cache=True,
            output_hidden_states=True,
            past_key_values=past_key_values
        )
        past_key_values = outputs.past_key_values
        latent_embd = outputs.hidden_states[-1][:, -1:, :]
        
        # Check this position
        logits = llama_model.codi.lm_head(latent_embd.squeeze(1))
        top1_token_id = torch.argmax(logits, dim=-1).item()
        top1_token_str = llama_tokenizer.decode([top1_token_id])
        is_number = bool(number_regex.match(top1_token_str))
        
        pos_type = f"T{i+1}"
        
        # Intervene if it's a number
        if is_number:
            predicted_embd = embedding_layer(torch.tensor([top1_token_id], device=device))
            A = latent_embd.squeeze(1)
            
            # Normalize embeddings
            E_pred_norm = predicted_embd / torch.norm(predicted_embd, dim=-1, keepdim=True)
            E_target_norm = target_embd / torch.norm(target_embd, dim=-1, keepdim=True)
            
            # Projection intervention
            proj_predicted = torch.sum(A * E_pred_norm, dim=-1, keepdim=True) * E_pred_norm
            proj_target = torch.sum(A * E_target_norm, dim=-1, keepdim=True) * E_target_norm
            A_modified = A - proj_predicted + k * proj_target
            
            # Check what it decodes to after intervention
            logits_modified = llama_model.codi.lm_head(A_modified)
            new_token_id = torch.argmax(logits_modified, dim=-1).item()
            new_token_str = llama_tokenizer.decode([new_token_id])
            
            # CAUSALLY apply the intervention
            latent_embd = A_modified.unsqueeze(1)
            
            llama_interventions.append({
                'position': i+1,
                'predicted_token': top1_token_str,
                'new_token': new_token_str,
                'intervened': True
            })
            intervened_positions.append(i+1)
            print(f"{pos_type:4s} [pos={i+1:2d}]: '{top1_token_str}' → '{new_token_str}' ✓ INTERVENED (causal)")
        else:
            print(f"{pos_type:4s} [pos={i+1:2d}]: '{top1_token_str}' (not a number, no intervention)")
        
        # Apply projection for next iteration
        if llama_training_args.use_prj:
            latent_embd = llama_model.prj(latent_embd)

print(f"\n{'='*80}")
print(f"✓ Total interventions: {len(llama_interventions)}")
print(f"✓ Intervened at positions: {intervened_positions}")
print(f"{'='*80}")

# Now generate the final answer with the intervened chain-of-thought
print("\nGenerating final answer with intervened CoT...")

# Signal end-of-thought with EOT token
if llama_training_args.remove_eos:
    eot_tensor = torch.tensor([[llama_model.eot_id]], dtype=torch.long).expand(batch_size, 1, 1).to(device)
else:
    eot_tensor = torch.tensor([[llama_tokenizer.eos_token_id, llama_model.eot_id]], 
                               dtype=torch.long).expand(batch_size, 1, 2).to(device)

eot_emb = llama_model.get_embd(llama_model.codi, llama_model.model_name)(eot_tensor).squeeze(1)
output = eot_emb

# Generate answer tokens
# Generate answer tokens
pred_tokens = []
found_number = False

for step in range(256):
    out = llama_model.codi(
        inputs_embeds=output,
        output_hidden_states=False,
        attention_mask=None,
        use_cache=True,
        output_attentions=False,
        past_key_values=past_key_values
    )
    past_key_values = out.past_key_values
    
    logits = out.logits[:, -1, :llama_model.codi.config.vocab_size-1]
    next_token_id = torch.argmax(logits, dim=-1).item()
    pred_tokens.append(next_token_id)
    
    # Decode current token
    current_token_str = llama_tokenizer.decode([next_token_id])
    
    # Stop if EOS token
    if next_token_id == llama_tokenizer.eos_token_id:
        print(f"Stopped at step {step} (EOS token)")
        break
    
    # Stop immediately after generating a number token
    if number_regex.match(current_token_str.strip()):
        print(f"Stopped at step {step} (found number: '{current_token_str}')")
        break
    
    # Hard limit to prevent loops
    if step >= 49:
        print(f"Stopped at step {step} (max length)")
        break
    
    output = llama_model.get_embd(llama_model.codi, llama_model.model_name)(
        torch.tensor([[next_token_id]], device=device)
    )
# Decode and extract answer
intervened_answer = llama_tokenizer.decode(pred_tokens, skip_special_tokens=True)
print(f"\nIntervened Answer: {intervened_answer}")

def extract_answer_number(text):
    text = text.replace(',', '')
    numbers = [s for s in re.findall(r'-?\d+\.?\d*', text)]
    if not numbers:
        return None
    return float(numbers[0])

intervened_number = extract_answer_number(intervened_answer)
print(f"Intervened numerical answer: {intervened_number}")
print(f"\n(Compare to original expected: 18)")

In [65]:
print("="*80)
print("LLaMA: Chain-of-Thought WITHOUT Beginning-of-Thought Token")
print("="*80)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 1

questions = [question]

# Set pad token if needed
if llama_tokenizer.pad_token is None:
    llama_tokenizer.pad_token = llama_tokenizer.eos_token
    llama_tokenizer.pad_token_id = llama_tokenizer.eos_token_id

# Tokenize WITHOUT BoT token - just the raw question
inputs = llama_tokenizer(questions, return_tensors="pt", padding=False)
inputs = {k: v.to(device) for k, v in inputs.items()}

print(f"WITHOUT BoT - Input IDs: {inputs['input_ids']}")
print(f"WITHOUT BoT - Last 5 token IDs: {inputs['input_ids'][0, -5:].tolist()}")
print(f"WITHOUT BoT - BoT token ID should be: {llama_model.bot_id}")
print(f"WITHOUT BoT - Contains BoT?: {llama_model.bot_id in inputs['input_ids'][0].tolist()}")
print(f"Input shape (no BoT): {inputs['input_ids'].shape}")

# Store the 6 chain-of-thought hidden states (no BoT position)
cot_hidden_states_no_bot = []

with torch.no_grad():
    # Initial encoding - start directly from the last token of the question
    past_key_values = None
    outputs = llama_model.codi(
        input_ids=inputs["input_ids"],
        use_cache=True,
        output_hidden_states=True,
        past_key_values=past_key_values,
        attention_mask=inputs["attention_mask"]
    )
    past_key_values = outputs.past_key_values
    
    # Start from the last question token (no BoT)
    latent_embd = outputs.hidden_states[-1][:, -1:, :]  # Shape: [batch, 1, hidden_dim]
    
    # Apply initial projection
    if llama_training_args.use_prj:
        latent_embd = llama_model.prj(latent_embd)
    
    # Chain-of-Thought iterations (6 positions, no BoT)
    for i in range(llama_training_args.inf_latent_iterations):
        outputs = llama_model.codi(
            inputs_embeds=latent_embd,
            use_cache=True,
            output_hidden_states=True,
            past_key_values=past_key_values
        )
        past_key_values = outputs.past_key_values
        
        # Get hidden state BEFORE projection
        latent_embd = outputs.hidden_states[-1][:, -1:, :]  # Shape: [batch, 1, hidden_dim]
        
        # Store BEFORE projection for decoding
        cot_hidden_states_no_bot.append(latent_embd.clone())
        print(f"✓ Position {i} (T{i}): shape={latent_embd.shape}")
        
        # Apply projection for next iteration
        if llama_training_args.use_prj:
            latent_embd = llama_model.prj(latent_embd)

# Stack all 6 positions (no BoT)
llama_continuous_thoughts_no_bot = torch.cat(cot_hidden_states_no_bot, dim=1)
print(f"\n✓ Total continuous thoughts (no BoT): {llama_continuous_thoughts_no_bot.shape}")
print(f"  Expected: [1, 6, hidden_dim]")
print(f"  Got: {list(llama_continuous_thoughts_no_bot.shape)}")

if llama_continuous_thoughts_no_bot.shape[1] != 6:
    print(f"\n❌ ERROR: Expected 6 positions, got {llama_continuous_thoughts_no_bot.shape[1]}")
else:
    print(f"\n✓ SUCCESS: Collected exactly 6 CoT positions (no BoT)!")

# Decode the continuous thoughts
print("\n" + "="*80)
print("Decoding Continuous Thoughts (no BoT)")
print("="*80)

llama_decoded_tokens_no_bot = []
llama_number_positions_no_bot = []

for i in range(llama_continuous_thoughts_no_bot.shape[1]):
    hidden_state = llama_continuous_thoughts_no_bot[:, i, :]
    
    # Decode using CODI's LM head
    logits = llama_model.codi.lm_head(hidden_state)
    top1_token_id = torch.argmax(logits, dim=-1).item()
    top1_token_str = llama_tokenizer.decode([top1_token_id])
    
    is_number = bool(number_regex.match(top1_token_str.strip()))
    pos_type = f"T{i}"
    
    llama_decoded_tokens_no_bot.append((i, top1_token_id, top1_token_str, is_number))
    if is_number:
        llama_number_positions_no_bot.append(i)
    
    marker = " 🔢 NUMBER!" if is_number else ""
    print(f"  {pos_type:4s} [pos={i:2d}]: token_id={top1_token_id:5d} → '{top1_token_str}'{marker}")

print(f"\n📊 Numbers at positions: {llama_number_positions_no_bot}")
print(f"📊 Total numbers: {len(llama_number_positions_no_bot)}/6")

# Generate final answer (no BoT)
print("\n" + "="*80)
print("Generate Final Answer (no BoT)")
print("="*80)

with torch.no_grad():
    # Signal end-of-thought with EOT token
    if llama_training_args.remove_eos:
        eot_tensor = torch.tensor([[llama_model.eot_id]], dtype=torch.long).expand(batch_size, 1, 1).to(device)
    else:
        eot_tensor = torch.tensor([[llama_tokenizer.eos_token_id, llama_model.eot_id]], 
                                   dtype=torch.long).expand(batch_size, 1, 2).to(device)
    
    eot_emb = llama_model.get_embd(llama_model.codi, llama_model.model_name)(eot_tensor).squeeze(1)
    output = eot_emb
    
    # Generate answer tokens
    pred_tokens = []
    
    for step in range(256):
        out = llama_model.codi(
            inputs_embeds=output,
            output_hidden_states=False,
            attention_mask=None,
            use_cache=True,
            output_attentions=False,
            past_key_values=past_key_values
        )
        past_key_values = out.past_key_values
        
        logits = out.logits[:, -1, :llama_model.codi.config.vocab_size-1]
        next_token_id = torch.argmax(logits, dim=-1).item()
        pred_tokens.append(next_token_id)
        
        # Decode current token
        current_token_str = llama_tokenizer.decode([next_token_id])
        
        # Stop if EOS token
        if next_token_id == llama_tokenizer.eos_token_id:
            print(f"Stopped at step {step} (EOS token)")
            break
        
        # Stop immediately after generating a number token
        if number_regex.match(current_token_str.strip()):
            print(f"Stopped at step {step} (found number: '{current_token_str}')")
            break
        
        # Hard limit to prevent loops
        if step >= 49:
            print(f"Stopped at step {step} (max length)")
            break
        
        output = llama_model.get_embd(llama_model.codi, llama_model.model_name)(
            torch.tensor([[next_token_id]], device=device)
        )
    
    # Decode full answer
    full_answer_no_bot = llama_tokenizer.decode(pred_tokens, skip_special_tokens=True)
    print(f"\nGenerated Answer (no BoT): {full_answer_no_bot}")
    
    # Extract numerical answer
    def extract_answer_number(text):
        text = text.replace(',', '')
        numbers = [s for s in re.findall(r'-?\d+\.?\d*', text)]
        if not numbers:
            return None
        return float(numbers[-1])
    
    print(f"Qeustion: {question}")
    predicted_number_no_bot = extract_answer_number(full_answer_no_bot)
    print(f"Numerical answer (no BoT): {predicted_number_no_bot}")
    print(f"Expected answer: ",answer)

print("\n" + "="*80)
print("Comparison: With BoT vs Without BoT")
print("="*80)
print(f"With BoT:    7 positions (1 BoT + 6 CoT)")
print(f"Without BoT: 6 positions (0 BoT + 6 CoT)")
print(f"\nThis tests whether the BoT token improves reasoning performance.")

LLaMA: Chain-of-Thought WITHOUT Beginning-of-Thought Token
WITHOUT BoT - Input IDs: tensor([[128000,     42,   4010,    277,   4024,    311,    279,   3637,    311,
           3780,  29247,    369,    813,    502,  13455,     13,   3861,   9168,
           7194,    400,     20,     11,    719,   1475,   2132,   9168,   7194,
           1193,    220,   1399,      4,    315,    279,   3430,     13,    735,
           4010,    277,   6944,    311,   3780,    220,    845,  29247,     13,
           2650,   1790,   1587,    568,   1205,    311,   2343,    369,   1124,
             30]], device='cuda:0')
WITHOUT BoT - Last 5 token IDs: [311, 2343, 369, 1124, 30]
WITHOUT BoT - BoT token ID should be: 128257
WITHOUT BoT - Contains BoT?: False
Input shape (no BoT): torch.Size([1, 55])
✓ Position 0 (T0): shape=torch.Size([1, 1, 2048])
✓ Position 1 (T1): shape=torch.Size([1, 1, 2048])
✓ Position 2 (T2): shape=torch.Size([1, 1, 2048])
✓ Position 3 (T3): shape=torch.Size([1, 1, 2048])
✓ Position 4 

## Compare Tokenizers (Sanity Check)

In [None]:
print("="*80)
print("Tokenizer Comparison")
print("="*80)

test_numbers = ['0', '1', '2', '3', '4', '5', '16', '18']

print("\nHow do tokenizers encode/decode numbers?\n")
for num in test_numbers:
    gpt2_ids = gpt2_tokenizer.encode(num, add_special_tokens=False)
    gpt2_decoded = gpt2_tokenizer.decode(gpt2_ids)
    gpt2_matches = bool(number_regex.match(gpt2_decoded))
    
    llama_ids = llama_tokenizer.encode(num, add_special_tokens=False)
    llama_decoded = llama_tokenizer.decode(llama_ids)
    llama_matches = bool(number_regex.match(llama_decoded))
    
    match_indicator = "✓" if gpt2_decoded == llama_decoded else "✗ MISMATCH"
    
    print(f"Number: '{num}'")
    print(f"  GPT-2:  IDs={gpt2_ids} → '{gpt2_decoded}' (matches={gpt2_matches})")
    print(f"  LLaMA:  IDs={llama_ids} → '{llama_decoded}' (matches={llama_matches})")
    print(f"  {match_indicator}\n")

## Summary and Diagnosis

In [None]:
print("="*80)
print("SUMMARY")
print("="*80)

print(f"\nGPT-2:")
print(f"  - Numbers detected: {len(gpt2_number_positions)}")
print(f"  - Interventions: {len(gpt2_interventions)}")

print(f"\nLLaMA:")
print(f"  - Numbers detected: {len(llama_number_positions)}")
print(f"  - Interventions: {len(llama_interventions)}")

print("\n" + "="*80)
print("DIAGNOSIS")
print("="*80)

if len(llama_number_positions) == 0:
    print("\n⚠️  ROOT CAUSE IDENTIFIED:")
    print("LLaMA's lm_head does NOT predict number tokens from continuous thoughts.")
    print("This explains why there are 0 interventions.")
    print("\nPossible reasons:")
    print("  1. Projection layers (bot_projection/thought_projection) not trained correctly")
    print("  2. lm_head not loaded correctly (check strict=False in load_state_dict)")
    print("  3. Vocabulary size mismatch causing wrong token predictions")
    print("  4. Continuous thoughts from different distribution than GPT-2")
elif len(llama_number_positions) < len(gpt2_number_positions):
    print("\n⚠️  LLaMA detects fewer numbers than GPT-2")
    print(f"  GPT-2: {len(gpt2_number_positions)} numbers")
    print(f"  LLaMA: {len(llama_number_positions)} numbers")
else:
    print("\n✓ Both models detect similar numbers of tokens")
    print("  The issue may be in the intervention logic, not the decoding.")

## Additional Investigation: Top-5 Predictions

In [None]:
print("="*80)
print("Top-5 Predictions Comparison")
print("="*80)

print("\nGPT-2 - First 5 thought positions:")
for i in range(min(5, gpt2_latent_embd.shape[1])):
    logits = gpt2_model.codi.lm_head(gpt2_latent_embd[:, i, :])
    top5_vals, top5_ids = torch.topk(logits[0], 5)
    top5_tokens = [gpt2_tokenizer.decode([tid.item()]) for tid in top5_ids]
    
    pos_type = "BoT" if i == 0 else f"T{i}"
    print(f"\n  {pos_type} [pos={i}]:")
    for j, (token, val) in enumerate(zip(top5_tokens, top5_vals)):
        is_num = "← NUM" if number_regex.match(token) else ""
        print(f"    {j+1}. '{token}' (logit={val.item():.2f}) {is_num}")

print("\n" + "-"*80)
print("\nLLaMA - First 5 thought positions:")
for i in range(min(5, llama_latent_embd.shape[1])):
    logits = llama_model.codi.lm_head(llama_latent_embd[:, i, :])
    top5_vals, top5_ids = torch.topk(logits[0], 5)
    top5_tokens = [llama_tokenizer.decode([tid.item()]) for tid in top5_ids]
    
    pos_type = "BoT" if i == 0 else f"T{i}"
    print(f"\n  {pos_type} [pos={i}]:")
    for j, (token, val) in enumerate(zip(top5_tokens, top5_vals)):
        is_num = "← NUM" if number_regex.match(token) else ""
        print(f"    {j+1}. '{token}' (logit={val.item():.2f}) {is_num}")

## Inspect Activation Norms

In [None]:
print("="*80)
print("Activation Norms")
print("="*80)

gpt2_norms = torch.norm(gpt2_latent_embd[0], dim=-1)
llama_norms = torch.norm(llama_latent_embd[0], dim=-1)

print(f"\nGPT-2 latent embedding norms (first 10 positions):")
for i in range(min(10, len(gpt2_norms))):
    print(f"  Position {i}: {gpt2_norms[i].item():.2f}")

print(f"\nLLaMA latent embedding norms (first 10 positions):")
for i in range(min(10, len(llama_norms))):
    print(f"  Position {i}: {llama_norms[i].item():.2f}")

print(f"\nGPT-2 mean norm: {gpt2_norms.mean().item():.2f}")
print(f"LLaMA mean norm: {llama_norms.mean().item():.2f}")