# Debug Projection Intervention: GPT-2 vs LLaMA

This notebook contains minimal working examples of CODI, focussing on LLAMA. There is also similar code for GPT-2, but it has not been tested properly.

This includes:
- loading model
- generating continuous chain of thought
- decoding continuous chain of thought
- projection replacement of decoded number tokens by specific token like '5'
- generating chain of thought with and without BOT (beginning of thought token)

## Setup and Imports

In [1]:
!pip install transformers torch peft matplotlib datasets tqdm hf_transfer dotenv

Collecting transformers
  Using cached transformers-4.57.1-py3-none-any.whl.metadata (43 kB)
Collecting peft
  Using cached peft-0.17.1-py3-none-any.whl.metadata (14 kB)
Collecting matplotlib
  Using cached matplotlib-3.10.7-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)
Collecting datasets
  Using cached datasets-4.3.0-py3-none-any.whl.metadata (18 kB)
Collecting tqdm
  Using cached tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Collecting hf_transfer
  Using cached hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.7 kB)
Collecting dotenv
  Using cached dotenv-0.9.9-py2.py3-none-any.whl.metadata (279 bytes)
Collecting huggingface-hub<1.0,>=0.34.0 (from transformers)
  Using cached huggingface_hub-0.36.0-py3-none-any.whl.metadata (14 kB)
Collecting regex!=2019.12.17 (from transformers)
  Using cached regex-2025.10.23-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (40 kB)
Collecti

In [2]:
import os
os.listdir("/workspace/CoT_Exploration/codi")

['train.py',
 'test.py',
 'src',
 'scripts',
 'requirements.txt',
 'probe_latent_token.py',
 'outputs',
 'imgs',
 'README.md',
 '.git']

In [128]:
import torch
import sys
import re
import os
from pathlib import Path

sys.path.insert(0, "/workspace/CoT_Exploration/codi")
from src.model import CODI, ModelArguments, TrainingArguments
from peft import LoraConfig, TaskType
from transformers import AutoTokenizer
from safetensors.torch import load_file

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


## Test Example

In [129]:
# GSM8K example
question = "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"
answer = "Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day. She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer's market. #### 18"
answer = 18

test_text = f"{question}\n{answer}"
print(f"Test text:\n{test_text}")
print(f"\nLength: {len(test_text)} characters")

Test text:
Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?
18

Length: 283 characters


In [130]:
# Number detection regex
number_regex = re.compile(r'^\s?\d+')

## Load CODI-LLaMA

In [131]:
# Login to HuggingFace (optional - only needed  for gated models like LLaMA)
import os
from dotenv import load_dotenv

  # Load environment variables from .env file
load_dotenv()

  # Get HuggingFace token
hf_token = os.getenv('HF_TOKEN')

if hf_token:
      from huggingface_hub import login
      login(token=hf_token)
      print("✓ Logged in to HuggingFace")
else:
      print("⚠ No HF_TOKEN found in .env file -  proceeding without authentication")
      print("  (This is fine if models are public)")

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


✓ Logged in to HuggingFace


In [132]:
def load_llama_model():
    """Load CODI-LLaMA model from HuggingFace Hub"""
    print("="*80)
    print("Loading CODI-LLaMA from HuggingFace Hub")
    print("="*80)

    llama_model_name = "zen-E/CODI-llama3.2-1b-Instruct"

    llama_model_args = ModelArguments(
        model_name_or_path="meta-llama/Llama-3.2-1B",
        lora_init=True,
        lora_r=128,
        lora_alpha=32,
        ckpt_dir=None,  # Not using local checkpoint
        full_precision=True,
        token=None
    )

    llama_training_args = TrainingArguments(
        output_dir="./outputs",
        model_max_length=512,
        inf_latent_iterations=6,
        use_prj=True,
        prj_dim=2048,
        remove_eos=True,
        greedy=True,
        bf16=False,
        inf_num_iterations=1
    )

    llama_lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=llama_model_args.lora_r,
        lora_alpha=llama_model_args.lora_alpha,
        lora_dropout=0.1,
        target_modules=["q_proj", "v_proj", "k_proj", "o_proj",
                        "gate_proj", "up_proj", "down_proj"],
        init_lora_weights=True,
    )

    # Initialize the CODI model
    llama_model = CODI(llama_model_args, llama_training_args, llama_lora_config)
    
    # Download and load weights from HuggingFace Hub
    from huggingface_hub import hf_hub_download
    
    print(f"Downloading weights from {llama_model_name}...")
    checkpoint_path = hf_hub_download(
        repo_id=llama_model_name,
        filename="pytorch_model.bin",
        token=hf_token if hf_token else None
    )
    
    print(f"Loading weights from {checkpoint_path}...")
    state_dict = torch.load(checkpoint_path, map_location='cpu')
    llama_model.load_state_dict(state_dict, strict=False)
    llama_model.codi.tie_weights()
    
    llama_model = llama_model.to(device)
    llama_model = llama_model.to(torch.bfloat16)
    llama_model.eval()

    llama_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
    
    print("✓ CODI-LLaMA loaded successfully from HuggingFace Hub")

    return llama_model, llama_tokenizer, llama_training_args

In [133]:
print("="*80)
print("Loading CODI-LLaMA from Hugging Face Hub")
print("="*80)

# Load from HuggingFace Hub
llama_model_name = "zen-E/CODI-llama3.2-1b-Instruct"

llama_model_args = ModelArguments(
    model_name_or_path="meta-llama/Llama-3.2-1B",
    lora_init=True,
    lora_r=128,
    lora_alpha=32,
    ckpt_dir=None,  # Not using local checkpoint
    full_precision=True,
    token=None
)

llama_training_args = TrainingArguments(
    output_dir="./outputs",
    model_max_length=512,
    inf_latent_iterations=6,
    use_prj=True,
    prj_dim=2048,
    remove_eos=True,
    greedy=True,
    bf16=False,
    inf_num_iterations=1
)

llama_lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=llama_model_args.lora_r,
    lora_alpha=llama_model_args.lora_alpha,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"],
    init_lora_weights=True,
)

# Initialize the CODI model
llama_model = CODI(llama_model_args, llama_training_args, llama_lora_config)

# Download and load weights from HuggingFace Hub
from huggingface_hub import hf_hub_download

print(f"Downloading weights from {llama_model_name}...")
checkpoint_path = hf_hub_download(
    repo_id=llama_model_name,
    filename="pytorch_model.bin",
    token=hf_token if hf_token else None
)

print(f"Loading weights from {checkpoint_path}...")
state_dict = torch.load(checkpoint_path, map_location='cpu')
llama_model.load_state_dict(state_dict, strict=False)
llama_model.codi.tie_weights()

llama_model = llama_model.to(device)
llama_model = llama_model.to(torch.bfloat16)
llama_model.eval()

llama_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")

print("✓ CODI-LLaMA loaded successfully from HuggingFace Hub")

Loading CODI-LLaMA from Hugging Face Hub




trainable params: 98574336 || all params: 1334394880 || trainable%: 7.387193811774817
Downloading weights from zen-E/CODI-llama3.2-1b-Instruct...
Loading weights from /workspace/.cache/huggingface/hub/models--zen-E--CODI-llama3.2-1b-Instruct/snapshots/b2c88ba224b06b12b52ef39b87f794b98a6eb1c8/pytorch_model.bin...
✓ CODI-LLaMA loaded successfully from HuggingFace Hub


In [5]:
import os

checkpoint_path = "/workspace/CoT_Exploration/models/CODI-llama3.2-1b/pytorch_model.bin"
if os.path.exists(checkpoint_path):
    size_mb = os.path.getsize(checkpoint_path) / (1024 * 1024)
    print(f"✓ Checkpoint found: {size_mb:.1f} MB")
else:
    print(f"❌ Checkpoint NOT found at: {checkpoint_path}")
    print(f"\nChecking directory contents:")
    dir_path = os.path.dirname(checkpoint_path)
    if os.path.exists(dir_path):
        files = os.listdir(dir_path)
        print(f"Files in {dir_path}:")
        for f in files:
            print(f"  - {f}")
    else:
        print(f"Directory doesn't exist: {dir_path}")

❌ Checkpoint NOT found at: /workspace/CoT_Exploration/models/CODI-llama3.2-1b/pytorch_model.bin

Checking directory contents:
Directory doesn't exist: /workspace/CoT_Exploration/models/CODI-llama3.2-1b


In [134]:
# Verify CODI weights are loaded
print("Checking if CODI weights are loaded:")
print(f"Model checkpoint dir: {llama_model_args.ckpt_dir}")
print(f"BoT token ID: {llama_model.bot_id}")
print(f"EoT token ID: {llama_model.eot_id}")
print(f"Using projection: {llama_training_args.use_prj}")
print(f"Projection dim: {llama_training_args.prj_dim if llama_training_args.use_prj else 'N/A'}")

# Check if projection layer exists
if hasattr(llama_model, 'prj'):
    print(f"✓ Projection layer found: {llama_model.prj}")
else:
    print("✗ No projection layer - this might be the problem!")

Checking if CODI weights are loaded:
Model checkpoint dir: None
BoT token ID: 128257
EoT token ID: 128258
Using projection: True
Projection dim: 2048
✓ Projection layer found: Sequential(
  (0): Dropout(p=0.0, inplace=False)
  (1): Linear(in_features=2048, out_features=2048, bias=True)
  (2): GELU(approximate='none')
  (3): Linear(in_features=2048, out_features=2048, bias=True)
  (ln): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
)


## Run LLaMA Forward Pass

In [99]:
question2 = "Kylar went to the store to buy glasses for his new apartment. One glass costs $5, but every second glass costs only 60% of the price. Kylar wants to buy 16 glasses. How much does he need to pay for them?"
answer2 = 64

In [135]:
print("="*80)
print("LLaMA: Running inference with Chain-of-Thought")
print("="*80)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 1

question = question
questions = [question]

# Tokenize with BoT token
if llama_training_args.remove_eos:
    bot_tensor = torch.tensor([llama_model.bot_id], dtype=torch.long).expand(batch_size, 1).to(device)
else:
    bot_tensor = torch.tensor([llama_tokenizer.eos_token_id, llama_model.bot_id], 
                              dtype=torch.long).expand(batch_size, 2).to(device)

inputs = llama_tokenizer(questions, return_tensors="pt", padding=False)
inputs = {k: v.to(device) for k, v in inputs.items()}
inputs["input_ids"] = torch.cat((inputs["input_ids"], bot_tensor), dim=1)
inputs["attention_mask"] = torch.cat((inputs["attention_mask"], torch.ones_like(bot_tensor)), dim=1)

