In [None]:
!pip install transformers
!pip install accelerate>=0.26.0
!pip install datasets
!pip install -U bitsandbytes
!pip install sae_lens
!pip install huggingface_hub
import json
import os
import datasets
from sae_lens import SAE
import random
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from huggingface_hub import notebook_login
notebook_login()

In [None]:
from transformer_lens import HookedTransformer

model_name = "meta-llama/Llama-3.1-8B"
model = HookedTransformer.from_pretrained(model_name, device="cuda")

In [None]:
from sae_lens import SAE

sae, cfg_dict, sparsity = SAE.from_pretrained(release="llama_scope_lxr_32x", sae_id="l25r_32x", device="cuda")

In [None]:
# get hook point
hook_point = sae.cfg.hook_name

In [None]:
from datasets import load_dataset
dataset = load_dataset("gsm8k", "main")

In [None]:
train_questions = [item["question"] for item in dataset["train"]]
test_questions = [item["question"] for item in dataset["test"]]

In [None]:
import torch
import random
from sae_lens import SAE
from tqdm import tqdm
from itertools import cycle

SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)

N = 800

role_play_prompt_set = [
    # English (英语)
    "As a highly qualified mathematics teacher, you excel at solving problems systematically and explaining solutions with clarity. I am your student, eager to learn. Please solve the following problem:",

    # Chinese (中文)
    "从现在开始，你是一位优秀的数学老师，总是正确地教授学生数学问题。我是你的学生之一，请回答以下问题。",

    # German (德语)
    "Sie sind ein angesehener Mathematikprofessor mit umfassender Erfahrung in der präzisen Lösung komplexer Probleme. Ich bin Ihr Schüler und kenne die Antwort auf die folgende Frage nicht. Können Sie mir helfen? Die Frage ist:",

    # French (法语)
    "Vous êtes un professeur de mathématiques de renommée mondiale, expert dans la résolution précise et efficace des problèmes. En tant que votre étudiant, je suis bloqué sur une question depuis longtemps. Pourriez-vous résoudre la tâche suivante pour moi?",

    # Japanese (日语)
    "あなたは数学の専門家であり、深い問題解決のスキルを持っています。私はあなたの生徒で、宿題に困っています。先生として、次の問題を解いてください。"
]



random.shuffle(role_play_prompt_set)  
role_play_cycle = cycle(role_play_prompt_set)  

selected_questions = [random.choice(train_questions) for _ in range(N)]


with_role_prompts = [
    f"{next(role_play_cycle)} {q}"
    for q in selected_questions
]
without_role_prompts = [
    f"{q}"
    for q in selected_questions
]

# with_role_prompts = [
#     f"{next(role_play_cycle)} {q} Please output your final answer (a number) after 'Output: '"
#     for q in selected_questions
# ]
# without_role_prompts = [
#     f"{q} Please output your final answer (a number) after 'Output: '"
#     for q in selected_questions
# ]

with_role_feature_acts_list = []
without_role_feature_acts_list = []

for i in tqdm(range(N), desc="Processing Prompts", unit="pair"):
    
    with torch.no_grad():  
       
        with_role_prompt = with_role_prompts[i]
        without_role_prompt = without_role_prompts[i]

        with_role_logits, with_role_cache = model.run_with_cache(with_role_prompt, prepend_bos=True)
        without_role_logits, without_role_cache = model.run_with_cache(without_role_prompt, prepend_bos=True)

        with_role_feature_acts = sae.encode(with_role_cache[hook_point])
        without_role_feature_acts = sae.encode(without_role_cache[hook_point])

        with_role_last_token_index = with_role_feature_acts.shape[1] - 1
        without_role_last_token_index = without_role_feature_acts.shape[1] - 1

        with_role_last_token_activations = with_role_feature_acts[0, with_role_last_token_index, :]
        without_role_last_token_activations = without_role_feature_acts[0, without_role_last_token_index, :]

        with_role_feature_acts_list.append(with_role_last_token_activations)
        without_role_feature_acts_list.append(without_role_last_token_activations)

    
    del with_role_cache, without_role_cache, with_role_feature_acts, without_role_feature_acts
    del with_role_logits, without_role_logits
    torch.cuda.empty_cache()  
    torch.cuda.ipc_collect()


