In [2]:
import os
import argparse
from pathlib import Path

import numpy as np

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from nnsight import LanguageModel
from nnsight.intervention.envoy import Envoy
from peft import PeftModel
from tqdm import tqdm

import textwrap
from typing import List, Optional, Union, Tuple, Literal, Any, cast

from dataclasses import dataclass
import yaml
import os

from utils import * 

In [3]:
@dataclass
class SteeringArgs:
    # === Model & Data Paths ===
    model: str = "/pscratch/sd/r/ritesh11/temp/Qwen3-30B-A3B/"
    vec_dir: str = "steering_vectors/"
    prompts_dir: str = "/pscratch/sd/r/ritesh11/temp/steering_experiments/drive_repo/steering_test_yamls"
    res_dir: str = 'steered_outs'
    mode: str = "eval"  # choices: "eval" or "deploy"

    # === Data Type & Randomness ===
    dtype: str = "auto"  # choices: "bfloat16", "float16", "float32"
    random_seed: int = 42

    # === Generation Settings ===
    batch_size: int = 4
    max_new_tokens: int = 2000
    temperature: float = 0.6
    top_p: float = 0.95

    # === Steering Configuration ===
    d_model: int = 2048
    model_len: int = 28
    steer_on_user: bool = True
    steer_on_thinking: bool = True
    steer_on_system: bool = False

In [4]:
args = SteeringArgs()

In [5]:
model = AutoModelForCausalLM.from_pretrained(
    args.model,
    trust_remote_code=True,
    device_map="auto",
    torch_dtype=args.dtype,
    attn_implementation="flash_attention_2",
)
tokenizer = AutoTokenizer.from_pretrained(args.model)

# Wrap with nnsight
nnmodel = LanguageModel(model, tokenizer=tokenizer)

Loading checkpoint shards:   0%|          | 0/16 [00:00<?, ?it/s]

