In [2]:
# Cell 1: Imports and Class Definitions
import logging
import math
import re
import os
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence
from copy import deepcopy

import torch
import torch.nn as nn
import transformers
from torch.nn import functional as F
import json

from peft import PeftModel, LoraConfig, TaskType, get_peft_model
from accelerate.utils import set_seed
from safetensors.torch import load_file
from transformers import AutoModelForCausalLM, AutoTokenizer

import numpy as np

@dataclass
class ModelArguments:
    model_name_or_path: str = field(default="mistralai/Mistral-7B-Instruct-v0.2")
    lora_r: int = field(default=128, metadata={"help": "lora rank"})
    lora_dropout: float = field(default=0.05, metadata={"help": "lora dropout"})
    full_precision: bool = field(default=True, metadata={"help": "whether use int4 for the base model"})
    lora_init: bool = field(default=False, metadata={"help": "True: Use zero and gaussian initialization"})
    token: Optional[str] = field(default=None, metadata={"help": "HF token to access private models"})
    adapter_name_or_path: Optional[str] = field(default=None, metadata={"help": "Path to the LoRA adapter"})
    lora_alpha: int = field(default=16, metadata={"help": "LoRA alpha"})
    ckpt_dir: Optional[str] = field(default=None, metadata={"help": "checkpoint dir for inference."})

@dataclass
class TrainingArguments(transformers.TrainingArguments):
    model_max_length: int = field(default=512, metadata={"help": "Maximum sequence length."})
    num_latent: int = field(default=5, metadata={"help": "The number of latent for training or inference."})
    use_lora: bool = field(default=True, metadata={"help": "Use lora or not."})
    greedy: bool = field(default=False, metadata={"help": "Greedy decoding during inference."})
    use_prj: bool = field(default=False, metadata={"help": "Use a prj module after the llm."})
    prj_dim: int = field(default=2048, metadata={"help": "The hidden dim of the projection module."})
    prj_dropout: float = field(default=0.0, metadata={"help": "Dropout ratio of the projection module."})
    prj_no_ln: bool = field(default=False, metadata={"help": "Remove LayerNorm for the projection module."})
    inf_latent_iterations: int = field(default=1, metadata={"help": "Latent iterations during inference."})
    inf_num_iterations: int = field(default=5, metadata={"help": "Run multiple times during inference."})
    remove_eos: bool = field(default=False, metadata={"help": "Do not add <eos> as a delimiter."})

class CODI(torch.nn.Module):
    def __init__(self, model_args, training_args, lora_config):
        super().__init__()
        self.model_args = model_args
        self.training_args = training_args
        self.model_name = model_args.model_name_or_path
        
        # Load the base model
        self.codi = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            torch_dtype=torch.bfloat16 if training_args.bf16 else torch.float16,
            resume_download=True,
        )

        ori_vocab_size = self.codi.config.vocab_size
        
        # Define special tokens
        self.pad_token_id = ori_vocab_size
        self.bot_id = ori_vocab_size + 1
        self.eot_id = ori_vocab_size + 2

        self.codi.resize_token_embeddings(ori_vocab_size + 3)
        self.dim = self.codi.config.hidden_size

        # Apply LoRA configuration
        if training_args.use_lora:
            self.codi = get_peft_model(self.codi, lora_config)

        # Optional Projection Layer
        if training_args.use_prj:
            self.prj = nn.Sequential(
                nn.Dropout(training_args.prj_dropout),
                nn.Linear(self.dim, training_args.prj_dim),
                nn.GELU(),
                nn.Linear(training_args.prj_dim, self.dim),
            )
            if not training_args.prj_no_ln:
                self.prj.add_module("ln", nn.LayerNorm(self.dim))
            self.prj.to(self.codi.dtype)

    def get_embd(self, model, model_name):
        # Helper to get the embedding layer from different model architectures
        base_model = model.get_base_model() if hasattr(model, "get_base_model") else model
        if "llama" in model_name.lower() or "mistral" in model_name.lower():
            return base_model.model.embed_tokens
        elif "phi" in model_name.lower():
             return base_model.model.embed_tokens
        elif "gpt2" in model_name.lower():
            return base_model.transformer.wte
        else:
            raise NotImplementedError(f"get_embd not implemented for {model_name}")