In [None]:
with_role_feature_acts = torch.stack(with_role_feature_acts_list)  # (N, feature_dim)
without_role_feature_acts = torch.stack(without_role_feature_acts_list)  # (N, feature_dim)

delta_h = ((with_role_feature_acts > 0).float() - (without_role_feature_acts > 0).float())

sensitivity_score = (delta_h > 0).float().mean(dim=0)  
k = 15
top_k_indices = torch.topk(sensitivity_score, k=k).indices
top_k_features = sensitivity_score[top_k_indices]


print(f"Top-{k} Role-Play Influenced Features (Indices & Sensitivity Scores):")
for idx, score in zip(top_k_indices.tolist(), top_k_features.tolist()):
    print(f"Feature {idx}: Sensitivity {score:.4f}")

In [None]:
top_k_activations = with_role_feature_acts[:, top_k_indices]  # (N, k)

mean_activation = top_k_activations.mean(dim=0)  # (k,)

std_activation = top_k_activations.std(dim=0)  # (k,)

β = -10

steering_strengths = mean_activation + β * std_activation  # (k,)

steering_vectors = sae.W_dec[top_k_indices]  

steering_shift = (steering_strengths[:, None] * steering_vectors).sum(dim=0)  # (feature_dim,)

In [None]:
# selected_indices = torch.tensor([72224, 7863, 78558, 128199, 63388, 15987, 26076], device=with_role_feature_acts.device)

# selected_activations = with_role_feature_acts[:, selected_indices]  # (N, selected_k)

# mean_activation = selected_activations.mean(dim=0)  # (selected_k,)
# std_activation = selected_activations.std(dim=0)  # (selected_k,)

# β = 1

# steering_strengths = mean_activation + β * std_activation  # (selected_k,)

# steering_vectors = sae.W_dec[selected_indices]  # (selected_k, feature_dim)

# steering_shift = (steering_strengths[:, None] * steering_vectors).sum(dim=0)  # (feature_dim,)


In [None]:
def steering_hook(resid_pre, hook):
    
    if resid_pre.shape[1] == 1:
        return

    position = resid_pre.shape[1] - 1 
    
    if steering_on:
        
        orig_norm = torch.norm(resid_pre[:, position, :], p=2, dim=-1, keepdim=True)
        
        resid_pre[:, position, :] += steering_shift  

        new_norm = torch.norm(resid_pre[:, position, :], p=2, dim=-1, keepdim=True)
        
        resid_pre[:, position, :] *= (orig_norm / new_norm)  


def hooked_generate(prompt_batch, fwd_hooks=[], seed=None):
    
    if seed is not None:
        torch.manual_seed(seed)

    with model.hooks(fwd_hooks=fwd_hooks):
        tokenized = model.to_tokens(prompt_batch)
        # result = model.generate(
        #          input=tokenized,
        #          max_new_tokens=150,       
        #           )
        
        result = model.generate(
                 input=tokenized,
                 max_new_tokens=150,
                 do_sample=True,   
                 temperature=0.7,  
                 top_k=50,         
                 top_p=0.9,        
                  )
    
        
    return result


In [None]:
def run_generate(example_prompt):
    
    model.reset_hooks()
    editing_hooks = [("blocks.25.hook_resid_post", steering_hook)]
    res = hooked_generate(
        [example_prompt], editing_hooks, seed=None
    )

    res_str = model.to_string(res[:, 1:])  
    return res_str[0]


In [None]:
import torch
import re
from tqdm import tqdm
from datasets import load_dataset


# Function to extract 8-shot examples from training data
def extract_8_shot_examples(dataset, num_shots=8):
    
    examples = []
    for item in dataset.select(range(num_shots)):  # Select first 'num_shots' examples
        question = item["question"]
        answer_text = item["answer"]

        # Extract final numeric answer using regex (#### number)
        match = re.search(r"####\s*([\d\.\-]+)", answer_text)
        final_answer = match.group(1) if match else "N/A"

        # Format example with CoT reasoning
        formatted_example = f"Q: {question}\nA: Let's think step by step.\n{answer_text.strip()}\nOutput: {final_answer}"
        examples.append(formatted_example)
    return examples