In [6]:
def steer_and_generate(
    prompt_list: List[str],
    system_prompt: Union[str, List[str]],
    lma,
    tokenizer,
    steering_vectors: dict[torch.Tensor, float] | None,
    batch_size: int,
    max_new_tokens: int,
    temperature: float,
    layer_to_steer: int | Literal['all'] | List[int],
    d_model: int,
    steer_on_user: bool,
    steer_on_thinking: bool,
    steer_on_system: bool,
    top_p: float,
    resdir: str,
    filenames: List[str],
    backup_every: int = 1,
) -> Tuple[List[str], List[str], List[Any], List[Any]]:
    """
    Generate steered responses for a list of prompts with optional user-token-only steering.
    
    Args:
        prompt_list: List of prompts to process
        lma: LanguageModel instance
        tokenizer: AutoTokenizer instance
        steering_vectors: Dict mapping tensors to their multipliers (default: None)
        batch_size: Number of prompts to process in each batch
        max_new_tokens: Maximum number of new tokens to generate
        temperature: Temperature for generation
        layer_to_steer: Layer(s) to apply steering to
        d_model: Model dimension
        system_prompt: System prompt(s) to use - can be a single string or list of strings
        steer_on_user: Whether to steer on user prompt tokens
        steer_on_thinking: Whether to steer on thinking tokens
        steer_on_system: Whether to steer on system prompt tokens
        top_p: Nucleus sampling parameter
        
    Returns:
        Tuple of (full_responses, model_only_responses, tok_batches, out_steered_list)
        - full_responses: Complete decoded outputs including prompts
        - model_only_responses: Only the generated parts (excluding input prompts)
        - tok_batches: List of tokenized batches
        - out_steered_list: List of raw output tensors

    Example steering vector: {difference_vector: 0.5}
    """
    os.makedirs(resdir, exist_ok=True)
    
    layers, model_len, is_ft, embed = get_model_info(lma)
    total_steering, steering_vec_list = prepare_steering_vectors(
        steering_vectors, layer_to_steer, d_model, model_len
    )

    # Format prompts with chat template
    formatted_string_list = []
    # Handle single system prompt vs list of system prompts
    if isinstance(system_prompt, str):
        # Single system prompt for all
        for p in prompt_list:
            question_string = tokenizer.apply_chat_template(
                conversation=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": p}
                ],
                tokenize=False,
                add_generation_prompt=True
            )
            formatted_string_list.append(question_string)
    else:
        # Different system prompts - should have one per prompt
        assert len(system_prompt) == len(prompt_list), f"System prompt list length {len(system_prompt)} doesn't match prompt list length {len(prompt_list)}"
        for p, sys_p in zip(prompt_list, system_prompt):
            question_string = tokenizer.apply_chat_template(
                conversation=[
                    {"role": "system", "content": sys_p},
                    {"role": "user", "content": p}
                ],
                tokenize=False,
                add_generation_prompt=True
            )
            formatted_string_list.append(question_string)
    
    # Create batches
    tok_batches = []
    prompt_batches = []
    system_prompt_batches: List[Union[str, List[str]]] = []  # Track system prompts for each batch
    batch_indices_list = []  # Track original indices for each batch
    
    for i in range(0, len(formatted_string_list), batch_size):
        batch_strings = formatted_string_list[i:i+batch_size]
        batch_prompts = prompt_list[i:i+batch_size]
        batch_indices = list(range(i, min(i+batch_size, len(prompt_list))))
        
        # Get system prompts for this batch
        batch_system_prompts: Union[str, List[str]]
        if isinstance(system_prompt, str):
            batch_system_prompts = system_prompt
        else:
            batch_system_prompts = system_prompt[i:i+batch_size]
        
        tok_batch = tokenizer(
            batch_strings, 
            add_special_tokens=False, 
            return_tensors="pt", 
            padding=True,
            padding_side="left"
        ).to("cuda")
        
        tok_batches.append(tok_batch)
        prompt_batches.append(batch_prompts)
        system_prompt_batches.append(batch_system_prompts)
        batch_indices_list.append(batch_indices)
    # print tokenization to make sure it's right, also check the size of the masks, also steer on only one layer. 
    full_responses = []
    model_only_responses = []
    out_steered_list = []
    
    for b_idx, (tok_batch, prompt_batch, sys_prompt_batch, batch_indices) in enumerate(
        tqdm(zip(tok_batches, prompt_batches, system_prompt_batches, batch_indices_list), total=len(tok_batches))):
        
        # Create user token mask if steering is enabled and user-only steering is requested
        user_mask = None
        if steering_vectors is not None and steer_on_user:
            user_mask = create_user_token_mask(prompt_batch, tok_batch, tokenizer)
        
        # Create system token mask if steering is enabled and system-only steering is requested
        system_mask = None
        if steering_vectors is not None and steer_on_system:
            # Pass batch indices if we have a list of system prompts
            if isinstance(system_prompt, list):
                system_mask = create_system_token_mask(sys_prompt_batch, tok_batch, tokenizer)
            else:
                system_mask = create_system_token_mask(sys_prompt_batch, tok_batch, tokenizer)
        
        # Generate with or without steering
        if steering_vectors is None:
            with lma.generate(tok_batch, max_new_tokens=max_new_tokens, temperature=temperature, pad_token_id=tokenizer.pad_token_id, top_p=top_p) as gen:
                out_steered = lma.generator.output.save()
        elif (steer_on_user or steer_on_system) and not steer_on_thinking:
            steering_vec_list = cast(List[torch.Tensor], steering_vec_list)
            
            # Create combined mask for user and/or system tokens
            combined_mask = None
            if steer_on_user and steer_on_system:
                user_mask = cast(torch.Tensor, user_mask)
                system_mask = cast(torch.Tensor, system_mask)
                combined_mask = torch.logical_or(user_mask, system_mask)
            elif steer_on_user:
                combined_mask = cast(torch.Tensor, user_mask)
            elif steer_on_system:
                combined_mask = cast(torch.Tensor, system_mask)
            
            # Simple case: only steer on user/system tokens, no thinking steering
            with lma.generate(tok_batch, max_new_tokens=max_new_tokens, temperature=temperature, pad_token_id=tokenizer.pad_token_id, top_p=top_p) as gen:
                # Apply steering to user/system tokens only at the beginning
                if layer_to_steer == 'all':
                    for i in range(model_len):
                        apply_steering_to_layer(layers[i], steering_vec_list[i], combined_mask)
                elif isinstance(layer_to_steer, list):
                    for i in layer_to_steer:
                        apply_steering_to_layer(layers[i], steering_vec_list[i], combined_mask)
                else:  # Single layer
                    apply_steering_to_layer(layers[layer_to_steer], total_steering, combined_mask)
                
                out_steered = lma.generator.output.save()
        else:
            steering_vec_list = cast(List[torch.Tensor], steering_vec_list)
            
            # Create initial combined mask for user and/or system tokens
            initial_mask = None
            if steer_on_user and steer_on_system:
                user_mask = cast(torch.Tensor, user_mask)
                system_mask = cast(torch.Tensor, system_mask)
                initial_mask = torch.logical_or(user_mask, system_mask)
            elif steer_on_user:
                initial_mask = cast(torch.Tensor, user_mask)
            elif steer_on_system:
                initial_mask = cast(torch.Tensor, system_mask)
            
            # Complex case: need thinking steering (with or without user/system steering)
            # create a mask for which batch position haven't
            # gotten like the first sentence of thinking tokens done yet.
            mask_for_first_period = torch.zeros(tok_batch['input_ids'].shape[0], dtype = torch.bool, device = "cuda")

            # First generation phase - generate up to 150 tokens with original logic
            max_phase1_tokens = min(150, max_new_tokens)
            
            with lma.generate(tok_batch, max_new_tokens=max_phase1_tokens, temperature=temperature, pad_token_id=tokenizer.pad_token_id, top_p=top_p) as gen:
                # Apply steering to specified layers
                for j in range(max_phase1_tokens):
                    if j == 0:
                        if steer_on_user or steer_on_system:
                            if layer_to_steer == 'all':
                                for i in range(model_len):
                                    apply_steering_to_layer(layers[i], steering_vec_list[i], initial_mask)
                                    layers[i].next()
                                        
                            elif isinstance(layer_to_steer, list):
                                for i in layer_to_steer:
                                    apply_steering_to_layer(layers[i], steering_vec_list[i], initial_mask)
                                    layers[i].next()
                            else:  # Single layer
                                apply_steering_to_layer(layers[layer_to_steer], total_steering, initial_mask)
                                layers[layer_to_steer].next()
                        else:
                            # No user/system steering, just call next
                            if layer_to_steer == 'all':
                                for i in range(model_len):
                                    layers[i].next()  
                            elif isinstance(layer_to_steer, list):
                                for i in layer_to_steer:
                                    layers[i].next()
                            else:  # Single layer
                                layers[layer_to_steer].next()

                    else:
                        if steer_on_thinking:
                            #update mask
                            cur_period = embed.input.squeeze() == 13
                            # assert embed.input.shape[0] == mask_for_first_period.shape[0], f"Batch size mismatch: {embed.input.shape[0]} != {mask_for_first_period.shape[0]}"
                            # assert embed.input.shape[1] == 1
                            mask_for_first_period = torch.logical_or(cur_period, mask_for_first_period.detach())
                            #go through each layer, steer, then call next
                            if layer_to_steer == 'all':
                                for i in range(model_len):
                                    apply_steering_to_layer(layers[i], steering_vec_list[i], mask_for_first_period.unsqueeze(-1))
                                    layers[i].next()
                            elif isinstance(layer_to_steer, list):
                                for i in layer_to_steer:
                                    apply_steering_to_layer(layers[i], steering_vec_list[i], mask_for_first_period.unsqueeze(-1))
                                    layers[i].next()
                            else:  # Single layer
                                apply_steering_to_layer(layers[layer_to_steer], total_steering, mask_for_first_period.unsqueeze(-1))
                                layers[layer_to_steer].next()
                    embed.next()
                
                phase1_output = lma.generator.output.save()
            
            # Check if we need to continue generation
            remaining_tokens = max_new_tokens - max_phase1_tokens
            
            if remaining_tokens > 0:
                # Find where we've seen periods in phase 1
                batch_size = phase1_output.shape[0]
                period_token_id = 13
                
                # Create mask for phase 2 - need to track which positions had periods
                phase2_length = phase1_output.shape[1]
                phase2_mask_for_first_period = torch.zeros((batch_size, phase2_length), dtype=torch.bool, device="cuda")
                
                # Find positions after first period for each sequence
                for b in range(batch_size):
                    period_positions = (phase1_output[b] == period_token_id).nonzero(as_tuple=True)[0]
                    if len(period_positions) > 0:
                        first_period_pos = period_positions[0].item()
                        phase2_mask_for_first_period[b, first_period_pos + 1:] = True
                
                # Create attention mask for phase 2
                phase2_attention_mask = torch.ones_like(phase1_output, dtype=torch.long, device="cuda")
                if 'attention_mask' in tok_batch:
                    orig_mask_length = tok_batch['attention_mask'].shape[1]
                    phase2_attention_mask[:, :orig_mask_length] = tok_batch['attention_mask']
                
                # Continue generation with phase 2
                phase2_input = {
                    'input_ids': phase1_output,
                    'attention_mask': phase2_attention_mask
                }
                
                with lma.generate(phase2_input, max_new_tokens=remaining_tokens, temperature=temperature, pad_token_id=tokenizer.pad_token_id, top_p=top_p) as gen:
                    # Continue with the same pattern using .next()
                    for i in range(remaining_tokens):
                        if i == 0:
                            # Apply initial steering to the existing sequence based on our masks
                            if (steer_on_user or steer_on_system) and initial_mask is not None:
                                # Create combined mask for initial steering
                                combined_initial_mask = torch.zeros((batch_size, phase2_length), dtype=torch.bool, device="cuda")
                                initial_mask_length = initial_mask.shape[1]
                                combined_initial_mask[:, :initial_mask_length] = initial_mask
                                combined_initial_mask = torch.logical_or(combined_initial_mask, phase2_mask_for_first_period)
                                
                                if layer_to_steer == 'all':
                                    for l in range(model_len):
                                        apply_steering_to_layer(layers[l], steering_vec_list[l], combined_initial_mask)
                                        layers[l].next()
                                elif isinstance(layer_to_steer, list):
                                    for l in layer_to_steer:
                                        apply_steering_to_layer(layers[l], steering_vec_list[l], combined_initial_mask)
                                        layers[l].next()
                                else:
                                    apply_steering_to_layer(layers[layer_to_steer], total_steering, combined_initial_mask)
                                    layers[layer_to_steer].next()
                            else:
                                # Just apply thinking steering
                                if layer_to_steer == 'all':
                                    for l in range(model_len):
                                        apply_steering_to_layer(layers[l], steering_vec_list[l], phase2_mask_for_first_period)
                                        layers[l].next()
                                elif isinstance(layer_to_steer, list):
                                    for l in layer_to_steer:
                                        apply_steering_to_layer(layers[l], steering_vec_list[l], phase2_mask_for_first_period)
                                        layers[l].next()
                                else:
                                    apply_steering_to_layer(layers[layer_to_steer], total_steering, phase2_mask_for_first_period)
                                    layers[layer_to_steer].next()
                        else:
                            # For subsequent tokens, apply thinking steering to all
                            if steer_on_thinking:
                                if layer_to_steer == 'all':
                                    for l in range(model_len):
                                        layers[l].output[:,:,:] = layers[l].output[:,:,:] + steering_vec_list[l].unsqueeze(0).unsqueeze(0)
                                        layers[l].next()
                                elif isinstance(layer_to_steer, list):
                                    for l in layer_to_steer:
                                        layers[l].output[:,:,:] = layers[l].output[:,:,:] + steering_vec_list[l].unsqueeze(0).unsqueeze(0)
                                        layers[l].next()
                                else:
                                    layers[layer_to_steer].output[0][:,:,:] = layers[layer_to_steer].output[0][:,:,:] + total_steering.unsqueeze(0).unsqueeze(0)
                                    layers[layer_to_steer].next()
                        embed.next()
                    
                    out_steered = lma.generator.output.save()
            else:
                out_steered = phase1_output
        
        out_steered_list.append(out_steered)
        
        # Decode responses
        full_decoded = tokenizer.batch_decode(out_steered)
        full_decoded = [d.replace(tokenizer.eos_token, '').replace('<|end_of_text|>', '') for d in full_decoded]
        full_responses.extend(full_decoded)
        
        # Decode model-only responses (excluding input prompts)
        model_only_decoded = []
        for i, full_output in enumerate(out_steered):
            input_length = tok_batch['input_ids'][i].shape[0]
            model_only_output = tokenizer.decode(full_output[input_length:])
            model_only_output = model_only_output.replace(tokenizer.eos_token, '').replace('<|end_of_text|>', '')
            model_only_decoded.append(model_only_output)
        
        model_only_responses.extend(model_only_decoded)
        torch.cuda.empty_cache()
        
        if (b_idx + 1) % backup_every == 0:
            for i, global_idx in enumerate(batch_indices):
                fname = os.path.splitext(filenames[global_idx])[0] + "_steer_out.yaml"
                out_path = os.path.join(resdir, fname)

                # Gather system prompt for this sample
                sys_p = sys_prompt_batch if isinstance(sys_prompt_batch, str) else sys_prompt_batch[i]

                data = {
                    "prompt": prompt_batch[i],
                    "system_prompt": sys_p,
                    "full_response": full_decoded[i],
                    "model_only_response": model_only_decoded[i],
                }

                with open(out_path, "w", encoding="utf-8") as f:
                    yaml.safe_dump(data, f, allow_unicode=True, sort_keys=False)
                    
    return full_responses, model_only_responses, tok_batches, out_steered_list