print(f"WITH BoT - Input IDs: {inputs['input_ids']}")
print(f"WITH BoT - Last 5 token IDs: {inputs['input_ids'][0, -5:].tolist()}")
print(f"WITH BoT - BoT token ID should be: {llama_model.bot_id}")
print(f"WITH BoT - Last token is BoT?: {inputs['input_ids'][0, -1].item() == llama_model.bot_id}")


print(f"Input shape: {inputs['input_ids'].shape}")

# Store only the 7 chain-of-thought hidden states
cot_hidden_states = []

with torch.no_grad():
    # Initial encoding (position 0: BoT)
    past_key_values = None
    outputs = llama_model.codi(
        input_ids=inputs["input_ids"],
        use_cache=True,
        output_hidden_states=True,
        past_key_values=past_key_values,
        attention_mask=inputs["attention_mask"]
    )
    past_key_values = outputs.past_key_values
    
    # 🔥 CRITICAL: Only take the LAST position, not all positions
    latent_embd = outputs.hidden_states[-1][:, -1:, :]  # Shape: [batch, 1, hidden_dim]
    
    # Store BoT position (BEFORE projection)
    cot_hidden_states.append(latent_embd.clone())
    print(f"✓ Position 0 (BoT): shape={latent_embd.shape}")
    
    # Apply initial projection
    if llama_training_args.use_prj:
        latent_embd = llama_model.prj(latent_embd)
    
    # Chain-of-Thought iterations (positions 1-6)
    for i in range(llama_training_args.inf_latent_iterations):
        outputs = llama_model.codi(
            inputs_embeds=latent_embd,
            use_cache=True,
            output_hidden_states=True,
            past_key_values=past_key_values
        )
        past_key_values = outputs.past_key_values
        
        # 🔥 CRITICAL: Only take the LAST position
        latent_embd = outputs.hidden_states[-1][:, -1:, :]  # Shape: [batch, 1, hidden_dim]
        
        # Store BEFORE projection
        cot_hidden_states.append(latent_embd.clone())
        print(f"✓ Position {i+1} (T{i+1}): shape={latent_embd.shape}")
        
        # Apply projection for next iteration
        if llama_training_args.use_prj:
            latent_embd = llama_model.prj(latent_embd)

# Stack all 7 positions
llama_continuous_thoughts = torch.cat(cot_hidden_states, dim=1)
print(f"\nContinuous thoughts shape: {llama_continuous_thoughts.shape}")
print(f"Number of thought positions: {llama_continuous_thoughts.shape[1]}")



# Also create llama_latent_embd as alias for compatibility with other cells
llama_latent_embd = llama_continuous_thoughts

LLaMA: Running inference with Chain-of-Thought
WITH BoT - Input IDs: tensor([[128000,  18820,    295,    596,  78878,  11203,    220,    845,  19335,
            824,   1938,     13,   3005,  50777,   2380,    369,  17954,   1475,
           6693,    323,    293,   2094,  55404,   1354,    369,   1077,   4885,
           1475,   1938,    449,   3116,     13,   3005,  31878,    279,  27410,
            520,    279,  20957,      6,   3157,   7446,    369,    400,     17,
            824,   7878,  37085,  19151,     13,   2650,   1790,    304,  11441,
           1587,   1364,   1304,   1475,   1938,    520,    279,  20957,      6,
           3157,     30, 128257]], device='cuda:0')
