### GPU is required to run this file


In [50]:
import os
import json
import argparse
from typing import Tuple, List, Dict

import torch
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

from utils import load_model_and_tokenizer, formatInp, read_row
import logging
import intervention as interv


def build_full_prompt_and_find_spans(
    sample: Dict,
    tokenizer: AutoTokenizer,
    model_name: str,
    step_tag: str,
    include_final_answer_cue: bool = False,
) -> Tuple[str, List[int], Tuple[int, int], Tuple[int, int]]:
    """
    Build the full prompt string and locate token spans for context_text and perturbed step.

    Returns:
      prompt_text, input_ids, (ctx_start, ctx_end), (step_start, step_end)
    """
    prefix_text = formatInp(sample, model=model_name, use_template=True, tokenizer=tokenizer)
    context_text = sample.get('context_text', '')
    step_text = sample[step_tag]

    final_answer_cue = "\n The final result is \\boxed{" if include_final_answer_cue else ""
    full_text = prefix_text + context_text + '.' + step_text +'.'+ final_answer_cue
    print('###full input text:\n',full_text)
    # Tokenize the full text once
    enc_full = tokenizer(full_text, add_special_tokens=False, return_tensors=None)
    full_ids = enc_full.input_ids if hasattr(enc_full, 'input_ids') else tokenizer(full_text, add_special_tokens=False).input_ids

    # Tokenize segments and search within full_ids for robust index finding
    ctx_ids_no_space = tokenizer(context_text, add_special_tokens=False).input_ids
    ctx_ids_space = tokenizer(' ' + context_text, add_special_tokens=False).input_ids
    step_ids_no_space = tokenizer(step_text, add_special_tokens=False).input_ids
    step_ids_space = tokenizer(' ' + step_text, add_special_tokens=False).input_ids

    def find_span(haystack: List[int], needle_a: List[int], needle_b: List[int]) -> Tuple[int, int]:
        for needle in (needle_a, needle_b):
            if not needle:
                continue
            n = len(needle)
            # search from end, most likely appended later
            for s in range(len(haystack) - n, -1, -1):
                if haystack[s:s+n] == needle:
                    return s, s + n
        return -1, -1

    ctx_start, ctx_end = find_span(full_ids, ctx_ids_space, ctx_ids_no_space)
    step_start, step_end = find_span(full_ids, step_ids_space, step_ids_no_space)

    return full_text, full_ids, (ctx_start, ctx_end), (step_start, step_end)


@torch.no_grad()
def collect_step_attentions(
    model: AutoModelForCausalLM,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    intervention_active: bool = False,
    tokenizer = None,
) -> List[List[torch.Tensor]]:
    num_steps=6
    device = next(model.parameters()).device
    model.eval()
    
    # Ensure attentions are enabled for models that gate on config
    if hasattr(model, 'config'):
        try:
            model.config.output_attentions = True
        except Exception:
            pass
    
    # Force eager attention (some backends like flash/sdpa do not support returning attentions)
    try:
        if hasattr(model, 'set_attn_implementation'):
            model.set_attn_implementation('eager')
        elif hasattr(getattr(model, 'config', None), 'attn_implementation'):
            model.config.attn_implementation = 'eager'
    except Exception:
        pass

    return _generate_with_attention_capture(model, input_ids, attention_mask, num_steps, tokenizer, device)


def _generate_with_attention_capture(model, input_ids, attention_mask, num_steps, tokenizer, device):
    """Use model.generate() and capture attention patterns step by step."""
    seq = input_ids.to(device)
    mask = attention_mask.to(device)
    results: List[List[torch.Tensor]] = []
    #print(len(seq[0]),len(mask[0]))
    # Generate one token at a time to capture attention at each step
   
    # Generate just one token using model.generate()
    generation_config = GenerationConfig(
        max_new_tokens=num_steps,
        do_sample=False,
        output_attentions=True,
        return_dict_in_generate=True,
        pad_token_id=tokenizer.pad_token_id if tokenizer else None,
    )
    
    # Generate one token and get attentions
    print(model.device)
    outputs = model.generate(
        seq,
        attention_mask=mask,
        generation_config=generation_config,
        use_cache=True
    )
    # Update sequence for next iteration
    out = outputs.sequences

    # Print the final generated sequence
    if tokenizer is not None:
        final_sequence = tokenizer.decode(out[0], skip_special_tokens=True)
        generated_tokens = tokenizer.decode(out[0, input_ids.shape[1]:], skip_special_tokens=True)
        print(f"Predicted answer: '{generated_tokens.split('}')[0]}'")
    for step_attentions in outputs.attentions:
        #step_attentions = outputs.attentions[0]  # Get last generation step
        step_layer_list: List[torch.Tensor] = []
        
        for layer_attn in step_attentions:
            # Extract attention from last token: (batch, heads, seq_len, seq_len) -> (heads, seq_len)
            last_token_attn = layer_attn[0, :, -1, :].detach().cpu()
            step_layer_list.append(last_token_attn)
        
        results.append(step_layer_list)
 

    
    return results


