In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import matplotlib.pyplot as plt
import numpy as np
from typing import Dict, List, Any, Tuple, Optional
from tqdm.auto import tqdm
import plotly.graph_objects as go

In [2]:
def get_device():
    if torch.cuda.is_available():
        device = torch.device("cuda")
        # Print CUDA details
        print(f"CUDA Device: {torch.cuda.get_device_name()}")
        print(f"CUDA Memory Allocated: {torch.cuda.memory_allocated()/1024**2:.2f}MB")
        print(f"CUDA Memory Reserved: {torch.cuda.memory_reserved()/1024**2:.2f}MB")
        return device
    else:
        return torch.device("cpu")

In [3]:
DEVICE = get_device()

In [None]:
def load_model():
    dtype = torch.bfloat16
    
    tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
    model = AutoModelForCausalLM.from_pretrained(
        "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
        torch_dtype=dtype,
        low_cpu_mem_usage=True,  # Optimize memory usage during loading
    )
    model.to(DEVICE)
    model.eval()
    return tokenizer, model

# Load the model and tokenizer
tokenizer, model = load_model()

In [None]:
model

In [6]:
def prepare_prompt(prompt: str, tokenizer) -> str:
    """Format the prompt using the chat template if available"""
    if hasattr(tokenizer, 'apply_chat_template'):
        messages = [{"role": "user", "content": prompt}]
        formatted_prompt = tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        return formatted_prompt
    return prompt

In [7]:
def tokenize_prompt(prompt: str, tokenizer, device=None):
    """Tokenize a prompt and prepare model inputs"""
    # Get appropriate device
    if device is None:
        device = get_device()
        
    # Format prompt
    formatted_prompt = prepare_prompt(prompt, tokenizer)
    
    # Tokenize
    tokenized_prompt = tokenizer(formatted_prompt, return_tensors="pt")
    
    # Move to device
    tokenized_prompt = {k: v.to(device) for k, v in tokenized_prompt.items()}
    
    print(f"Prompt: {formatted_prompt}")
    
    return tokenized_prompt, formatted_prompt

