In [43]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

import torch
import plotly.graph_objects as go
import plotly.express as px
import numpy as np
from typing import Dict, List, Tuple, Optional

In [2]:
def load_model():
    tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
    model = AutoModelForCausalLM.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
                                                 torch_dtype=torch.bfloat16)
    model.eval()
    return tokenizer, model

tokenizer, model = load_model()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [40]:
def prepare_prompt(prompt, tokenizer):
    # Apply the chat template to format the prompt properly
    messages = [
        {"role": "user", "content": prompt}
    ]

    # Format the prompt using the chat template
    formatted_prompt = tokenizer.apply_chat_template(
        messages, 
        tokenize=False,
        add_generation_prompt=False
    )

    # Tokenize the formatted prompt
    inputs = tokenizer(formatted_prompt, return_tensors="pt")
    
    return inputs


In [41]:
def generate_text(model, tokenizer, inputs, max_length=10000):
    # Generate text
    with torch.no_grad():
        output = model.generate(
            inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_length=max_length,
            num_return_sequences=1,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
            return_dict_in_generate=True,
            output_scores=True,
            return_legacy_cache=True 
        )
    
    # Get the generated token IDs
    generated_ids = output.sequences[0]
    
    # Decode the generated text
    generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
    
    return generated_text

In [46]:


# Simplified version for easier debugging
def simple_patching_experiment(clean_prompt, corrupted_prompt, layer_to_patch=None):
    """
    Simplified function to patch a single layer for debugging.
    """
    print("Loading model...")
    tokenizer, model = load_model()
    
    print("Preparing prompts...")
    clean_inputs = prepare_prompt(clean_prompt, tokenizer)
    corrupted_inputs = prepare_prompt(corrupted_prompt, tokenizer)
    
    print("Generating from clean prompt (baseline)...")
    clean_generation = generate_text(model, tokenizer, clean_inputs)
    
    print("Generating from corrupted prompt (baseline)...")
    corrupted_generation = generate_text(model, tokenizer, corrupted_inputs)
    
    print("\n----- RESULTS -----")
    print("CLEAN PROMPT GENERATION:")
    print(clean_generation)
    print("\nCORRUPTED PROMPT GENERATION:")
    print(corrupted_generation)
    
    # If no layer to patch, we're done with baseline
    if layer_to_patch is None:
        return
    
    print(f"\nPatching layer {layer_to_patch}...")
    
    # Collect clean activations
    clean_activations = {}
    
    def save_clean_activation(module, input, output, layer_idx):
        # For Llama, output is a tuple where first element is hidden state
        if isinstance(output, tuple):
            clean_activations[layer_idx] = output[0].detach().clone()
        else:
            clean_activations[layer_idx] = output.detach().clone()
        return output
    
    # Register clean hook
    print("Collecting clean activations...")
    clean_hook = model.model.layers[layer_to_patch].register_forward_hook(
        lambda mod, inp, out: save_clean_activation(mod, inp, out, layer_to_patch)
    )
    
    # Run forward pass with clean prompt
    with torch.no_grad():
        model(**clean_inputs)
    
    # Remove clean hook
    clean_hook.remove()
    
    # Define patching hook
    def patch_hook(module, input, output):
        if isinstance(output, tuple):
            hidden_state = output[0]
            other = output[1:]
            
            # Patch with clean activations
            clean_activation = clean_activations[layer_to_patch]
            min_len = min(clean_activation.size(1), hidden_state.size(1))
            
            patched_hidden = hidden_state.clone()
            patched_hidden[:, :min_len, :] = clean_activation[:, :min_len, :]
            
            # Return new tuple with patched hidden state
            return (patched_hidden,) + other
        else:
            # Unlikely case, but handle for completeness
            clean_activation = clean_activations[layer_to_patch]
            min_len = min(clean_activation.size(1), output.size(1))
            
            patched_output = output.clone()
            patched_output[:, :min_len, :] = clean_activation[:, :min_len, :]
            
            return patched_output
    
    # Register patching hook
    print("Setting up patching hook...")
    patch_hook_handle = model.model.layers[layer_to_patch].register_forward_hook(patch_hook)
    
    # Generate with patching
    print("Generating with patching...")
    patched_generation = generate_text(model, tokenizer, corrupted_inputs)
    
    # Remove patching hook
    patch_hook_handle.remove()
    
    # Print patched result
    print(f"\nPATCHED GENERATION (LAYER {layer_to_patch}):")
    print(patched_generation)
    


In [47]:
# Example usage
if __name__ == "__main__":
    clean_prompt = "Problem: X+Y=10, X-Y=4. Hint: Add equations to get 2X=14, X=7, then Y=3. Answer:"
    corrupted_prompt = "Problem: X+Y=10, X-Y=4. Answer:"
    
    # Run baseline only (no patching)
    # simple_patching_experiment(clean_prompt, corrupted_prompt, None)
    
    # Run with patching layer 20
    simple_patching_experiment(clean_prompt, corrupted_prompt, 20)

Loading model...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Preparing prompts...
Generating from clean prompt (baseline)...
Generating from corrupted prompt (baseline)...

----- RESULTS -----
CLEAN PROMPT GENERATION:
<｜User｜>Problem: X+Y=10, X-Y=4. Hint: Add equations to get 2X=14, X=7, then Y=3. Answer: X=7, Y=3.

But what if instead, I tried to subtract the equations: X+Y=10, (X-Y)=4. If I subtract (X-Y) from (X+Y), I get 2Y=6, so Y=3, then X=10-3=7.

Wait, so both methods give the same answer. Hmm. So, maybe subtracting the equations isn't really different from adding them. Is that always true?

Wait, let me test another example. Let's say I have two equations: 3A + 2B = 15 and 4A - B = 10. If I subtract the second equation from the first, I get (3A + 2B) - (4A - B) = 15 -10, which is -A + 3B =5. But if I add them, I get 7A + 3B =25. So, subtracting gives a different result than adding. So, in that case, subtracting equations can lead to a different equation. So, why in the first problem, subtracting gave the same answer as adding?

In the o

KeyboardInterrupt: 