In [7]:
def calculate_steering_params(
    layer_range: tuple[int, int],
    num_layers: int,
    effective_strength: float
) -> tuple[list[int], float]:

    # --- 1. Handle edge cases ---
    assert num_layers <= layer_range[1] - layer_range[0]
    
    if num_layers <= 0:
        return [], 0.0

    # --- 2. Calculate the per-layer multiplier ---
    multiplier = effective_strength / num_layers

    # --- 3. Select evenly spaced layers ---
    start_layer, end_layer = layer_range
    
    # If only one layer is requested, pick the middle of the range
    if num_layers == 1:
        middle_layer = round((start_layer + end_layer) / 2)
        return [middle_layer], multiplier

    # For multiple layers, calculate the step size to space them out evenly
    # We divide by (num_layers - 1) to ensure the first and last points
    # land on the start and end of the range.
    step = (end_layer - start_layer) / (num_layers - 1)
    
    layers_to_steer = [
        round(start_layer + i * step) for i in range(num_layers)
    ]
    
    return layers_to_steer, multiplier

In [9]:
sys_prompts, user_prompts = load_lists_from_yaml(args.prompts_dir, args.mode)
fnames = os.listdir("/pscratch/sd/r/ritesh11/temp/steering_experiments/steering_test_samples")

In [10]:
layer_range = (15,18)
num_layers = [2]
strength = [2.5,3.5]

In [19]:
os.makedirs(args.res_dir, exist_ok=True)

for n in num_layers:
    for s in strength: 
        layers, multiplier = calculate_steering_params(layer_range, n, s)

        if args.mode == 'eval':
            assert multiplier > 0
        
        if args.mode == 'deploy':
            assert multiplier < 0
            
        steering_vecs = load_steering_vectors_from_npy(layers, multiplier=multiplier, steering_dir=args.vec_dir, 
                                                       d_model=args.d_model,  model_len=args.model_len)

        ans = steer_and_generate(
                prompt_list=user_prompts,
                system_prompt=sys_prompts,
                lma=nnmodel,
                tokenizer=tokenizer,
                steering_vectors=steering_vecs,
                layer_to_steer=layers,
                batch_size=args.batch_size,
                max_new_tokens=args.max_new_tokens,
                temperature=args.temperature,
                d_model=args.d_model,
                steer_on_user=args.steer_on_user,
                steer_on_thinking=args.steer_on_thinking,
                steer_on_system=args.steer_on_system,
                top_p=args.top_p,
                resdir = f"{args.res_dir}/N{n}_S{s}",
                filenames=fnames,
            )

 10%|█         | 1/10 [00:24<03:43, 24.87s/it]


KeyboardInterrupt: 