print("Imports complete")

Imports complete


In [4]:
# Cell 2: Helper Functions for Loading and Generation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_model_and_tokenizer(model_args, training_args):
    """Loads the model and tokenizer based on the provided arguments."""
    print(f"Using device: {device}")
    
    if not model_args.lora_init:
        raise ValueError("lora_init must be True for this script.")

    task_type = TaskType.CAUSAL_LM
    if any(name in model_args.model_name_or_path.lower() for name in ["llama", "mistral", "falcon", "qwen", "phi"]):
        target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"]
        if "phi" in model_args.model_name_or_path.lower():
             target_modules.extend(["dense", "fc1", "fc2"])
    elif any(name in model_args.model_name_or_path.lower() for name in ["gpt2"]):
        target_modules = ["c_attn", "c_proj", 'c_fc']
    else:
        raise ValueError(f"Unsupported model type for LoRA: {model_args.model_name_or_path}.")
    
    lora_config = LoraConfig(
        task_type=task_type,
        inference_mode=False,
        r=model_args.lora_r,
        lora_alpha=model_args.lora_alpha,
        lora_dropout=0.1,
        target_modules=target_modules,
        init_lora_weights=True,
    )
    
    print("Initializing CODI model...")
    model = CODI(model_args, training_args, lora_config)
    
    try:
        # Load fine-tuned adapter weights
        ckpt_path = os.path.expanduser(model_args.ckpt_dir)
        bin_path = os.path.join(ckpt_path, "codi.bin")

        if os.path.exists(bin_path):
            print(f"Loading state dict from: {bin_path}")
            state_dict = torch.load(bin_path, map_location=device)
        else:
            raise FileNotFoundError(f"Could not find model weights in {ckpt_path}")

        model.load_state_dict(state_dict, strict=False)
        print("Successfully loaded state dict.")

    except Exception as e:
        print(f"Error loading state dictionary: {e}")
        return None, None
    
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        model_max_length=training_args.model_max_length,
        padding_side="left",
        use_fast=False,
    )

    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})

    model = model.to(device)
    model.to(torch.bfloat16 if training_args.bf16 else torch.float16)
    model.eval()
    
    print("CODI Model and tokenizer loaded successfully.")
    return model, tokenizer