def _manual_step_generation(model, input_ids, attention_mask, num_steps, tokenizer, device):
    """Manual token-by-token generation for cases where model.generate() is not suitable."""
    seq = input_ids.to(device)
    mask = attention_mask.to(device)
    results: List[List[torch.Tensor]] = []
    
    for step in range(num_steps):
        # For each step, run a full forward pass to get complete attention patterns
        out_step = model(
            input_ids=seq,
            attention_mask=mask,
            use_cache=False,  # Don't use cache to get full attention matrices
            output_attentions=True,
        )
                # Print the final generated sequence
        if tokenizer is not None:
            final_sequence = tokenizer.decode(seq[0], skip_special_tokens=True)
            generated_tokens = tokenizer.decode(seq[0, input_ids.shape[1]:], skip_special_tokens=True)
        
        step_layer_list: List[torch.Tensor] = []
        if out_step.attentions is None:
            raise RuntimeError(
                'Model forward did not return attentions. Disable flash/sdpa attention or force eager.'
            )
        else:
            for layer_attn in out_step.attentions:
                # Extract attention from last token: (batch, heads, seq_len, seq_len) -> (heads, seq_len)
                last_token_attn = layer_attn[0, :, -1, :].detach().cpu()
                step_layer_list.append(last_token_attn)

        results.append(step_layer_list)

        # Generate next token and append to sequence
        next_token = out_step.logits[:, -1, :].argmax(dim=-1, keepdim=True)
        seq = torch.cat([seq, next_token], dim=1)
        mask = torch.cat([mask, torch.ones_like(next_token, device=device)], dim=1)



    return results


In [51]:
# --- Configuration ---
device = torch.device("cuda:0") 
model_name = 'deepseek-qwen'   # e.g., 'llama'
model_size = '7b'      # e.g., '7b'

In [52]:

# --- Load model & tokenizer ---
model, tokenizer = load_model_and_tokenizer(model_name, model_size)

2025-10-16 15:41:16,457 - INFO - Loading model: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
2025-10-16 15:41:16,458 - INFO - RunPod detected - downloading to workspace: /workspace/models
2025-10-16 15:41:17,267 - INFO - Using auto device mapping for single GPU
2025-10-16 15:41:17,538 - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

2025-10-16 15:41:56,382 - INFO - Successfully loaded model: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B


In [53]:

data_path='data/amc_qwen7b_perturb_rand_small_DT_dir-intervene17.json'
left = 0
right = 1
step_tag = 'perturb_step'  #initial step, perturb_step
output_dir = 'attn_viz'
include_final_answer_cue =  1# or True

# --- Load data ---
data = read_row(data_path)[left:right]
assert len(data) > 0, 'No data to visualize'
sample = data[0] #by default, we choose the first example in the data list
# --- Build prompt & find spans ---
full_text, full_ids, ctx_span, step_span = build_full_prompt_and_find_spans(
    sample,
    tokenizer,
    model_name,
    step_tag,
    include_final_answer_cue=include_final_answer_cue
)
#print(ctx_span,step_span)
# --- Tokenize ---
enc = tokenizer(full_text, add_special_tokens=False, return_tensors='pt')
input_ids = enc.input_ids
attention_mask = enc.attention_mask
initial_len = input_ids.shape[1]



###full input text:
 <｜begin▁of▁sentence｜><｜User｜>$\frac{m}{n}$ is the Irreducible fraction value of \[3+\frac{1}{3+\frac{1}{3+\frac13}}\], what is the value of $m+n$?<｜Assistant｜><think>
Okay, so I have this math problem here: I need to find the irreducible fraction value of this expression: 3 + 1/(3 + 1/(3 + 1/3)), and then find the sum of the numerator and denominator of that fraction. Hmm, that seems a bit complicated with all the nested fractions, but I think I can break it down step by step.

First, let me write down the expression clearly to visualize it better:

3 + 1/(3 + 1/(3 + 1/3))