In [8]:
def get_residual_stream(model, 
                        tokenizer, 
                        tokenized_prompt, 
                        max_new_tokens=20, 
                        layer_indices=None, 
                        save_path=None):
    """
    Get and store the residual stream from the model for input tokens and generated tokens.
    
    Args:
        model: The transformer model (LlamaForCausalLM)
        tokenizer: The tokenizer associated with the model
        tokenized_prompt: Dictionary containing input_ids and attention_mask tensors
        max_new_tokens: Maximum number of new tokens to generate (not including input tokens)
        layer_indices: Optional list of specific layer indices to capture. If None, capture all layers.
        save_path: Optional path to save the residual streams (as .pt file)
    
    Returns:
        tuple: (all_residual_streams, generated_text, token_ids)
            - all_residual_streams: Dict mapping from layer index to tensors of shape 
              [batch_size, seq_len, hidden_size] representing the residual stream at each layer
            - generated_text: The text including the generated tokens
            - token_ids: The token IDs including the generated tokens
    """
    batch_size = tokenized_prompt["input_ids"].shape[0]
    device = tokenized_prompt["input_ids"].device
    input_length = tokenized_prompt["input_ids"].shape[1]
    
    # Register hooks to capture residual streams
    all_residual_streams = {}
    hooks = []
    
    # Function to determine if we should capture a specific layer
    def should_capture_layer(idx):
        if layer_indices is None:
            return True
        return idx in layer_indices
    
    # Define hook function to capture residual stream
    def get_activation(name, layer_idx):
        def hook(module, input, output):
            # For LLaMA models, the residual stream is the input to the layer's input_layernorm
            if name == "input_layernorm":
                # Store the residual stream (input[0])
                if layer_idx not in all_residual_streams:
                    all_residual_streams[layer_idx] = []
                all_residual_streams[layer_idx].append(input[0].detach().cpu())
            elif name == "final_norm":
                # Final norm layer
                if layer_idx not in all_residual_streams:
                    all_residual_streams[layer_idx] = []
                all_residual_streams[layer_idx].append(input[0].detach().cpu())
        return hook
    
    # Register hooks for each transformer layer
    for i, layer in enumerate(model.model.layers):
        if should_capture_layer(i):
            # Capture the input to the input_layernorm which is the residual stream
            hook = layer.input_layernorm.register_forward_hook(get_activation("input_layernorm", i))
            hooks.append(hook)
    
    # Also capture the final norm layer (after the last transformer layer)
    final_hook = model.model.norm.register_forward_hook(
        get_activation("final_norm", len(model.model.layers))
    )
    hooks.append(final_hook)
    
    # Clone the tokenized_prompt to avoid modifying the original
    generation_inputs = {
        "input_ids": tokenized_prompt["input_ids"].clone(),
        "attention_mask": tokenized_prompt["attention_mask"].clone(),
    }
    
    # Generate text using the model's built-in generate method to handle stopping properly
    try:
        with torch.no_grad():
            # Use the model's generate method, which handles EOS tokens and stopping criteria
            generation_output = model.generate(
                **generation_inputs,
                max_new_tokens=max_new_tokens,
                return_dict_in_generate=True,
                output_scores=True,
                output_hidden_states=False,  # We capture hidden states via hooks
                # Set pad_token_id to EOS to avoid potential issues
                pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2
            )
            
            # Get the generated sequence
            generated_ids = generation_output.sequences
    except Exception as e:
        print(f"Error during generation: {e}")
        # Remove hooks
        for hook in hooks:
            hook.remove()
        raise
    
    # Convert lists to tensors
    for layer_idx in all_residual_streams:
        # Concatenate all the tensors for this layer along the sequence dimension
        all_residual_streams[layer_idx] = torch.cat(all_residual_streams[layer_idx], dim=1)
    
    # Remove hooks
    for hook in hooks:
        hook.remove()
    
    # Decode the generated text
    try:
        generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        # Print a preview of the generated text
        preview = generated_text.replace('\n', '\\n')
        if len(preview) > 100:
            preview = preview[:97] + '...'
        input_tokens = input_length
        output_tokens = generated_ids.shape[1]
        new_tokens = output_tokens - input_tokens
        print(f"Generated text: {preview}")
        print(f"Input tokens: {input_tokens}, New tokens: {new_tokens}, Total tokens: {output_tokens}")
    except Exception as e:
        print(f"Error decoding generated tokens: {e}")
        generated_text = f"[Error decoding tokens: {generated_ids[0]}]"
    
    # Store token ids for reference
    token_ids = generated_ids.detach().cpu()
    
    # Save to file if requested
    if save_path is not None:
        save_data = {
            'residual_streams': all_residual_streams,
            'generated_text': generated_text,
            'token_ids': token_ids
        }
        try:
            # Add file extension if not provided
            if not save_path.endswith('.pt'):
                save_path = f"{save_path}.pt"
            
            # Make sure directory exists
            import os
            os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True)
            
            torch.save(save_data, save_path)
            print(f"Residual stream data saved to {save_path} \n-----\n")
        except Exception as e:
            print(f"Error saving residual stream data: {e}")
    
    return all_residual_streams, generated_text, token_ids


In [9]:


def load_residual_stream(load_path):
    """
    Load previously saved residual stream data from a file.
    
    Args:
        load_path: Path to the saved residual stream file (.pt)
        
    Returns:
        tuple: (all_residual_streams, generated_text, token_ids)
    """
    data = torch.load(load_path)
    return data['residual_streams'], data['generated_text'], data['token_ids']