def generate_cache(model, tokenizer, perturbed_question, training_args):
    """
    Runs the model on the perturbed prompt and caches the hidden state activations
    for each layer of latent thoughts 1 through 6.

    Args:
        model: The pre-trained language model.
        tokenizer: The tokenizer associated with the model.
        perturbed_question (str): The perturbed input prompt.
        training_args: Configuration object.

     Returns:
        tuple(dict, str): A tuple containing:
            - activation_cache (dict): A cache of activations, structured as 
                                       {thought_idx: {layer_idx: tensor, ...}, ...}.
            - final_answer (str): The generated final answer from the model.
    """
    # 1. Initialize the Cache
    activation_cache = {i: {} for i in range(1, training_args.inf_latent_iterations + 1)}
    
    # Set model to evaluation mode
    model.eval()

    # 2. Preprocess the Input
    inputs = tokenizer(perturbed_question, return_tensors="pt").to(device)
    bot_tensor = torch.tensor([model.bot_id], dtype=torch.long, device=device).expand(inputs.input_ids.size(0), 1)
    input_ids = torch.cat((inputs["input_ids"], bot_tensor), dim=1)
    attention_mask = torch.cat((inputs["attention_mask"], torch.ones_like(bot_tensor)), dim=1)

    with torch.no_grad():
        # 3. Initial pass for "Latent Thought 0"
        # This sets up the initial state for the reasoning process.
        outputs = model.codi(
            input_ids=input_ids,
            attention_mask=attention_mask,
            use_cache=True,
            output_hidden_states=True
        )
        past_key_values = outputs.past_key_values
        latent_embd = outputs.hidden_states[-1][:, -1, :].unsqueeze(1)

        if training_args.use_prj:
            latent_embd = model.prj(latent_embd)

        # 4. Main Caching Loop: Iterate through Latent Thoughts 1-6
        for thought_idx in range(1, training_args.inf_latent_iterations + 1):
            outputs = model.codi(
                inputs_embeds=latent_embd,
                use_cache=True,
                output_hidden_states=True,
                past_key_values=past_key_values
            )
            
            # --- CACHING LOGIC ---
            for layer_idx, layer_hidden_state in enumerate(outputs.hidden_states):
                # Shape is (1, 1, 2048). Clone to store a clean copy.
                activation_cache[thought_idx][layer_idx] = layer_hidden_state.clone().detach()
            
            # Prepare for the next iteration
            past_key_values = outputs.past_key_values
            latent_embd = outputs.hidden_states[-1][:, -1, :].unsqueeze(1)

            if training_args.use_prj:
                latent_embd = model.prj(latent_embd)
        
        # --- END OF THOUGHT GENERATION AND CACHING ---
        # The 'past_key_values' now contains the state after all thoughts.
        # We proceed to generate the final answer.

        # 5. Generate the Final Answer from the Latent Thoughts
        eot_token_id = torch.tensor([model.eot_id], dtype=torch.long, device=device)
        next_input_embeds = model.get_embd(model.codi, model.model_name)(eot_token_id).unsqueeze(0)
        
        generated_token_ids = []
        
        for _ in range(training_args.model_max_length):
            # We no longer need hidden states for the final answer generation
            out = model.codi(
                inputs_embeds=next_input_embeds,
                use_cache=True,
                past_key_values=past_key_values,
                output_hidden_states=False 
            )
            
            past_key_values = out.past_key_values
            current_logits = out.logits[:, -1, :]
            next_token_id = torch.argmax(current_logits, dim=-1)
            
            if next_token_id.item() == tokenizer.eos_token_id:
                break
            
            generated_token_ids.append(next_token_id.item())
            next_input_embeds = model.get_embd(model.codi, model.model_name)(next_token_id).unsqueeze(0)

    # 6. Decode the generated tokens into the final string
    final_answer = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
    
    # 7. Return both the cache and the final answer
    return activation_cache, final_answer

def analyze_cache(activation_cache):
    """
    Analyzes and prints the structure and tensor shapes of an activation cache.
    """
    if not activation_cache:
        print("The activation cache is empty.")
        return

    print("--- Analyzing Activation Cache Structure ---")
    
    # Get the list of thought indices (e.g., [1, 2, 3, 4, 5, 6])
    thought_indices = sorted(activation_cache.keys())
    print(f"Cached Thought Indices: {thought_indices}")
    print(f"Total Thoughts Cached: {len(thought_indices)}\n")

    # Inspect the first thought in detail to show the structure for all
    first_thought_idx = thought_indices[0]
    first_thought_layers = activation_cache[first_thought_idx]
    
    layer_indices = sorted(first_thought_layers.keys())
    
    print(f"--- Details for Thought Index: {first_thought_idx} ---")
    print(f"Cached Layer Indices: {layer_indices[0]} to {layer_indices[-1]}")
    print(f"Total Layers Cached per Thought: {len(layer_indices)}")
    
    # Get the shape of the activation tensor from the first layer of the first thought
    # This shape will be consistent across all layers and thoughts.
    sample_tensor = first_thought_layers[layer_indices[0]]
    tensor_shape = sample_tensor.shape
    
    print(f"\nShape of each activation tensor: {tensor_shape}")
    print(f"  - Batch Size: {tensor_shape[0]}")
    print(f"  - Sequence Length: {tensor_shape[1]} (Represents a single thought token)")
    print(f"  - Hidden Dimension: {tensor_shape[2]}")

    print("\n--- Summary ---")
    print("The cache is a dictionary where:")
    print("  - Keys are thought indices (integers from 1 to 6).")
    print("  - Values are another dictionary for layers.")
    print("      - Keys are layer indices (integers from 0 to N).")
    print(f"      - Values are PyTorch tensors of shape {tensor_shape}.")
    print("------------------------------------------")