WITH BoT - Last 5 token IDs: [20957, 6, 3157, 30, 128257]
WITH BoT - BoT token ID should be: 128257
WITH BoT - Last token is BoT?: True
Input shape: torch.Size([1, 66])
✓ Position 0 (BoT): shape=torch.Size([1, 1, 2048])
✓ Position 1 (T1): shape=torch.Size([1, 1, 2048])
✓ Position 2 (T2): shape=torch.Size([1, 1,

In [136]:
# Verify we have exactly 7 positions
if llama_continuous_thoughts.shape[1] != 7:
    print(f"\n❌ ERROR: Expected 7 positions, got {llama_continuous_thoughts.shape[1]}")
else:
    print(f"✓ SUCCESS: Collected exactly 7 CoT positions!")

✓ SUCCESS: Collected exactly 7 CoT positions!


In [137]:
## 

cot_length=2

print("="*80)
print("LLaMA: Running inference with Custom-length Chain-of-Thought")
print("="*80)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 1

question = question
questions = [question]

# Tokenize with BoT token
if llama_training_args.remove_eos:
    bot_tensor = torch.tensor([llama_model.bot_id], dtype=torch.long).expand(batch_size, 1).to(device)
else:
    bot_tensor = torch.tensor([llama_tokenizer.eos_token_id, llama_model.bot_id], 
                              dtype=torch.long).expand(batch_size, 2).to(device)

inputs = llama_tokenizer(questions, return_tensors="pt", padding=False)
inputs = {k: v.to(device) for k, v in inputs.items()}
inputs["input_ids"] = torch.cat((inputs["input_ids"], bot_tensor), dim=1)
inputs["attention_mask"] = torch.cat((inputs["attention_mask"], torch.ones_like(bot_tensor)), dim=1)

print(f"WITH BoT - Input IDs: {inputs['input_ids']}")
print(f"WITH BoT - Last 5 token IDs: {inputs['input_ids'][0, -5:].tolist()}")
print(f"WITH BoT - BoT token ID should be: {llama_model.bot_id}")
print(f"WITH BoT - Last token is BoT?: {inputs['input_ids'][0, -1].item() == llama_model.bot_id}")


print(f"Input shape: {inputs['input_ids'].shape}")


print(f"Question: {question}")
print(f"Trying with CoT length of {cot_length}")

cot_hidden_states = []

with torch.no_grad():
    # Initial encoding (position 0: BoT)
    past_key_values = None
    outputs = llama_model.codi(
        input_ids=inputs["input_ids"],
        use_cache=True,
        output_hidden_states=True,
        past_key_values=past_key_values,
        attention_mask=inputs["attention_mask"]
    )
    past_key_values = outputs.past_key_values
    
    # 🔥 CRITICAL: Only take the LAST position, not all positions
    latent_embd = outputs.hidden_states[-1][:, -1:, :]  # Shape: [batch, 1, hidden_dim]
    
    # Store BoT position (BEFORE projection)
    cot_hidden_states.append(latent_embd.clone())
    print(f"✓ Position 0 (BoT): shape={latent_embd.shape}")
    
    # Apply initial projection
    if llama_training_args.use_prj:
        latent_embd = llama_model.prj(latent_embd)
    
    # Chain-of-Thought iterations (positions 1-6)
    for i in range(cot_length):
        outputs = llama_model.codi(
            inputs_embeds=latent_embd,
            use_cache=True,
            output_hidden_states=True,
            past_key_values=past_key_values
        )
        past_key_values = outputs.past_key_values
        
        # 🔥 CRITICAL: Only take the LAST position
        latent_embd = outputs.hidden_states[-1][:, -1:, :]  # Shape: [batch, 1, hidden_dim]
        
        # Store BEFORE projection
        cot_hidden_states.append(latent_embd.clone())
        print(f"✓ Position {i+1} (T{i+1}): shape={latent_embd.shape}")
        
        # Apply projection for next iteration
        if llama_training_args.use_prj:
            latent_embd = llama_model.prj(latent_embd)

# Stack all positions
llama_continuous_thoughts = torch.cat(cot_hidden_states, dim=1)
print(f"\nContinuous thoughts shape: {llama_continuous_thoughts.shape}")
print(f"Number of thought positions: {llama_continuous_thoughts.shape[1]}")



# Also create llama_latent_embd as alias for compatibility with other cells
llama_latent_embd = llama_continuous_thoughts

LLaMA: Running inference with Custom-length Chain-of-Thought
WITH BoT - Input IDs: tensor([[128000,  18820,    295,    596,  78878,  11203,    220,    845,  19335,
            824,   1938,     13,   3005,  50777,   2380,    369,  17954,   1475,
           6693,    323,    293,   2094,  55404,   1354,    369,   1077,   4885,
           1475,   1938,    449,   3116,     13,   3005,  31878,    279,  27410,
            520,    279,  20957,      6,   3157,   7446,    369,    400,     17,
            824,   7878,  37085,  19151,     13,   2650,   1790,    304,  11441,
           1587,   1364,   1304,   1475,   1938,    520,    279,  20957,      6,
           3157,     30, 128257]], device='cuda:0')
WITH BoT - Last 5 token IDs: [20957, 6, 3157, 30, 128257]
WITH BoT - BoT token ID should be: 128257
WITH BoT - Last token is BoT?: True
Input shape: torch.Size([1, 66])
Question: Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every d

## Decode LLaMA Continuous Thoughts (like section5_analysis.py)

In [138]:
print("="*80)
print("LLaMA: Decoding Continuous Thoughts")
print("="*80)

# Note: HuggingFace models don't have projection layers, decode directly
print(f"Continuous thoughts shape: {llama_continuous_thoughts.shape}")

# Decode first positions directly from continuous thoughts
print("\nDecoded tokens from continuous thoughts:")
llama_decoded_tokens = []
llama_number_positions = []

for i in range(min(50, llama_continuous_thoughts.shape[1])):
    logits = llama_model.codi.lm_head(llama_continuous_thoughts[:, i, :])
    top1_token_id = torch.argmax(logits, dim=-1).item()
    top1_token_str = llama_tokenizer.decode([top1_token_id])

    is_number = bool(number_regex.match(top1_token_str))
    pos_type = "BoT" if i == 0 else f"T{i}"

    llama_decoded_tokens.append((i, top1_token_id, top1_token_str, is_number))
    if is_number:
        llama_number_positions.append(i)

    marker = " ← NUMBER!" if is_number else ""
    print(f"  {pos_type:4s} [pos={i:2d}]: token_id={top1_token_id:5d} → '{top1_token_str}'{marker}")

print(f"\n✓ Numbers detected at positions: {llama_number_positions}")
print(f"✓ Total numbers: {len(llama_number_positions)}")

LLaMA: Decoding Continuous Thoughts
Continuous thoughts shape: torch.Size([1, 3, 2048])

Decoded tokens from continuous thoughts:
  BoT  [pos= 0]: token_id=   24 → '9' ← NUMBER!
  T1   [pos= 1]: token_id=   22 → '7' ← NUMBER!
  T2   [pos= 2]: token_id=   10 → '+'

✓ Numbers detected at positions: [0, 1]
✓ Total numbers: 2


In [139]:
print("="*80)
print("LLaMA: Generate Final Answer")
print("="*80)
with torch.no_grad():
    # Signal end-of-thought with EOT token
    if llama_training_args.remove_eos:
        eot_tensor = torch.tensor([[llama_model.eot_id]], dtype=torch.long).expand(batch_size, 1, 1).to(device)
    else:
        eot_tensor = torch.tensor([[llama_tokenizer.eos_token_id, llama_model.eot_id]], 
                                   dtype=torch.long).expand(batch_size, 1, 2).to(device)
    
    eot_emb = llama_model.get_embd(llama_model.codi, llama_model.model_name)(eot_tensor).squeeze(1)
    output = eot_emb
    
    # Generate answer tokens autoregressively
    full_answer = None
    pred_tokens = []
    max_length = 256
    
    for step in range(max_length):
        out = llama_model.codi(
            inputs_embeds=output,
            output_hidden_states=False,
            attention_mask=None,
            use_cache=True,
            output_attentions=False,
            past_key_values=past_key_values
        )
        past_key_values = out.past_key_values
        
        # Get logits for vocabulary tokens only
        logits = out.logits[:, -1, :llama_model.codi.config.vocab_size-1]
        next_token_id = torch.argmax(logits, dim=-1).item()
        
        pred_tokens.append(next_token_id)
        
        # Decode current token
        current_token_str = llama_tokenizer.decode([next_token_id])
        
        # Stop if EOS token
        if next_token_id == llama_tokenizer.eos_token_id:
            print(f"Stopped at step {step} (EOS token)")
            break
        
        # Stop immediately after generating a number token
        #if number_regex.match(current_token_str.strip()):
        #    print(f"Stopped at step {step} (found number: '{current_token_str}')")
        #    full_answer=current_token_str
        #    break
        
        # Hard limit to prevent loops
        if step >= 49:
            print(f"Stopped at step {step} (max length)")
            break
        
        # Prepare next input
        output = llama_model.get_embd(llama_model.codi, llama_model.model_name)(
            torch.tensor([[next_token_id]], device=device)
        )
    
    
    # Extract numerical answer
    def extract_answer_number(text):
        text = text.replace(',', '')
        numbers = [s for s in re.findall(r'-?\d+\.?\d*', text)]
        if not numbers:
            return None
        return float(numbers[0])

    # Decode the full answer from pred_tokens
    full_answer = llama_tokenizer.decode(pred_tokens, skip_special_tokens=True)
    print(f"\nFull decoded answer: {full_answer}")
    
    predicted_number = extract_answer_number(full_answer)
    print(f"\nExtracted numerical answer: {predicted_number}")
    print(f"Expected answer: {answer}")
    

LLaMA: Generate Final Answer
Stopped at step 49 (max length)

Full decoded answer: The answer is: 18The answer is: 36The answer is: 36The answer is: 36The answer is: 72The answer is: 72The answer is: 144The

Extracted numerical answer: 18.0
Expected answer: 18


In [140]:
import torch, re

def run_continuous_cot_codi_llama(
    llama_model,
    llama_tokenizer,
    llama_training_args,
    question: str,
    answer: str,
    cot_length: int = 2,
    device: str = None,
    max_answer_len: int = 256,
):
    """
    Run CODI-LLaMA continuous chain-of-thought inference, decoding, and evaluation
    exactly as implemented in the notebook.

    Parameters
    ----------
    llama_model : model with .codi, .bot_id, .eot_id, .prj, .get_embd, .model_name
    llama_tokenizer : tokenizer matching model
    llama_training_args : object with .remove_eos and .use_prj attributes
    question : str
    answer : str (reference numerical or textual answer)
    cot_length : int, number of CoT latent iterations
    device : torch.device or 'cuda'/'cpu' (auto if None)
    max_answer_len : int, maximum generated answer tokens

    Returns
    -------
    dict containing:
        - continuous_thoughts : torch.Tensor of latent states
        - decoded_tokens : list of (pos, token_id, token_str, is_number)
        - full_answer : str
        - predicted_number : float or None
        - match : bool or None
    """
    number_regex = re.compile(r"^-?\d+\.?\d*$")

    print("=" * 80)
    print("LLaMA: Running inference with Custom-length Chain-of-Thought")
    print("=" * 80)

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device(device)

    batch_size = 1
    questions = [question]

    # --- Build input with BoT token(s)
    if llama_training_args.remove_eos:
        bot_tensor = torch.tensor([llama_model.bot_id], dtype=torch.long).expand(batch_size, 1).to(device)
    else:
        bot_tensor = torch.tensor([llama_tokenizer.eos_token_id, llama_model.bot_id],
                                  dtype=torch.long).expand(batch_size, 2).to(device)

    inputs = llama_tokenizer(questions, return_tensors="pt", padding=False)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    inputs["input_ids"] = torch.cat((inputs["input_ids"], bot_tensor), dim=1)
    inputs["attention_mask"] = torch.cat((inputs["attention_mask"], torch.ones_like(bot_tensor)), dim=1)

    print(f"Question: {question}")
    print(f"Trying with CoT length of {cot_length}")

    cot_hidden_states = []

    with torch.no_grad():
        past_key_values = None
        outputs = llama_model.codi(
            input_ids=inputs["input_ids"],
            use_cache=True,
            output_hidden_states=True,
            past_key_values=past_key_values,
            attention_mask=inputs["attention_mask"]
        )
        past_key_values = outputs.past_key_values
        latent_embd = outputs.hidden_states[-1][:, -1:, :]  # BoT position
        cot_hidden_states.append(latent_embd.clone())
        print(f"✓ Position 0 (BoT): shape={latent_embd.shape}")

        if llama_training_args.use_prj:
            latent_embd = llama_model.prj(latent_embd)

        # Chain of Thought
        for i in range(cot_length):
            outputs = llama_model.codi(
                inputs_embeds=latent_embd,
                use_cache=True,
                output_hidden_states=True,
                past_key_values=past_key_values
            )
            past_key_values = outputs.past_key_values
            latent_embd = outputs.hidden_states[-1][:, -1:, :]
            cot_hidden_states.append(latent_embd.clone())
            print(f"✓ Position {i+1} (T{i+1}): shape={latent_embd.shape}")

            if llama_training_args.use_prj:
                latent_embd = llama_model.prj(latent_embd)

    llama_continuous_thoughts = torch.cat(cot_hidden_states, dim=1)
    print(f"\nContinuous thoughts shape: {llama_continuous_thoughts.shape}")

    # --- Decode Continuous Thoughts
    print("=" * 80)
    print("LLaMA: Decoding Continuous Thoughts")
    print("=" * 80)
    llama_decoded_tokens = []
    llama_number_positions = []

    for i in range(llama_continuous_thoughts.shape[1]):
        logits = llama_model.codi.lm_head(llama_continuous_thoughts[:, i, :])
        top1_token_id = torch.argmax(logits, dim=-1).item()
        top1_token_str = llama_tokenizer.decode([top1_token_id])
        is_number = bool(number_regex.match(top1_token_str.strip()))
        pos_type = "BoT" if i == 0 else f"T{i}"
        llama_decoded_tokens.append((i, top1_token_id, top1_token_str, is_number))
        if is_number:
            llama_number_positions.append(i)
        marker = " ← NUMBER!" if is_number else ""
        print(f"  {pos_type:4s} [pos={i:2d}]: token_id={top1_token_id:5d} → '{top1_token_str}'{marker}")

    print(f"\n✓ Numbers detected at positions: {llama_number_positions}")
    print(f"✓ Total numbers: {len(llama_number_positions)}")

    # --- Generate Final Answer
    print("=" * 80)
    print("LLaMA: Generate Final Answer")
    print("=" * 80)
    with torch.no_grad():
        if llama_training_args.remove_eos:
            eot_tensor = torch.tensor([[llama_model.eot_id]], dtype=torch.long).expand(batch_size, 1, 1).to(device)
        else:
            eot_tensor = torch.tensor([[llama_tokenizer.eos_token_id, llama_model.eot_id]],
                                      dtype=torch.long).expand(batch_size, 1, 2).to(device)
        eot_emb = llama_model.get_embd(llama_model.codi, llama_model.model_name)(eot_tensor).squeeze(1)
        output = eot_emb

        full_answer = ""
        pred_tokens = []

        for step in range(max_answer_len):
            out = llama_model.codi(
                inputs_embeds=output,
                output_hidden_states=False,
                use_cache=True,
                output_attentions=False,
                past_key_values=past_key_values
            )
            past_key_values = out.past_key_values
            logits = out.logits[:, -1, :llama_model.codi.config.vocab_size - 1]
            next_token_id = torch.argmax(logits, dim=-1).item()
            pred_tokens.append(next_token_id)
            tok_str = llama_tokenizer.decode([next_token_id])
            if next_token_id == llama_tokenizer.eos_token_id:
                print(f"Stopped at step {step} (EOS token)")
                break
            if step >= 49:
                print(f"Stopped at step {step} (max length)")
                break
            output = llama_model.get_embd(llama_model.codi, llama_model.model_name)(
                torch.tensor([[next_token_id]], device=device)
            )

        # decode
        full_answer = llama_tokenizer.decode(pred_tokens, skip_special_tokens=True)
        print(f"\nFull decoded answer: {full_answer}")

        def extract_answer_number(text):
            text = text.replace(',', '')
            numbers = re.findall(r'-?\d+\.?\d*', text)
            return float(numbers[0]) if numbers else None

        predicted_number = extract_answer_number(full_answer)
        print(f"\nExtracted numerical answer: {predicted_number}")
        print(f"Expected answer: \n{answer}")

    # --- Evaluate
    match = None
    if predicted_number is not None:
        try:
            ref = float(answer)
            match = abs(predicted_number - ref) < 1e-6
        except ValueError:
            match = (full_answer.strip() == str(answer).strip())

    return {
        "continuous_thoughts": llama_continuous_thoughts,
        "decoded_tokens": llama_decoded_tokens,
        "full_answer": full_answer,
        "predicted_number": predicted_number,
        "reference_answer": answer,
        "match": match,
    }


In [141]:
# Example usage (uncomment + adapt inside your notebook):
res = run_continuous_cot_codi_llama(
    llama_model=llama_model,
    llama_tokenizer=llama_tokenizer,
    llama_training_args=llama_training_args,
    question=question,
    answer=answer,
    #question="Q: 12 + 5 = ?",
    #answer="17",
    cot_length=1,
)
print(res)


LLaMA: Running inference with Custom-length Chain-of-Thought
Question: Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?
Trying with CoT length of 1
✓ Position 0 (BoT): shape=torch.Size([1, 1, 2048])
✓ Position 1 (T1): shape=torch.Size([1, 1, 2048])

Continuous thoughts shape: torch.Size([1, 2, 2048])
LLaMA: Decoding Continuous Thoughts
  BoT  [pos= 0]: token_id=   24 → '9' ← NUMBER!
  T1   [pos= 1]: token_id=   22 → '7' ← NUMBER!

✓ Numbers detected at positions: [0, 1]
✓ Total numbers: 2
LLaMA: Generate Final Answer
Stopped at step 49 (max length)

Full decoded answer: The answer is: 18The answer is: 36The answer is: 36The answer is: 36The answer is: 72The answer is: 72The answer is: 144The

Extracted numerical answer: 18.0
Expected answer: 
18
{'continuous

In [142]:
run_continuous_cot_codi_llama(
    llama_model=llama_model,
    llama_tokenizer=llama_tokenizer,
    llama_training_args=llama_training_args,
    question=question,
    answer=answer,
    #question="Q: 12 + 5 = ?",
    #answer="17",
    cot_length=0,
)

LLaMA: Running inference with Custom-length Chain-of-Thought
Question: Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?
Trying with CoT length of 0
✓ Position 0 (BoT): shape=torch.Size([1, 1, 2048])

Continuous thoughts shape: torch.Size([1, 1, 2048])
LLaMA: Decoding Continuous Thoughts
  BoT  [pos= 0]: token_id=   24 → '9' ← NUMBER!

✓ Numbers detected at positions: [0]
✓ Total numbers: 1
LLaMA: Generate Final Answer
Stopped at step 49 (max length)

Full decoded answer: 9/4=2.25>> <<16-3-2.25=10.75>>The answer is: 17The answer is: 34The answer is: 34The answer is: 68The

Extracted numerical answer: 9.0
Expected answer: 
18


{'continuous_thoughts': tensor([[[-0.1953,  2.3906, -0.5430,  ...,  1.5703, -2.6562, -0.9023]]],
        device='cuda:0', dtype=torch.bfloat16),
 'decoded_tokens': [(0, 24, '9', True)],
 'full_answer': '9/4=2.25>> <<16-3-2.25=10.75>>The answer is: 17The answer is: 34The answer is: 34The answer is: 68The',
 'predicted_number': 9.0,
 'reference_answer': 18,
 'match': False}

In [157]:
def run_continuous_cot_codi_llama_intervene(
    question: str,
    answer: str,
    cot_length: int = 6,
    bot: bool = True,
    target_token: str = None,
    k: float = 1.0,
    device=None
):
    """
    Run CODI chain-of-thought on a preloaded CODI-LLaMA model, optionally with
    causal projection intervention (replace activations with target token embedding).

    Args:
        question (str): The input question for inference.
        answer (str): Reference (expected) answer for comparison.
        cot_length (int): Number of CoT latent iterations.
        bot (bool): Whether to include BoT token at the end of input.
        target_token (str, optional): If given, perform causal intervention
            by replacing projection of number-token activations with the embedding
            of this token.
        k (float): Scaling strength for the projection replacement.
        device: torch.device to use (GPU or CPU).

    Returns:
        dict with decoded intermediate tokens, generated answer, and reference comparison.
    """
    print("=" * 80)
    print("LLaMA: Running Continuous Chain-of-Thought" +
          (" with CAUSAL INTERVENTION" if target_token else ""))
    print("=" * 80)

    device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
    batch_size = 1
    torch.manual_seed(42)

    number_regex = re.compile(r"^-?\d+\.?\d*$")

    # Prepare question input
    questions = [question]
    if bot:
        if llama_training_args.remove_eos:
            bot_tensor = torch.tensor([llama_model.bot_id], dtype=torch.long).expand(batch_size, 1).to(device)
        else:
            bot_tensor = torch.tensor(
                [llama_tokenizer.eos_token_id, llama_model.bot_id],
                dtype=torch.long
            ).expand(batch_size, 2).to(device)
    else:
        bot_tensor = None

    inputs = llama_tokenizer(questions, return_tensors="pt", padding=False)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    if bot:
        inputs["input_ids"] = torch.cat((inputs["input_ids"], bot_tensor), dim=1)
        inputs["attention_mask"] = torch.cat(
            (inputs["attention_mask"], torch.ones_like(bot_tensor)), dim=1
        )

    print(f"Question: {question}")
    print(f"CoT length: {cot_length}")

    # Prepare intervention
    intervention_enabled = target_token is not None
    if intervention_enabled:
        target_token_id = llama_tokenizer.encode(target_token, add_special_tokens=False)[0]
        embedding_layer = llama_model.codi.get_input_embeddings()
        target_embd = embedding_layer(torch.tensor([target_token_id], device=device))
        print(f"\nCAUSAL INTERVENTION enabled → target_token='{target_token}', k={k}")
    else:
        print("\nRunning baseline CoT (no intervention).")

    # Run chain-of-thought
    cot_hidden_states = []
    llama_interventions = []
    intervened_positions = []

    with torch.no_grad():
        past_key_values = None
        outputs = llama_model.codi(
            input_ids=inputs["input_ids"],
            use_cache=True,
            output_hidden_states=True,
            attention_mask=inputs["attention_mask"]
        )
        past_key_values = outputs.past_key_values
        latent_embd = outputs.hidden_states[-1][:, -1:, :]
        cot_hidden_states.append(latent_embd.clone())

        if llama_training_args.use_prj:
            latent_embd = llama_model.prj(latent_embd)

        for i in range(cot_length):
            outputs = llama_model.codi(
                inputs_embeds=latent_embd,
                use_cache=True,
                output_hidden_states=True,
                past_key_values=past_key_values
            )
            past_key_values = outputs.past_key_values
            latent_embd = outputs.hidden_states[-1][:, -1:, :]
            cot_hidden_states.append(latent_embd.clone())

            # decode predicted token
            logits = llama_model.codi.lm_head(latent_embd.squeeze(1))
            top1_token_id = torch.argmax(logits, dim=-1).item()
            top1_token_str = llama_tokenizer.decode([top1_token_id])
            is_number = bool(number_regex.match(top1_token_str))
            pos_type = "BoT" if i == 0 else f"T{i}"

            # Causal projection intervention
            if intervention_enabled and is_number:
                predicted_embd = embedding_layer(torch.tensor([top1_token_id], device=device))
                A = latent_embd.squeeze(1)

                E_pred_norm = predicted_embd / torch.norm(predicted_embd, dim=-1, keepdim=True)
                E_target_norm = target_embd / torch.norm(target_embd, dim=-1, keepdim=True)
                proj_predicted = torch.sum(A * E_pred_norm, dim=-1, keepdim=True) * E_pred_norm
                proj_target = torch.norm(proj_predicted, dim=-1, keepdim=True) * E_target_norm
                A_modified = A - proj_predicted + k * proj_target

                logits_modified = llama_model.codi.lm_head(A_modified)
                new_token_id = torch.argmax(logits_modified, dim=-1).item()
                new_token_str = llama_tokenizer.decode([new_token_id])

                latent_embd = A_modified.unsqueeze(1)
                llama_interventions.append({
                    "position": i,
                    "predicted_token": top1_token_str,
                    "new_token": new_token_str,
                    "intervened": True
                })
                intervened_positions.append(i)
                print(f"{pos_type:4s} [{i:2d}]: '{top1_token_str}' → '{new_token_str}' ✓ INTERVENED")
            else:
                print(f"{pos_type:4s} [{i:2d}]: '{top1_token_str}'")

            if llama_training_args.use_prj:
                latent_embd = llama_model.prj(latent_embd)

    llama_continuous_thoughts = torch.cat(cot_hidden_states, dim=1)
    print(f"\nContinuous thoughts shape: {llama_continuous_thoughts.shape}")

    # Decode intermediate tokens
    decoded = []
    print("\nDecoded intermediate tokens:")
    for i in range(min(50, llama_continuous_thoughts.shape[1])):
        logits = llama_model.codi.lm_head(llama_continuous_thoughts[:, i, :])
        top1_token_id = torch.argmax(logits, dim=-1).item()
        tok_str = llama_tokenizer.decode([top1_token_id])
        print(f"  pos={i:2d}: '{tok_str}'")
        decoded.append(tok_str)

    # Generate final answer
    print("\nGenerating final answer...")
    if llama_training_args.remove_eos:
        eot_tensor = torch.tensor([[llama_model.eot_id]], dtype=torch.long).expand(batch_size, 1, 1).to(device)
    else:
        eot_tensor = torch.tensor(
            [[llama_tokenizer.eos_token_id, llama_model.eot_id]],
            dtype=torch.long
        ).expand(batch_size, 1, 2).to(device)

    eot_emb = llama_model.get_embd(llama_model.codi, llama_model.model_name)(eot_tensor).squeeze(1)
    output = eot_emb
    pred_tokens = []
    for step in range(256):
        out = llama_model.codi(inputs_embeds=output, use_cache=True, past_key_values=past_key_values)
        past_key_values = out.past_key_values
        logits = out.logits[:, -1, :llama_model.codi.config.vocab_size-1]
        next_token_id = torch.argmax(logits, dim=-1).item()
        pred_tokens.append(next_token_id)
        current_token_str = llama_tokenizer.decode([next_token_id])
        if next_token_id == llama_tokenizer.eos_token_id or re.match(number_regex, current_token_str.strip()) or step >= 49:
            break
        output = llama_model.get_embd(llama_model.codi, llama_model.model_name)(
            torch.tensor([[next_token_id]], device=device)
        )

    full_answer = llama_tokenizer.decode(pred_tokens, skip_special_tokens=True)
    print(f"\nGenerated answer: {full_answer}")

    def extract_number(text):
        nums = re.findall(r"-?\d+\.?\d*", text.replace(",", ""))
        return float(nums[0]) if nums else None

    predicted_number = extract_number(full_answer)
    match = str(predicted_number) == str(answer)
    print(f"Extracted number: {predicted_number}  |  Reference: {answer}  |  Match: {match}")

    return {
        "decoded_tokens": decoded,
        "generated_answer": full_answer,
        "predicted_number": predicted_number,
        "reference_answer": answer,
        "match": match,
        "interventions": llama_interventions,
        "intervened_positions": intervened_positions
    }

In [158]:
run_continuous_cot_codi_llama_intervene(
    question=question2,
    answer=answer2,
    cot_length=6
)


LLaMA: Running Continuous Chain-of-Thought
Question: Kylar went to the store to buy glasses for his new apartment. One glass costs $5, but every second glass costs only 60% of the price. Kylar wants to buy 16 glasses. How much does he need to pay for them?
CoT length: 6

Running baseline CoT (no intervention).
BoT  [ 0]: '40'
T1   [ 1]: '80'
T2   [ 2]: ' ('
T3   [ 3]: '24'
T4   [ 4]: '24'
T5   [ 5]: '>>'

Continuous thoughts shape: torch.Size([1, 7, 2048])

Decoded intermediate tokens:
  pos= 0: '8'
  pos= 1: '40'
  pos= 2: '80'
  pos= 3: ' ('
  pos= 4: '24'
  pos= 5: '24'
  pos= 6: '>>'

Generating final answer...

Generated answer: The answer is: 64
Extracted number: 64.0  |  Reference: 64  |  Match: False


{'decoded_tokens': ['8', '40', '80', ' (', '24', '24', '>>'],
 'generated_answer': 'The answer is: 64',
 'predicted_number': 64.0,
 'reference_answer': 64,
 'match': False,
 'interventions': [],
 'intervened_positions': []}

In [160]:
run_continuous_cot_codi_llama_intervene(
    question=question,
    answer=answer,
    cot_length=6,
    target_token="5"
)


LLaMA: Running Continuous Chain-of-Thought with CAUSAL INTERVENTION
Question: Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?
CoT length: 6

CAUSAL INTERVENTION enabled → target_token='5', k=1.0
BoT  [ 0]: '7' → ' Gib' ✓ INTERVENED
T1   [ 1]: '7' → '+' ✓ INTERVENED
T2   [ 2]: '-'
T3   [ 3]: '9' → '5' ✓ INTERVENED
T4   [ 4]: '9' → '5' ✓ INTERVENED
T5   [ 5]: '"'

Continuous thoughts shape: torch.Size([1, 7, 2048])

Decoded intermediate tokens:
  pos= 0: '9'
  pos= 1: '7'
  pos= 2: '7'
  pos= 3: '-'
  pos= 4: '9'
  pos= 5: '9'
  pos= 6: '"'

Generating final answer...

Generated answer: The answer is: 18
Extracted number: 18.0  |  Reference: 18  |  Match: False


{'decoded_tokens': ['9', '7', '7', '-', '9', '9', '"'],
 'generated_answer': 'The answer is: 18',
 'predicted_number': 18.0,
 'reference_answer': 18,
 'match': False,
 'interventions': [{'position': 0,
   'predicted_token': '7',
   'new_token': ' Gib',
   'intervened': True},
  {'position': 1,
   'predicted_token': '7',
   'new_token': '+',
   'intervened': True},
  {'position': 3,
   'predicted_token': '9',
   'new_token': '5',
   'intervened': True},
  {'position': 4,
   'predicted_token': '9',
   'new_token': '5',
   'intervened': True}],
 'intervened_positions': [0, 1, 3, 4]}

## LLaMA: Projection Intervention (Target Token = '5', k=3)

In [112]:
print("="*80)
print("LLaMA: CAUSAL Projection Intervention")
print("="*80)

# CRITICAL: Reset to ensure clean state
torch.manual_seed(42)  # For reproducibility
past_key_values = None  # Clear any cached state

target_token = '5'
k = 3

# Get target token embedding
target_token_id = llama_tokenizer.encode(target_token, add_special_tokens=False)[0]
embedding_layer = llama_model.codi.get_input_embeddings()
target_embd = embedding_layer(torch.tensor([target_token_id], device=device))

print(f"Target token: '{target_token}'")
print(f"Target token ID: {target_token_id}")
print(f"k (top-k intervention): {k}")
print(f"\nRunning CAUSAL intervention (affects downstream positions)...\n")

# Re-run the chain-of-thought with interventions
batch_size = 1

# Set pad token if needed
if llama_tokenizer.pad_token is None:
    llama_tokenizer.pad_token = llama_tokenizer.eos_token
    llama_tokenizer.pad_token_id = llama_tokenizer.eos_token_id

# Tokenize with BoT token
if llama_training_args.remove_eos:
    bot_tensor = torch.tensor([llama_model.bot_id], dtype=torch.long).expand(batch_size, 1).to(device)
else:
    bot_tensor = torch.tensor([llama_tokenizer.eos_token_id, llama_model.bot_id], 
                              dtype=torch.long).expand(batch_size, 2).to(device)

inputs = llama_tokenizer([question], return_tensors="pt", padding=False)
inputs = {k: v.to(device) for k, v in inputs.items()}
inputs["input_ids"] = torch.cat((inputs["input_ids"], bot_tensor), dim=1)
inputs["attention_mask"] = torch.cat((inputs["attention_mask"], torch.ones_like(bot_tensor)), dim=1)

llama_interventions = []
intervened_positions = []

with torch.no_grad():
    # Initial encoding (position 0: BoT)
    past_key_values = None
    outputs = llama_model.codi(
        input_ids=inputs["input_ids"],
        use_cache=True,
        output_hidden_states=True,
        past_key_values=past_key_values,
        attention_mask=inputs["attention_mask"]
    )
    past_key_values = outputs.past_key_values
    latent_embd = outputs.hidden_states[-1][:, -1:, :]
    
    # Check BoT position
    logits = llama_model.codi.lm_head(latent_embd.squeeze(1))
    top1_token_id = torch.argmax(logits, dim=-1).item()
    top1_token_str = llama_tokenizer.decode([top1_token_id])
    is_number = bool(number_regex.match(top1_token_str))
    
    # Intervene at BoT if it's a number
    if is_number:
        predicted_embd = embedding_layer(torch.tensor([top1_token_id], device=device))
        A = latent_embd.squeeze(1)
        
        # Normalize embeddings
        E_pred_norm = predicted_embd / torch.norm(predicted_embd, dim=-1, keepdim=True)
        E_target_norm = target_embd / torch.norm(target_embd, dim=-1, keepdim=True)
        
        # Projection intervention
        proj_predicted = torch.sum(A * E_pred_norm, dim=-1, keepdim=True) * E_pred_norm
        proj_target = torch.norm(proj_predicted, dim =-1, keepdim=True) * E_target_norm
        #proj_target = torch.sum(A * E_target_norm, dim=-1, keepdim=True) * E_target_norm
        A_modified = A - proj_predicted + k * proj_target
        
        # Check what it decodes to after intervention
        logits_modified = llama_model.codi.lm_head(A_modified)
        new_token_id = torch.argmax(logits_modified, dim=-1).item()
        new_token_str = llama_tokenizer.decode([new_token_id])
        
        # CAUSALLY apply the intervention
        latent_embd = A_modified.unsqueeze(1)
        
        llama_interventions.append({
            'position': 0,
            'predicted_token': top1_token_str,
            'new_token': new_token_str,
            'intervened': True
        })
        intervened_positions.append(0)
        print(f"BoT  [pos= 0]: '{top1_token_str}' → '{new_token_str}' ✓ INTERVENED (causal)")
    else:
        print(f"BoT  [pos= 0]: '{top1_token_str}' (not a number, no intervention)")
    
    # Apply initial projection
    if llama_training_args.use_prj:
        latent_embd = llama_model.prj(latent_embd)
    
    # Chain-of-Thought iterations (positions 1-6)
    for i in range(llama_training_args.inf_latent_iterations):
        outputs = llama_model.codi(
            inputs_embeds=latent_embd,
            use_cache=True,
            output_hidden_states=True,
            past_key_values=past_key_values
        )
        past_key_values = outputs.past_key_values
        latent_embd = outputs.hidden_states[-1][:, -1:, :]
        
        # Check this position
        logits = llama_model.codi.lm_head(latent_embd.squeeze(1))
        top1_token_id = torch.argmax(logits, dim=-1).item()
        top1_token_str = llama_tokenizer.decode([top1_token_id])
        is_number = bool(number_regex.match(top1_token_str))
        
        pos_type = f"T{i+1}"
        
        # Intervene if it's a number
        if is_number:
            predicted_embd = embedding_layer(torch.tensor([top1_token_id], device=device))
            A = latent_embd.squeeze(1)
            
            # Normalize embeddings
            E_pred_norm = predicted_embd / torch.norm(predicted_embd, dim=-1, keepdim=True)
            E_target_norm = target_embd / torch.norm(target_embd, dim=-1, keepdim=True)
            
            # Projection intervention
            proj_predicted = torch.sum(A * E_pred_norm, dim=-1, keepdim=True) * E_pred_norm
            proj_target = torch.sum(A * E_target_norm, dim=-1, keepdim=True) * E_target_norm
            A_modified = A - proj_predicted + k * proj_target
            
            # Check what it decodes to after intervention
            logits_modified = llama_model.codi.lm_head(A_modified)
            new_token_id = torch.argmax(logits_modified, dim=-1).item()
            new_token_str = llama_tokenizer.decode([new_token_id])
            
            # CAUSALLY apply the intervention
            latent_embd = A_modified.unsqueeze(1)
            
            llama_interventions.append({
                'position': i+1,
                'predicted_token': top1_token_str,
                'new_token': new_token_str,
                'intervened': True
            })
            intervened_positions.append(i+1)
            print(f"{pos_type:4s} [pos={i+1:2d}]: '{top1_token_str}' → '{new_token_str}' ✓ INTERVENED (causal)")
        else:
            print(f"{pos_type:4s} [pos={i+1:2d}]: '{top1_token_str}' (not a number, no intervention)")
        
        # Apply projection for next iteration
        if llama_training_args.use_prj:
            latent_embd = llama_model.prj(latent_embd)

print(f"\n{'='*80}")
print(f"✓ Total interventions: {len(llama_interventions)}")
print(f"✓ Intervened at positions: {intervened_positions}")
print(f"{'='*80}")

# Now generate the final answer with the intervened chain-of-thought
print("\nGenerating final answer with intervened CoT...")

# Signal end-of-thought with EOT token
if llama_training_args.remove_eos:
    eot_tensor = torch.tensor([[llama_model.eot_id]], dtype=torch.long).expand(batch_size, 1, 1).to(device)
else:
    eot_tensor = torch.tensor([[llama_tokenizer.eos_token_id, llama_model.eot_id]], 
                               dtype=torch.long).expand(batch_size, 1, 2).to(device)

eot_emb = llama_model.get_embd(llama_model.codi, llama_model.model_name)(eot_tensor).squeeze(1)
output = eot_emb

# Generate answer tokens
# Generate answer tokens
pred_tokens = []
found_number = False

for step in range(256):
    out = llama_model.codi(
        inputs_embeds=output,
        output_hidden_states=False,
        attention_mask=None,
        use_cache=True,
        output_attentions=False,
        past_key_values=past_key_values
    )
    past_key_values = out.past_key_values
    
    logits = out.logits[:, -1, :llama_model.codi.config.vocab_size-1]
    next_token_id = torch.argmax(logits, dim=-1).item()
    pred_tokens.append(next_token_id)
    
    # Decode current token
    current_token_str = llama_tokenizer.decode([next_token_id])
    
    # Stop if EOS token
    if next_token_id == llama_tokenizer.eos_token_id:
        print(f"Stopped at step {step} (EOS token)")
        break
    
    # Stop immediately after generating a number token
    if number_regex.match(current_token_str.strip()):
        print(f"Stopped at step {step} (found number: '{current_token_str}')")
        break
    
    # Hard limit to prevent loops
    if step >= 49:
        print(f"Stopped at step {step} (max length)")
        break
    
    output = llama_model.get_embd(llama_model.codi, llama_model.model_name)(
        torch.tensor([[next_token_id]], device=device)
    )
# Decode and extract answer
intervened_answer = llama_tokenizer.decode(pred_tokens, skip_special_tokens=True)
print(f"\nIntervened Answer: {intervened_answer}")

def extract_answer_number(text):
    text = text.replace(',', '')
    numbers = [s for s in re.findall(r'-?\d+\.?\d*', text)]
    if not numbers:
        return None
    return float(numbers[0])

intervened_number = extract_answer_number(intervened_answer)
print(f"Intervened numerical answer: {intervened_number}")
print(f"\n(Compare to original expected: {answer})")

LLaMA: CAUSAL Projection Intervention
Target token: '5'
Target token ID: 20
k (top-k intervention): 3

Running CAUSAL intervention (affects downstream positions)...

BoT  [pos= 0]: '9' → '5' ✓ INTERVENED (causal)
T1   [pos= 1]: '7' → '5' ✓ INTERVENED (causal)
T2   [pos= 2]: '7' → '5' ✓ INTERVENED (causal)
T3   [pos= 3]: '-' (not a number, no intervention)
T4   [pos= 4]: '9' → '5' ✓ INTERVENED (causal)
T5   [pos= 5]: '9' → '5' ✓ INTERVENED (causal)
T6   [pos= 6]: '"' (not a number, no intervention)

✓ Total interventions: 5
✓ Intervened at positions: [0, 1, 2, 4, 5]

Generating final answer with intervened CoT...
Stopped at step 5 (found number: '18')

Intervened Answer: The answer is: 18
Intervened numerical answer: 18.0

(Compare to original expected: 18)


In [113]:
##REPLACEMENT INTERVENTION 

print("="*80)
print("LLaMA: CAUSAL Projection Intervention")
print("="*80)

# CRITICAL: Reset to ensure clean state
torch.manual_seed(42)  # For reproducibility
past_key_values = None  # Clear any cached state

target_token = '5'
k = 3

# Get target token embedding
target_token_id = llama_tokenizer.encode(target_token, add_special_tokens=False)[0]
embedding_layer = llama_model.codi.get_input_embeddings()
target_embd = embedding_layer(torch.tensor([target_token_id], device=device))

print(f"Target token: '{target_token}'")
print(f"Target token ID: {target_token_id}")
print(f"k (top-k intervention): {k}")
print(f"\nRunning CAUSAL intervention (affects downstream positions)...\n")

# Re-run the chain-of-thought with interventions
batch_size = 1

# Set pad token if needed
if llama_tokenizer.pad_token is None:
    llama_tokenizer.pad_token = llama_tokenizer.eos_token
    llama_tokenizer.pad_token_id = llama_tokenizer.eos_token_id

# Tokenize with BoT token
if llama_training_args.remove_eos:
    bot_tensor = torch.tensor([llama_model.bot_id], dtype=torch.long).expand(batch_size, 1).to(device)
else:
    bot_tensor = torch.tensor([llama_tokenizer.eos_token_id, llama_model.bot_id], 
                              dtype=torch.long).expand(batch_size, 2).to(device)

inputs = llama_tokenizer([question], return_tensors="pt", padding=False)
inputs = {k: v.to(device) for k, v in inputs.items()}
inputs["input_ids"] = torch.cat((inputs["input_ids"], bot_tensor), dim=1)
inputs["attention_mask"] = torch.cat((inputs["attention_mask"], torch.ones_like(bot_tensor)), dim=1)

llama_interventions = []
intervened_positions = []

with torch.no_grad():
    # Initial encoding (position 0: BoT)
    past_key_values = None
    outputs = llama_model.codi(
        input_ids=inputs["input_ids"],
        use_cache=True,
        output_hidden_states=True,
        past_key_values=past_key_values,
        attention_mask=inputs["attention_mask"]
    )
    past_key_values = outputs.past_key_values
    latent_embd = outputs.hidden_states[-1][:, -1:, :]
    
    # Check BoT position
    logits = llama_model.codi.lm_head(latent_embd.squeeze(1))
    top1_token_id = torch.argmax(logits, dim=-1).item()
    top1_token_str = llama_tokenizer.decode([top1_token_id])
    is_number = bool(number_regex.match(top1_token_str))
    
    # Intervene at BoT if it's a number
    if is_number:
        predicted_embd = embedding_layer(torch.tensor([top1_token_id], device=device))
        A = latent_embd.squeeze(1)
        
        # Normalize embeddings
        E_pred_norm = predicted_embd / torch.norm(predicted_embd, dim=-1, keepdim=True)
        E_target_norm = target_embd / torch.norm(target_embd, dim=-1, keepdim=True)
        
        # Projection intervention
        proj_predicted = torch.sum(A * E_pred_norm, dim=-1, keepdim=True) * E_pred_norm
        proj_target = torch.norm(proj_predicted, dim =-1, keepdim=True) * E_target_norm
        #proj_target = torch.sum(A * E_target_norm, dim=-1, keepdim=True) * E_target_norm
        A_modified = A - proj_predicted + k * proj_target
        
        # Check what it decodes to after intervention
        logits_modified = llama_model.codi.lm_head(A_modified)
        new_token_id = torch.argmax(logits_modified, dim=-1).item()
        new_token_str = llama_tokenizer.decode([new_token_id])
        
        # CAUSALLY apply the intervention
        latent_embd = A_modified.unsqueeze(1)
        
        llama_interventions.append({
            'position': 0,
            'predicted_token': top1_token_str,
            'new_token': new_token_str,
            'intervened': True
        })
        intervened_positions.append(0)
        print(f"BoT  [pos= 0]: '{top1_token_str}' → '{new_token_str}' ✓ INTERVENED (causal)")
    else:
        print(f"BoT  [pos= 0]: '{top1_token_str}' (not a number, no intervention)")
    
    # Apply initial projection
    if llama_training_args.use_prj:
        latent_embd = llama_model.prj(latent_embd)
    
    # Chain-of-Thought iterations (positions 1-6)
    for i in range(llama_training_args.inf_latent_iterations):
        outputs = llama_model.codi(
            inputs_embeds=latent_embd,
            use_cache=True,
            output_hidden_states=True,
            past_key_values=past_key_values
        )
        past_key_values = outputs.past_key_values
        latent_embd = outputs.hidden_states[-1][:, -1:, :]
        
        # Check this position
        logits = llama_model.codi.lm_head(latent_embd.squeeze(1))
        top1_token_id = torch.argmax(logits, dim=-1).item()
        top1_token_str = llama_tokenizer.decode([top1_token_id])
        is_number = bool(number_regex.match(top1_token_str))
        
        pos_type = f"T{i+1}"
        
        # Intervene if it's a number
        if is_number:
            predicted_embd = embedding_layer(torch.tensor([top1_token_id], device=device))
            A = latent_embd.squeeze(1)
            
            # Normalize embeddings
            E_pred_norm = predicted_embd / torch.norm(predicted_embd, dim=-1, keepdim=True)
            E_target_norm = target_embd / torch.norm(target_embd, dim=-1, keepdim=True)
            
            # Projection intervention
            proj_predicted = torch.sum(A * E_pred_norm, dim=-1, keepdim=True) * E_pred_norm
            proj_target = torch.sum(A * E_target_norm, dim=-1, keepdim=True) * E_target_norm
            A_modified = A - proj_predicted + k * proj_target
            
            # Check what it decodes to after intervention
            logits_modified = llama_model.codi.lm_head(A_modified)
            new_token_id = torch.argmax(logits_modified, dim=-1).item()
            new_token_str = llama_tokenizer.decode([new_token_id])
            
            # CAUSALLY apply the intervention
            latent_embd = A_modified.unsqueeze(1)
            
            llama_interventions.append({
                'position': i+1,
                'predicted_token': top1_token_str,
                'new_token': new_token_str,
                'intervened': True
            })
            intervened_positions.append(i+1)
            print(f"{pos_type:4s} [pos={i+1:2d}]: '{top1_token_str}' → '{new_token_str}' ✓ INTERVENED (causal)")
        else:
            print(f"{pos_type:4s} [pos={i+1:2d}]: '{top1_token_str}' (not a number, no intervention)")
        
        # Apply projection for next iteration
        if llama_training_args.use_prj:
            latent_embd = llama_model.prj(latent_embd)

print(f"\n{'='*80}")
print(f"✓ Total interventions: {len(llama_interventions)}")
print(f"✓ Intervened at positions: {intervened_positions}")
print(f"{'='*80}")

# Now generate the final answer with the intervened chain-of-thought
print("\nGenerating final answer with intervened CoT...")

# Signal end-of-thought with EOT token
if llama_training_args.remove_eos:
    eot_tensor = torch.tensor([[llama_model.eot_id]], dtype=torch.long).expand(batch_size, 1, 1).to(device)
else:
    eot_tensor = torch.tensor([[llama_tokenizer.eos_token_id, llama_model.eot_id]], 
                               dtype=torch.long).expand(batch_size, 1, 2).to(device)

eot_emb = llama_model.get_embd(llama_model.codi, llama_model.model_name)(eot_tensor).squeeze(1)
output = eot_emb

# Generate answer tokens
# Generate answer tokens
pred_tokens = []
found_number = False

for step in range(256):
    out = llama_model.codi(
        inputs_embeds=output,
        output_hidden_states=False,
        attention_mask=None,
        use_cache=True,
        output_attentions=False,
        past_key_values=past_key_values
    )
    past_key_values = out.past_key_values
    
    logits = out.logits[:, -1, :llama_model.codi.config.vocab_size-1]
    next_token_id = torch.argmax(logits, dim=-1).item()
    pred_tokens.append(next_token_id)
    
    # Decode current token
    current_token_str = llama_tokenizer.decode([next_token_id])
    
    # Stop if EOS token
    if next_token_id == llama_tokenizer.eos_token_id:
        print(f"Stopped at step {step} (EOS token)")
        break
    
    # Stop immediately after generating a number token
    if number_regex.match(current_token_str.strip()):
        print(f"Stopped at step {step} (found number: '{current_token_str}')")
        break
    
    # Hard limit to prevent loops
    if step >= 49:
        print(f"Stopped at step {step} (max length)")
        break
    
    output = llama_model.get_embd(llama_model.codi, llama_model.model_name)(
        torch.tensor([[next_token_id]], device=device)
    )
# Decode and extract answer
intervened_answer = llama_tokenizer.decode(pred_tokens, skip_special_tokens=True)
print(f"\nIntervened Answer: {intervened_answer}")

def extract_answer_number(text):
    text = text.replace(',', '')
    numbers = [s for s in re.findall(r'-?\d+\.?\d*', text)]
    if not numbers:
        return None
    return float(numbers[0])

intervened_number = extract_answer_number(intervened_answer)
print(f"Intervened numerical answer: {intervened_number}")
print(f"\n(Compare to original expected: {answer})")

LLaMA: CAUSAL Projection Intervention
Target token: '5'
Target token ID: 20
k (top-k intervention): 3

Running CAUSAL intervention (affects downstream positions)...

BoT  [pos= 0]: '9' → '5' ✓ INTERVENED (causal)
T1   [pos= 1]: '7' → '5' ✓ INTERVENED (causal)
T2   [pos= 2]: '7' → '5' ✓ INTERVENED (causal)
T3   [pos= 3]: '-' (not a number, no intervention)
T4   [pos= 4]: '9' → '5' ✓ INTERVENED (causal)
T5   [pos= 5]: '9' → '5' ✓ INTERVENED (causal)
T6   [pos= 6]: '"' (not a number, no intervention)

✓ Total interventions: 5
✓ Intervened at positions: [0, 1, 2, 4, 5]

Generating final answer with intervened CoT...
Stopped at step 5 (found number: '18')

Intervened Answer: The answer is: 18
Intervened numerical answer: 18.0

(Compare to original expected: 18)


In [60]:
print("="*80)
print("LLaMA: Chain-of-Thought WITHOUT Beginning-of-Thought Token")
print("="*80)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 1

questions = [question]

# Set pad token if needed
if llama_tokenizer.pad_token is None:
    llama_tokenizer.pad_token = llama_tokenizer.eos_token
    llama_tokenizer.pad_token_id = llama_tokenizer.eos_token_id

# Tokenize WITHOUT BoT token - just the raw question
inputs = llama_tokenizer(questions, return_tensors="pt", padding=False)
inputs = {k: v.to(device) for k, v in inputs.items()}

print(f"WITHOUT BoT - Input IDs: {inputs['input_ids']}")
print(f"WITHOUT BoT - Last 5 token IDs: {inputs['input_ids'][0, -5:].tolist()}")
print(f"WITHOUT BoT - BoT token ID should be: {llama_model.bot_id}")
print(f"WITHOUT BoT - Contains BoT?: {llama_model.bot_id in inputs['input_ids'][0].tolist()}")
print(f"Input shape (no BoT): {inputs['input_ids'].shape}")

# Store the 6 chain-of-thought hidden states (no BoT position)
cot_hidden_states_no_bot = []

with torch.no_grad():
    # Initial encoding - start directly from the last token of the question
    past_key_values = None
    outputs = llama_model.codi(
        input_ids=inputs["input_ids"],
        use_cache=True,
        output_hidden_states=True,
        past_key_values=past_key_values,
        attention_mask=inputs["attention_mask"]
    )
    past_key_values = outputs.past_key_values
    
    # Start from the last question token (no BoT)
    latent_embd = outputs.hidden_states[-1][:, -1:, :]  # Shape: [batch, 1, hidden_dim]
    
    # Apply initial projection
    if llama_training_args.use_prj:
        latent_embd = llama_model.prj(latent_embd)
    
    # Chain-of-Thought iterations (6 positions, no BoT)
    for i in range(llama_training_args.inf_latent_iterations):
        outputs = llama_model.codi(
            inputs_embeds=latent_embd,
            use_cache=True,
            output_hidden_states=True,
            past_key_values=past_key_values
        )
        past_key_values = outputs.past_key_values
        
        # Get hidden state BEFORE projection
        latent_embd = outputs.hidden_states[-1][:, -1:, :]  # Shape: [batch, 1, hidden_dim]
        
        # Store BEFORE projection for decoding
        cot_hidden_states_no_bot.append(latent_embd.clone())
        print(f"✓ Position {i} (T{i}): shape={latent_embd.shape}")
        
        # Apply projection for next iteration
        if llama_training_args.use_prj:
            latent_embd = llama_model.prj(latent_embd)

# Stack all 6 positions (no BoT)
llama_continuous_thoughts_no_bot = torch.cat(cot_hidden_states_no_bot, dim=1)
print(f"\n✓ Total continuous thoughts (no BoT): {llama_continuous_thoughts_no_bot.shape}")
print(f"  Expected: [1, 6, hidden_dim]")
print(f"  Got: {list(llama_continuous_thoughts_no_bot.shape)}")

if llama_continuous_thoughts_no_bot.shape[1] != 6:
    print(f"\n❌ ERROR: Expected 6 positions, got {llama_continuous_thoughts_no_bot.shape[1]}")
else:
    print(f"\n✓ SUCCESS: Collected exactly 6 CoT positions (no BoT)!")

# Decode the continuous thoughts
print("\n" + "="*80)
print("Decoding Continuous Thoughts (no BoT)")
print("="*80)

llama_decoded_tokens_no_bot = []
llama_number_positions_no_bot = []

for i in range(llama_continuous_thoughts_no_bot.shape[1]):
    hidden_state = llama_continuous_thoughts_no_bot[:, i, :]
    
    # Decode using CODI's LM head
    logits = llama_model.codi.lm_head(hidden_state)
    top1_token_id = torch.argmax(logits, dim=-1).item()
    top1_token_str = llama_tokenizer.decode([top1_token_id])
    
    is_number = bool(number_regex.match(top1_token_str.strip()))
    pos_type = f"T{i}"
    
    llama_decoded_tokens_no_bot.append((i, top1_token_id, top1_token_str, is_number))
    if is_number:
        llama_number_positions_no_bot.append(i)
    
    marker = " 🔢 NUMBER!" if is_number else ""
    print(f"  {pos_type:4s} [pos={i:2d}]: token_id={top1_token_id:5d} → '{top1_token_str}'{marker}")

print(f"\n📊 Numbers at positions: {llama_number_positions_no_bot}")
print(f"📊 Total numbers: {len(llama_number_positions_no_bot)}/6")

# Generate final answer (no BoT)
print("\n" + "="*80)
print("Generate Final Answer (no BoT)")
print("="*80)

with torch.no_grad():
    # Signal end-of-thought with EOT token
    if llama_training_args.remove_eos:
        eot_tensor = torch.tensor([[llama_model.eot_id]], dtype=torch.long).expand(batch_size, 1, 1).to(device)
    else:
        eot_tensor = torch.tensor([[llama_tokenizer.eos_token_id, llama_model.eot_id]], 
                                   dtype=torch.long).expand(batch_size, 1, 2).to(device)
    
    eot_emb = llama_model.get_embd(llama_model.codi, llama_model.model_name)(eot_tensor).squeeze(1)
    output = eot_emb
    
    # Generate answer tokens
    pred_tokens = []
    
    for step in range(256):
        out = llama_model.codi(
            inputs_embeds=output,
            output_hidden_states=False,
            attention_mask=None,
            use_cache=True,
            output_attentions=False,
            past_key_values=past_key_values
        )
        past_key_values = out.past_key_values
        
        logits = out.logits[:, -1, :llama_model.codi.config.vocab_size-1]
        next_token_id = torch.argmax(logits, dim=-1).item()
        pred_tokens.append(next_token_id)
        
        # Decode current token
        current_token_str = llama_tokenizer.decode([next_token_id])
        
        # Stop if EOS token
        if next_token_id == llama_tokenizer.eos_token_id:
            print(f"Stopped at step {step} (EOS token)")
            break
        
        # Stop immediately after generating a number token
        if number_regex.match(current_token_str.strip()):
            print(f"Stopped at step {step} (found number: '{current_token_str}')")
            break
        
        # Hard limit to prevent loops
        if step >= 49:
            print(f"Stopped at step {step} (max length)")
            break
        
        output = llama_model.get_embd(llama_model.codi, llama_model.model_name)(
            torch.tensor([[next_token_id]], device=device)
        )
    
    # Decode full answer
    full_answer_no_bot = llama_tokenizer.decode(pred_tokens, skip_special_tokens=True)
    print(f"\nGenerated Answer (no BoT): {full_answer_no_bot}")
    
    # Extract numerical answer
    def extract_answer_number(text):
        text = text.replace(',', '')
        numbers = [s for s in re.findall(r'-?\d+\.?\d*', text)]
        if not numbers:
            return None
        return float(numbers[-1])
    
    print(f"Qeustion: {question}")
    predicted_number_no_bot = extract_answer_number(full_answer_no_bot)
    print(f"Numerical answer (no BoT): {predicted_number_no_bot}")
    print(f"Expected answer: ",answer)

print("\n" + "="*80)
print("Comparison: With BoT vs Without BoT")
print("="*80)
print(f"With BoT:    7 positions (1 BoT + 6 CoT)")
print(f"Without BoT: 6 positions (0 BoT + 6 CoT)")
print(f"\nThis tests whether the BoT token improves reasoning performance.")

LLaMA: Chain-of-Thought WITHOUT Beginning-of-Thought Token
WITHOUT BoT - Input IDs: tensor([[128000,     42,   4010,    277,   4024,    311,    279,   3637,    311,
           3780,  29247,    369,    813,    502,  13455,     13,   3861,   9168,
           7194,    400,     20,     11,    719,   1475,   2132,   9168,   7194,
           1193,    220,   1399,      4,    315,    279,   3430,     13,    735,
           4010,    277,   6944,    311,   3780,    220,    845,  29247,     13,
           2650,   1790,   1587,    568,   1205,    311,   2343,    369,   1124,
             30]], device='cuda:0')
WITHOUT BoT - Last 5 token IDs: [311, 2343, 369, 1124, 30]
WITHOUT BoT - BoT token ID should be: 128257
WITHOUT BoT - Contains BoT?: False
Input shape (no BoT): torch.Size([1, 55])
✓ Position 0 (T0): shape=torch.Size([1, 1, 2048])
✓ Position 1 (T1): shape=torch.Size([1, 1, 2048])
✓ Position 2 (T2): shape=torch.Size([1, 1, 2048])
✓ Position 3 (T3): shape=torch.Size([1, 1, 2048])
✓ Position 4 

## Load CODI-GPT-2

In [None]:
print("="*80)
print("Loading CODI-GPT-2 from Local Checkpoint")
print("="*80)

  # Load from HuggingFace Hub
gpt2_model_name = "lhao499/codi-gpt2"

gpt2_model_args = ModelArguments(
      model_name_or_path="gpt2",
      lora_init=True,
      lora_r=128,
      lora_alpha=32,
      ckpt_dir="/workspace/CoT_Exploration/models/CODI-gpt2",  # Local checkpoint
      full_precision=True,
      token=None
  )

gpt2_training_args = TrainingArguments(
      output_dir="./outputs",
      model_max_length=512,
      inf_latent_iterations=6,
      use_prj=True,
      prj_dim=768,
      remove_eos=True,
      greedy=True,
      bf16=False,
      inf_num_iterations=1
  )

gpt2_lora_config = LoraConfig(
      task_type=TaskType.CAUSAL_LM,
      inference_mode=False,
      r=gpt2_model_args.lora_r,
      lora_alpha=gpt2_model_args.lora_alpha,
      lora_dropout=0.1,
      target_modules=["c_attn", "c_proj", "c_fc"],
      init_lora_weights=True,
  )

gpt2_model = CODI(gpt2_model_args,
gpt2_training_args, gpt2_lora_config)
gpt2_model = gpt2_model.to(device)
gpt2_model = gpt2_model.to(torch.bfloat16)
gpt2_model.eval()
gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2")

print("✓ CODI-GPT-2 loaded successfully from HuggingFace")

## Run GPT-2 Forward Pass

In [None]:
print("="*80)
print("GPT-2 Forward Pass")
print("="*80)

gpt2_inputs = gpt2_tokenizer(question, return_tensors="pt", add_special_tokens=True)
gpt2_input_ids = gpt2_inputs.input_ids.to(device)

print(f"Input shape: {gpt2_input_ids.shape}")

with torch.no_grad():
    gpt2_outputs = gpt2_model.codi(
        input_ids=gpt2_input_ids,
        output_hidden_states=True
    )
    gpt2_continuous_thoughts = gpt2_outputs.hidden_states[-1]

print(f"Continuous thoughts shape: {gpt2_continuous_thoughts.shape}")
print(f"Number of thought positions: {gpt2_continuous_thoughts.shape[1]}")


## Decode GPT-2 Continuous Thoughts (like section5_analysis.py)

In [None]:
print("="*80)
print("GPT-2: Decoding Continuous Thoughts")
print("="*80)

# Note: HuggingFace models don't have projection layers, decode directly
print(f"Continuous thoughts shape: {gpt2_continuous_thoughts.shape}")



# Decode first 15 positions directly from continuous thoughts
print("\nDecoded tokens from continuous thoughts:")
gpt2_decoded_tokens = []
gpt2_number_positions = []

for i in range(min(15, gpt2_continuous_thoughts.shape[1])):
    logits = gpt2_model.codi.lm_head(gpt2_continuous_thoughts[:, i, :])
    top1_token_id = torch.argmax(logits, dim=-1).item()
    top1_token_str = gpt2_tokenizer.decode([top1_token_id])

    is_number = bool(number_regex.match(top1_token_str))
    pos_type = "BoT" if i == 0 else f"T{i}"

    gpt2_decoded_tokens.append((i, top1_token_id, top1_token_str, is_number))
    if is_number:
        gpt2_number_positions.append(i)

    marker = " ← NUMBER!" if is_number else ""
    print(f"  {pos_type:4s} [pos={i:2d}]: token_id={top1_token_id:5d} → '{top1_token_str}'{marker}")

print(f"\n✓ Numbers detected at positions: {gpt2_number_positions}")
print(f"✓ Total numbers: {len(gpt2_number_positions)}")

## GPT-2: Projection Intervention (Target Token = '5', k=3)

In [None]:
print("="*80)
print("GPT-2: Projection Intervention")
print("="*80)

target_token = '5'
k = 3

# Get target token embedding
target_token_id = gpt2_tokenizer.encode(target_token, add_special_tokens=False)[0]
embedding_layer = gpt2_model.codi.get_input_embeddings()
target_embd = embedding_layer(torch.tensor([target_token_id], device=device))

print(f"Target token: '{target_token}'")
print(f"Target token ID: {target_token_id}")
print(f"k (top-k intervention): {k}")

gpt2_interventions = []

for i in range(min(15, gpt2_latent_embd.shape[1])):
    pos_type = "BoT" if i == 0 else f"T{i}"
    
    # Get predicted token
    logits = gpt2_model.codi.lm_head(gpt2_latent_embd[:, i, :])
    top1_token_id = torch.argmax(logits, dim=-1).item()
    top1_token_str = gpt2_tokenizer.decode([top1_token_id])
    
    # Check if it's a number
    is_number = bool(number_regex.match(top1_token_str))
    
    if is_number:
        # Get predicted token embedding
        predicted_embd = embedding_layer(torch.tensor([top1_token_id], device=device))
        
        # Get activation
        A = gpt2_latent_embd[:, i, :]  # [1, hidden_dim]
        
        # Normalize embeddings
        E_pred_norm = predicted_embd / torch.norm(predicted_embd, dim=-1, keepdim=True)
        E_target_norm = target_embd / torch.norm(target_embd, dim=-1, keepdim=True)
        
        # Projection removal and replacement
        proj_predicted = torch.sum(A * E_pred_norm, dim=-1, keepdim=True) * E_pred_norm
        proj_target = torch.sum(A * E_target_norm, dim=-1, keepdim=True) * E_target_norm
        
        A_modified = A - proj_predicted + k * proj_target
        
        # Decode modified activation
        logits_modified = gpt2_model.codi.lm_head(A_modified)
        new_token_id = torch.argmax(logits_modified, dim=-1).item()
        new_token_str = gpt2_tokenizer.decode([new_token_id])
        
        gpt2_interventions.append({
            'position': i,
            'predicted_token': top1_token_str,
            'new_token': new_token_str,
            'intervened': True
        })
        
        print(f"{pos_type:4s} [pos={i:2d}]: '{top1_token_str}' → '{new_token_str}' (intervened)")
    else:
        print(f"{pos_type:4s} [pos={i:2d}]: '{top1_token_str}' (not a number, skipped)")

print(f"\n✓ GPT-2 interventions: {len(gpt2_interventions)}")

## Compare Tokenizers (Sanity Check)

In [None]:
print("="*80)
print("Tokenizer Comparison")
print("="*80)

test_numbers = ['0', '1', '2', '3', '4', '5', '16', '18']

print("\nHow do tokenizers encode/decode numbers?\n")
for num in test_numbers:
    gpt2_ids = gpt2_tokenizer.encode(num, add_special_tokens=False)
    gpt2_decoded = gpt2_tokenizer.decode(gpt2_ids)
    gpt2_matches = bool(number_regex.match(gpt2_decoded))
    
    llama_ids = llama_tokenizer.encode(num, add_special_tokens=False)
    llama_decoded = llama_tokenizer.decode(llama_ids)
    llama_matches = bool(number_regex.match(llama_decoded))
    
    match_indicator = "✓" if gpt2_decoded == llama_decoded else "✗ MISMATCH"
    
    print(f"Number: '{num}'")
    print(f"  GPT-2:  IDs={gpt2_ids} → '{gpt2_decoded}' (matches={gpt2_matches})")
    print(f"  LLaMA:  IDs={llama_ids} → '{llama_decoded}' (matches={llama_matches})")
    print(f"  {match_indicator}\n")

## Summary and Diagnosis

In [None]:
print("="*80)
print("SUMMARY")
print("="*80)

print(f"\nGPT-2:")
print(f"  - Numbers detected: {len(gpt2_number_positions)}")
print(f"  - Interventions: {len(gpt2_interventions)}")

print(f"\nLLaMA:")
print(f"  - Numbers detected: {len(llama_number_positions)}")
print(f"  - Interventions: {len(llama_interventions)}")

print("\n" + "="*80)
print("DIAGNOSIS")
print("="*80)

if len(llama_number_positions) == 0:
    print("\n⚠️  ROOT CAUSE IDENTIFIED:")
    print("LLaMA's lm_head does NOT predict number tokens from continuous thoughts.")
    print("This explains why there are 0 interventions.")
    print("\nPossible reasons:")
    print("  1. Projection layers (bot_projection/thought_projection) not trained correctly")
    print("  2. lm_head not loaded correctly (check strict=False in load_state_dict)")
    print("  3. Vocabulary size mismatch causing wrong token predictions")
    print("  4. Continuous thoughts from different distribution than GPT-2")
elif len(llama_number_positions) < len(gpt2_number_positions):
    print("\n⚠️  LLaMA detects fewer numbers than GPT-2")
    print(f"  GPT-2: {len(gpt2_number_positions)} numbers")
    print(f"  LLaMA: {len(llama_number_positions)} numbers")
else:
    print("\n✓ Both models detect similar numbers of tokens")
    print("  The issue may be in the intervention logic, not the decoding.")

## Additional Investigation: Top-5 Predictions

In [None]:
print("="*80)
print("Top-5 Predictions Comparison")
print("="*80)

print("\nGPT-2 - First 5 thought positions:")
for i in range(min(5, gpt2_latent_embd.shape[1])):
    logits = gpt2_model.codi.lm_head(gpt2_latent_embd[:, i, :])
    top5_vals, top5_ids = torch.topk(logits[0], 5)
    top5_tokens = [gpt2_tokenizer.decode([tid.item()]) for tid in top5_ids]
    
    pos_type = "BoT" if i == 0 else f"T{i}"
    print(f"\n  {pos_type} [pos={i}]:")
    for j, (token, val) in enumerate(zip(top5_tokens, top5_vals)):
        is_num = "← NUM" if number_regex.match(token) else ""
        print(f"    {j+1}. '{token}' (logit={val.item():.2f}) {is_num}")

print("\n" + "-"*80)
print("\nLLaMA - First 5 thought positions:")
for i in range(min(5, llama_latent_embd.shape[1])):
    logits = llama_model.codi.lm_head(llama_latent_embd[:, i, :])
    top5_vals, top5_ids = torch.topk(logits[0], 5)
    top5_tokens = [llama_tokenizer.decode([tid.item()]) for tid in top5_ids]
    
    pos_type = "BoT" if i == 0 else f"T{i}"
    print(f"\n  {pos_type} [pos={i}]:")
    for j, (token, val) in enumerate(zip(top5_tokens, top5_vals)):
        is_num = "← NUM" if number_regex.match(token) else ""
        print(f"    {j+1}. '{token}' (logit={val.item():.2f}) {is_num}")

## Inspect Activation Norms

In [None]:
print("="*80)
print("Activation Norms")
print("="*80)

gpt2_norms = torch.norm(gpt2_latent_embd[0], dim=-1)
llama_norms = torch.norm(llama_latent_embd[0], dim=-1)

print(f"\nGPT-2 latent embedding norms (first 10 positions):")
for i in range(min(10, len(gpt2_norms))):
    print(f"  Position {i}: {gpt2_norms[i].item():.2f}")

print(f"\nLLaMA latent embedding norms (first 10 positions):")
for i in range(min(10, len(llama_norms))):
    print(f"  Position {i}: {llama_norms[i].item():.2f}")

print(f"\nGPT-2 mean norm: {gpt2_norms.mean().item():.2f}")
print(f"LLaMA mean norm: {llama_norms.mean().item():.2f}")