In [10]:
def patch_res_stream(model, 
                    tokenizer, 
                    clean_path, 
                    corrupt_path, 
                    patch_layers, 
                    num_tokens_to_patch=None, 
                    target_token="<think>", 
                    save_path=None, 
                    device=None, 
                    max_new_tokens=100):
    """
    Patch residual streams from corrupted prompt into the clean prompt at specified layers and tokens,
    supporting autoregressive generation.
    
    Args:
        model: The transformer model (LlamaForCausalLM)
        tokenizer: The tokenizer associated with the model
        clean_path: Path to the saved clean prompt residual stream file (.pt)
        corrupt_path: Path to the saved corrupted prompt residual stream file (.pt)
        patch_layers: List of layer indices to patch (e.g., [5, 6, 7] or range(32))
        num_tokens_to_patch: Number of tokens to patch after the target token.
                           If None, will patch from the target token to the end.
        target_token: Token after which to start patching (default: "<think>")
        save_path: Optional path to save the patched results (.pt)
        device: Device to perform computation on
        max_new_tokens: Maximum number of new tokens to generate
    
    Returns:
        tuple: (patched_text, original_clean_text, original_corrupt_text, patched_token_ids)
    """
    import os
    import torch
    
    # Set device if not provided
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load the saved residual streams
    print(f"Loading residual streams from {clean_path} and {corrupt_path}...")
    clean_streams, clean_text, clean_ids = load_residual_stream(clean_path)
    corrupt_streams, corrupt_text, corrupt_ids = load_residual_stream(corrupt_path)
    
    print(f"Clean text preview: {clean_text[:100]}...")
    print(f"Corrupt text preview: {corrupt_text[:100]}...")
    
    # Find the target token index in the clean prompt
    if target_token:
        # Check if target token is a string or a token ID
        if isinstance(target_token, str):
            # Find the token ID for the target token
            target_token_ids = tokenizer.encode(target_token, add_special_tokens=False)
            if len(target_token_ids) != 1:
                print(f"Warning: Target token '{target_token}' encoded to {len(target_token_ids)} tokens: {target_token_ids}")
            target_token_id = target_token_ids[0]
        else:
            target_token_id = target_token
            
        # Find the position of the target token in the clean_ids
        target_positions = (clean_ids[0] == target_token_id).nonzero(as_tuple=True)[0]
        if len(target_positions) == 0:
            print(f"Target token '{target_token}' not found in clean_ids. Will patch from the beginning.")
            patch_start_idx = 0
        else:
            # Take the first occurrence of the target token
            patch_start_idx = target_positions[0].item()
            print(f"Found target token at position {patch_start_idx}")
    else:
        # If no target token specified, start patching from the beginning
        patch_start_idx = 0
    
    # Set the patch range
    patch_start_idx = patch_start_idx if target_token else 0
    
    if num_tokens_to_patch is None:
        # From target token to the end
        patch_end_idx = min(clean_ids.shape[1], corrupt_ids.shape[1])
    else:
        # Patch specified number of tokens after the target token
        patch_end_idx = min(patch_start_idx + num_tokens_to_patch, clean_ids.shape[1], corrupt_ids.shape[1])
    
    patch_tokens_range = (patch_start_idx, patch_end_idx)
    
    print(f"Will patch tokens from position {patch_start_idx} to {patch_end_idx} ({patch_end_idx - patch_start_idx} tokens)")
    print(f"Will patch layers: {patch_layers}")
    
    # Validate layers
    max_layer = max(clean_streams.keys())
    patch_layers = [layer for layer in patch_layers if layer in clean_streams and layer in corrupt_streams]
    if not patch_layers:
        raise ValueError(f"No valid layers to patch. Layers must be in range 0-{max_layer}")
    
    # Create patched inputs based on clean prompt
    input_ids = clean_ids[0, :patch_start_idx].unsqueeze(0).to(device)
    attention_mask = torch.ones_like(input_ids).to(device)
    
    patched_inputs = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
    }
    
    # Set up patching state for autoregressive generation
    current_token_pos = patch_start_idx - 1  # Position of the last token in inputs (-1 because we increment before patching)
    max_corrupted_len = corrupt_ids.shape[1]
    
    # Set up patching progress tracking
    from tqdm.auto import tqdm
    expected_tokens_to_patch = min(patch_end_idx - patch_start_idx, max_new_tokens)
    progress_bar = tqdm(total=expected_tokens_to_patch, desc="Patching tokens", ncols=100)
    
    # Create a hook to patch each token during autoregressive generation
    hooks = []
    
    def patch_hook(layer_idx):
        patched_tokens_counter = 0  # Initialize counter inside hook closure
        
        def hook(module, inputs):
            nonlocal current_token_pos, patched_tokens_counter
            
            # Get residual stream
            res_stream = inputs[0]
            batch_size, seq_len, hidden_dim = res_stream.shape
            
            # During autoregressive generation, we'll get one token at a time
            # For the initial pass with multiple tokens, patch only the target position and beyond
            if seq_len > 1:
                # This is the initial forward pass with all input tokens
                # We don't patch anything in the input prefix (before target token)
                return inputs
            else:
                # This is autoregressive generation with one token at a time
                # Increment position counter - this tells us which token we're generating
                current_token_pos += 1
                
                # Check if this position should be patched
                if patch_start_idx <= current_token_pos < patch_end_idx and current_token_pos < max_corrupted_len:
                    # Get the corresponding position in the corrupted stream
                    to_patch = corrupt_streams[layer_idx][0, current_token_pos, :].to(device)
                    
                    # Replace the residual stream for this token with the corrupted version
                    res_stream[0, 0, :] = to_patch
                    
                    # Update progress tracking
                    patched_tokens_counter += 1
                    progress_bar.update(1)
                    
                    if current_token_pos % 10 == 0:  # Less frequent logging
                        print(f"Patched token at position {current_token_pos} in layer {layer_idx}")
            
            # Return potentially modified inputs
            return (res_stream,) + inputs[1:] if len(inputs) > 1 else (res_stream,)
        
        return hook
    
    # Register forward pre-hooks for each layer to patch
    for layer_idx in patch_layers:
        hook = model.model.layers[layer_idx].input_layernorm.register_forward_pre_hook(
            patch_hook(layer_idx)
        )
        hooks.append(hook)
    
    # Generate text with patched residual streams
    try:
        with torch.no_grad():
            # Generate text with the patched residual stream
            generation_output = model.generate(
                **patched_inputs,
                max_new_tokens=max_new_tokens,
                return_dict_in_generate=True,
                output_scores=True,
                pad_token_id=tokenizer.eos_token_id
            )
            patched_ids = generation_output.sequences
    except Exception as e:
        print(f"Error during generation with patched residual stream: {e}")
        for hook in hooks:
            hook.remove()
        raise
    finally:
        # Remove all hooks
        for hook in hooks:
            hook.remove()
        
        # Close progress bar
        progress_bar.close()
    
    # Decode the patched generation
    patched_text = tokenizer.decode(patched_ids[0], skip_special_tokens=True)
    
    # Print summary
    print("\n--- Patching Results ---")
    print(f"Patched layers: {patch_layers}")
    print(f"Patched token range: {patch_tokens_range}")
    print(f"Generation length: {patched_ids.shape[1] - input_ids.shape[1]} new tokens")
    print("\nOriginal clean output: ")
    print(clean_text[:200] + "..." if len(clean_text) > 200 else clean_text)
    print("\nOriginal corrupt output: ")
    print(corrupt_text[:200] + "..." if len(corrupt_text) > 200 else corrupt_text)
    print("\nPatched output: ")
    print(patched_text[:200] + "..." if len(patched_text) > 200 else patched_text)
    
    # Save patched results if requested
    if save_path:
        patched_data = {
            'patched_text': patched_text,
            'clean_text': clean_text,
            'corrupt_text': corrupt_text,
            'patched_token_ids': patched_ids.cpu(),
            'patch_info': {
                'patch_layers': patch_layers,
                'patch_tokens_range': patch_tokens_range,
                'clean_path': clean_path,
                'corrupt_path': corrupt_path
            }
        }
        
        try:
            # Add file extension if not provided
            if not save_path.endswith('.pt'):
                save_path = f"{save_path}.pt"
            
            # Make sure directory exists
            os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True)
            
            torch.save(patched_data, save_path)
            print(f"Patched results saved to {save_path}")
        except Exception as e:
            print(f"Error saving patched results: {e}")
    
    return patched_text, clean_text, corrupt_text, patched_ids