from contextlib import contextmanager

# Helper context manager for cleaner hook management
@contextmanager
def apply_forward_hook(module, hook_fn):
    """Context manager to apply and automatically remove a forward hook."""
    handle = module.register_forward_hook(hook_fn)
    try:
        yield
    finally:
        handle.remove()

def run_patch(
    model, 
    tokenizer, 
    training_args,
    original_question,
    activation_cache,
    thought_idx_to_patch,
    layer_idx_to_patch
):
    """
    Runs inference on the original prompt while patching a specific activation
    from the cache at a designated thought and layer.

    Args:
        model: The pre-trained language model.
        tokenizer: The tokenizer for the model.
        training_args: Configuration object.
        original_question (str): The clean, unperturbed question.
        activation_cache (dict): The cache from Phase 1.
        thought_idx_to_patch (int): The latent thought step to patch (1-6).
        layer_idx_to_patch (int): The layer index to patch (0 to N-1).

    Returns:
        str: The final generated answer string from the patched run.
    """
    # 1. Retrieve the specific activation to patch from the cache
    # This is the "corrupted" state we will inject.
    cached_activation = activation_cache[thought_idx_to_patch][layer_idx_to_patch]

    # --- Hook Definition and Logic ---
    # We use a mutable object (a list) as a counter to track the thought index.
    # This allows the inner hook function to modify a variable from the outer scope.
    thought_counter = [0] 

    def patch_activation_hook(module, args, output):
        # This hook fires every time a forward pass happens on the hooked layer.
        # We only want to patch when we are processing the target thought.
        
        # The first pass is for the prompt + [Begin_Thought], which we count as thought 0.
        # The subsequent passes are for thoughts 1, 2, 3, ...
        if thought_counter[0] == thought_idx_to_patch:
            # The output of a layer is a tuple in some models (e.g., (hidden_state, ...))
            # Or just the tensor itself. Let's assume it's the tensor for simplicity.
            # We must handle both cases.
            if isinstance(output, tuple):
                # Modify the first element, which is typically the hidden state.
                # Shape is (batch, seq_len, hidden_dim) -> (1, 1, 2048)
                output_list = list(output)
                output_list[0] = cached_activation
                return tuple(output_list)
            else:
                # If output is just a tensor, replace it directly.
                return cached_activation
        
        # For all other thoughts, do nothing and return the original output.
        return output
    
    # --- Start of Modified Generation Logic ---
    model.eval()

    # Get the target layer module. The path may need adjustment.
    # Example for a standard Hugging Face GPT-2/Llama architecture:
    # `model.codi.model.layers[layer_idx_to_patch]`
    # We will assume a path like this for the example.
    target_layer = model.codi.base_model.model.model.layers[layer_idx_to_patch]

    with torch.no_grad(), apply_forward_hook(target_layer, patch_activation_hook):
        # 2. Preprocess the ORIGINAL input
        inputs = tokenizer(original_question, return_tensors="pt").to(device)
        bot_tensor = torch.tensor([model.bot_id], dtype=torch.long, device=device).expand(inputs.input_ids.size(0), 1)
        input_ids = torch.cat((inputs["input_ids"], bot_tensor), dim=1)
        attention_mask = torch.cat((inputs["attention_mask"], torch.ones_like(bot_tensor)), dim=1)

        # 3. Run "Latent Thought 0" Generation
        # The hook will fire here, but since thought_counter[0] is 0, it won't patch
        # unless thought_idx_to_patch is also 0 (which we have excluded).
        outputs = model.codi(
            input_ids=input_ids,
            attention_mask=attention_mask,
            use_cache=True,
            output_hidden_states=True
        )
        thought_counter[0] += 1 # Increment after processing thought 0
        past_key_values = outputs.past_key_values
        latent_embd = outputs.hidden_states[-1][:, -1, :].unsqueeze(1)
        if training_args.use_prj:
            latent_embd = model.prj(latent_embd)

        # 4. Iterate through latent thoughts
        for i in range(1, training_args.inf_latent_iterations + 1):
            # The hook will fire on the forward pass inside this loop.
            # If `thought_counter[0]` matches `thought_idx_to_patch`, the swap happens.
            outputs = model.codi(
                inputs_embeds=latent_embd,
                use_cache=True,
                output_hidden_states=True,
                past_key_values=past_key_values
            )
            thought_counter[0] += 1 # Increment after processing each thought
            past_key_values = outputs.past_key_values
            latent_embd = outputs.hidden_states[-1][:, -1, :].unsqueeze(1)
            if training_args.use_prj:
                latent_embd = model.prj(latent_embd)

        # 5. Generate the final answer (same as original `generate` function)
        eot_token_id = torch.tensor([model.eot_id], dtype=torch.long, device=device)
        next_input_embeds = model.get_embd(model.codi, model.model_name)(eot_token_id).unsqueeze(0)
        generated_token_ids = []
        for _ in range(training_args.model_max_length):
            # The hook will continue to fire here but the counter is now too high, so it won't patch.
            out = model.codi(
                inputs_embeds=next_input_embeds,
                use_cache=True,
                past_key_values=past_key_values
            )
            past_key_values = out.past_key_values
            current_logits = out.logits[:, -1, :]
            next_token_id = torch.argmax(current_logits, dim=-1)
            if next_token_id.item() == tokenizer.eos_token_id:
                break
            generated_token_ids.append(next_token_id.item())
            next_input_embeds = model.get_embd(model.codi, model.model_name)(next_token_id).unsqueeze(0)

    # 6. Decode and return the final answer
    final_answer = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
    return final_answer