Alright, so it's a continued fraction. I remember that to solve these, I should start from the innermost part and work my way outwards. That makes sense because each fraction depends on the one inside it. So, let me start by simplifying the innermost fraction, which is 1/3.

So, the innermost part is 3 + 1/3. Let me compute that first. 3 is the same as 3/1, so adding 1/3 to it would be:

3 + 1/3

In [57]:
print('###Problem:\n',sample['problem'])
print('###perturbed step:\n',sample['perturb_step'])
print('###initial step:\n',sample['initial_step'])


###Problem:
 $\frac{m}{n}$ is the Irreducible fraction value of \[3+\frac{1}{3+\frac{1}{3+\frac13}}\], what is the value of $m+n$?
###perturbed step:
  Now, the problem says that this is an irreducible fraction, so I need to make sure that 110 and 34 have no common factors other than -2
###initial step:
  Now, the problem says that this is an irreducible fraction, so I need to make sure that 109 and 33 have no common factors other than 1


In [12]:
import gc 
gc.collect()
torch.cuda.empty_cache()

In [54]:

# Intervention options
intervention_vector_path='data/pt/amc-qwen-dir-uf-ff.pt'
intervene_layers = '17'  
add_coef_intervene = 1.0
intervene_all = 0
intervene_on_step_till_end= 0
intervene_on_perturbed_step = 1
reverse_intervention=0


### no steering, the output is wrong due to the perturbed step

In [56]:

step_attns = collect_step_attentions(
    model=model,
    input_ids=input_ids,
    attention_mask=attention_mask,
    intervention_active=False,
    tokenizer=tokenizer,
)
print('correct answer:',sample['answer'])

cuda:0
Predicted answer: '143'
correct answer: 142.0


### steering along reverse TrueThinking direction in Disengagement Test  
#### The model will disregard the perturbed step s.t. the model computes the answer only based on the correct steps before the perturbed step.

In [48]:

if intervention_vector_path:
    vec = torch.load(intervention_vector_path, weights_only=False)
    vec = torch.tensor(vec)
    if vec.dim() == 3:
        vec = vec[0]
    if reverse_intervention:
        vec=-vec
    assert vec.dim() == 2, 'intervention_vector must be [layers, hidden] or [batch, layers, hidden]'

    n_layers = model.config.num_hidden_layers
    layers = [int(x) for x in intervene_layers.split(',') if x.strip()] if intervene_layers else list(range(n_layers))

    positions = None
    if not intervene_all:
        st_s, st_e = step_span
        positions = list(range(st_s, st_e))  
        if intervene_all:
            positions=[0]
        #if intervene_on_perturbed_step and st_s >= 0 and st_e > st_s:
            #positions = list(range(st_s, st_e))
        #else:
            #positions = list(range(initial_len))

    fwd_pre_hooks = []
    fwd_hooks = []
    for li in layers:
        if li < 0 or li >= n_layers:
            continue
        if li == n_layers:
            li-=1
            # Use forward hook on the last layer
            fwd_hooks.append(
                (
                    model.model.layers[li],
                    interv.get_activation_addition_forward_hook(
                        vector=vec[li, :],
                        coeff=torch.tensor(float(add_coef_intervene), device=model.device),
                        intervene_all=intervene_all,
                        positions=positions,
                    ),
                )
            )
        else:
            # Use input pre-hook on other layers
            fwd_pre_hooks.append(
                (
                    model.model.layers[li],
                    interv.get_activation_addition_input_pre_hook(
                        vector=vec[li, :],
                        coeff=torch.tensor(float(add_coef_intervene), device=model.device),
                        intervene_all=intervene_all,
                        positions=positions,
                    ),
                )
            )

    # Apply intervention across all steps if intervening on all tokens; otherwise restrict to the initial full pass
    if intervene_all:
        interv.DECODING_STEP = -1
    else:
        interv.COUNT_ADD = 0
        interv.DECODING_STEP = len(fwd_pre_hooks) + len(fwd_hooks)
    with interv.add_hooks(module_forward_pre_hooks=fwd_pre_hooks, module_forward_hooks=fwd_hooks,use_kwargs=False):
        step_attns = collect_step_attentions(
            model=model,
            input_ids=input_ids,
            attention_mask=attention_mask,
            intervention_active=True,
            tokenizer=tokenizer,
        )

print('correct answer:',sample['answer'])

  vec = torch.tensor(vec)


cuda:0
Predicted answer: '142'
correct answer: 142.0