In [None]:
clean_prompt = "David found four sticks of different lengths that can be used to form three non-congruent convex cyclic quadrilaterals, $A,\text{ }B,\text{ }C$ , which can each be inscribed in a circle with radius $1$ . Let $\varphi_A$ denote the measure of the acute angle made by the diagonals of quadrilateral $A$ , and define $\varphi_B$ and $\varphi_C$ similarly. Suppose that $\sin\varphi_A=\frac{2}{3}$ , $\sin\varphi_B=\frac{3}{5}$ , and $\sin\varphi_C=\frac{6}{7}$ . All three quadrilaterals have the same area $K$ , which can be written in the form $\dfrac{m}{n}$ , where $m$ and $n$ are relatively prime positive integers. Find $m+n$ ."
clean_tokenized_prompt, clean_formatted_prompt = tokenize_prompt(clean_prompt, tokenizer, device=DEVICE)


clean_residual_streams, clean_generated_text, clean_token_ids = get_residual_stream(model, 
                                                                    tokenizer, 
                                                                    clean_tokenized_prompt, 
                                                                    max_new_tokens=32768, 
                                                                    layer_indices=None, 
                                                                    save_path="res-stream/clean-prompt-2.pt")

if DEVICE == "cuda":
    torch.cuda.empty_cache()