print("Complete!")

Complete!


In [5]:
# Cell 3: Configuration and Interactive Session
# --- Configuration is now set directly in the script ---
model_args = ModelArguments(
    model_name_or_path="./llama", # Load base model from local folder
    lora_r=128,
    lora_alpha=32,
    lora_init=True,
    ckpt_dir="./llama" # Path to fine-tuned adapter
)

# Note: TrainingArguments is used for model config, not actual training
training_args = TrainingArguments(
    output_dir="./output", # Dummy output dir
    seed=11,
    model_max_length=512,
    bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
    greedy=True,
    num_latent=6,
    use_prj=True,
    prj_dim=2048,
    prj_no_ln=False,
    prj_dropout=0.0,
    inf_latent_iterations=6,
    remove_eos=True,
    use_lora=True,
)

set_seed(training_args.seed)

codi_model, tokenizer = load_model_and_tokenizer(model_args, training_args)

if codi_model and tokenizer:
    print("\nAll models ready!")
else:
    print("ERROR loading models, read above")

`torch_dtype` is deprecated! Use `dtype` instead!


Using device: cuda
Initializing CODI model...


OSError: Error no file named model.safetensors, or pytorch_model.bin, found in directory ./llama.

In [14]:
#Cell 4: Run Loop
if codi_model and tokenizer:
    print("\n--- Interactive Mode ---")
    print("Enter your question below. Type 'exit' to quit.")
    while True:
        try:
            user_question = input("\n Original Question: ")
            if user_question.lower() == 'exit':
                break
            modif = input("\n Modified Question: ")
            print("\nGenerating with CODI model...")
            cache, final_answer = generate_cache(codi_model, tokenizer, modif, training_args)
            print("Cache complete")
            #original question, cache, thought (1-6), layer (0-15)
            patched = run_patch(codi_model, tokenizer, training_args, user_question, cache, 1, 5)
            # analyze_cache(cache)exit
            
            print("\n--- Modified Answer ---")
            print(final_answer.strip())
            print("--------------------")

            print("\n--- Patched Answer ---")
            print(patched.strip())
            print("--------------------")

        except KeyboardInterrupt:
            print("\nExiting...")
            break
        except Exception as e:
            print(f"An error occurred: {e}")


--- Interactive Mode ---
Enter your question below. Type 'exit' to quit.



 Original Question:  Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?

 Modified Question:  Janet’s ducks lay 11 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?



Generating with CODI model...
Cache complete

--- Modified Answer ---
The answer is: 8
--------------------

--- Patched Answer ---
The answer is: 18
--------------------



 Original Question:  exit