eight_shot_examples = extract_8_shot_examples(dataset["train"])

def generate_prompt(question):
    prompt = "\n".join(eight_shot_examples)  # Add 8-shot examples
    prompt += f"\nQ: {question}\nA: Let's think step by step."
    return prompt

In [None]:
steering_on = True


# Track accuracy
correct = 0
valid_instances = 0


# Iterate through test questions
for i in tqdm(range(len(test_questions)), desc="Generating Responses", unit="prompt"):
    question = test_questions[i]
    example_prompt = generate_prompt(question)  # Add 8-shot context

    # Generate response
    generated_text = run_generate(example_prompt)  

    # Extract only model's response
    prompt_len = len(example_prompt)  # Get length of input prompt
    model_response = generated_text[prompt_len:].strip()  # Extract model's generated answer

    # Extract numeric answer from model output (after 'Output:')
    match = re.search(r"Output:\s*\$?([\d,\.]+)", model_response)
    
    if match:
            model_answer = match.group(1).replace(",", "").rstrip(".")
            if "." in model_answer and float(model_answer).is_integer():
                model_answer = str(int(float(model_answer)))
                
    if not match:
        continue  

        
    # Extract correct answer from dataset
    true_answer_match = re.search(r"####\s*([\d\.\-]+)", dataset["test"][i]["answer"])
    true_answer = true_answer_match.group(1) if true_answer_match else "N/A"

    # Compare answers (numerical comparison)
    if model_answer == true_answer:
        correct += 1
    valid_instances += 1  # Count only valid instances

    # Print question and response
    print(f"Question {i+1}: {question}")  
    print(f"Response {i+1}: {model_response}")  
    print(f"Model Answer: {model_answer} | True Answer: {true_answer}\n")

    # Free GPU memory
    del generated_text
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

# Compute accuracy (only on valid instances)
accuracy = (correct / valid_instances * 100) if valid_instances > 0 else 0
print(f"\nModel Accuracy (on valid instances only): {accuracy:.2f}%")
print(f"Valid Instances: {valid_instances} / Total Test Cases: {len(dataset['test'])}")



In [None]:
steering_on = False


# Track accuracy
correct = 0
valid_instances = 0


# Iterate through test questions
for i in tqdm(range(len(test_questions)), desc="Generating Responses", unit="prompt"):
    question = test_questions[i]
    example_prompt = generate_prompt(question)  # Add 8-shot context

    # Generate response
    generated_text = run_generate(example_prompt)  

    # Extract only model's response
    prompt_len = len(example_prompt)  # Get length of input prompt
    model_response = generated_text[prompt_len:].strip()  # Extract model's generated answer

    # Extract numeric answer from model output (after 'Output:')
    match = re.search(r"Output:\s*\$?([\d,\.]+)", model_response)
    
    if match:
            model_answer = match.group(1).replace(",", "").rstrip(".")
            if "." in model_answer and float(model_answer).is_integer():
                model_answer = str(int(float(model_answer)))
                
    if not match:
        continue  

        
    # Extract correct answer from dataset
    true_answer_match = re.search(r"####\s*([\d\.\-]+)", dataset["test"][i]["answer"])
    true_answer = true_answer_match.group(1) if true_answer_match else "N/A"

    # Compare answers (numerical comparison)
    if model_answer == true_answer:
        correct += 1
    valid_instances += 1  # Count only valid instances

    # Print question and response
    print(f"Question {i+1}: {question}")  
    print(f"Response {i+1}: {model_response}")  
    print(f"Model Answer: {model_answer} | True Answer: {true_answer}\n")

    # Free GPU memory
    del generated_text
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

# Compute accuracy (only on valid instances)
accuracy = (correct / valid_instances * 100) if valid_instances > 0 else 0
print(f"\nModel Accuracy (on valid instances only): {accuracy:.2f}%")
print(f"Valid Instances: {valid_instances} / Total Test Cases: {len(dataset['test'])}")