print(clean_generated_text)

In [None]:
corrupted_prompt = "For integers $a,b,c$ and $d,$ let $f(x)=x^2+ax+b$ and $g(x)=x^2+cx+d.$ Find the number of ordered triples $(a,b,c)$ of integers with absolute values not exceeding $10$ for which there is an integer $d$ such that $g(f(2))=g(f(4))=0.$ . Hint: The key insight is to recognize what it means for $g(f(2)) = g(f(4)) = 0$. This tells us that both $f(2)$ and $f(4)$ are roots of the polynomial $g(x)$. There are two distinct cases to consider: 1. What happens if $f(2) = f(4)$? When would this occur? What constraint does this place on $a$ 2. What happens if $f(2) \neq f(4)$? In this case, since $g(x)$ is a quadratic polynomial and we know both $f(2)$ and $f(4)$ are roots, we can write: $g(x) = (x - f(2))(x - f(4))$. When you expand this, you'll get expressions for $c$ and $d$ in terms of $a$ and $b$. Focus on the constraint for $c$ (since $d$ can be any integer), and determine which values of $(a,b)$ will give you values of $c$ with $|c| \leq 10$. Try separating the cases and counting how many valid triples $(a,b,c)$ exist in each scenario."
corrupted_tokenized_prompt, corrupted_formatted_prompt = tokenize_prompt(corrupted_prompt, tokenizer, device=DEVICE)


corrupted_residual_streams, corrupted_generated_text, corrupted_token_ids = get_residual_stream(model, 
                                                                    tokenizer, 
                                                                    corrupted_tokenized_prompt, 
                                                                    max_new_tokens=16384, 
                                                                    layer_indices=None, 
                                                                    save_path="res-stream/corrupted-prompt-1.pt")

print(corrupted_generated_text)

In [None]:
patched_text, clean_text, corrupt_text, patched_ids = patch_res_stream(model, 
                                                                        tokenizer, 
                                                                        clean_path = "res-stream/clean-prompt-1.pt", 
                                                                        corrupt_path = "res-stream/corrupted-prompt-1.pt", 
                                                                        patch_layers = [0], 
                                                                        num_tokens_to_patch=64, 
                                                                        target_token="<think>", 
                                                                        save_path="res-stream/patched-prompt-1.pt", 
                                                                        device=None, 
                                                                        max_new_tokens=1024)

if DEVICE == "cuda":
    torch.cuda.empty_cache()

print(patched_text)

In [None]:
print("corrupted text:")
print(corrupt_text)
print("\n---")
print("clean text:")
print(clean_text)
print("\n---")
print("patched text:")
print(patched_text)
print("\n---")
# print(patched_ids)


In [None]:
res_streams, generated_text, token_ids = load_residual_stream("res-stream/corrupted-prompt-1.pt")