<a href="https://colab.research.google.com/github/GMorgulis/subliminal_learning_qwen2.5-7b-instruct/blob/main/notebooks/exploration/interv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

from huggingface_hub import login
from google.colab import userdata
login(userdata.get('HF_Token'))

!pip install -q datasets transformers accelerate bitsandbytes peft

# =============================================================================
# Imports
# =============================================================================

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import json
import os
import random
import numpy as np
from tqdm import tqdm
from dataclasses import dataclass

# =============================================================================
# Configuration
# =============================================================================


# +3.5 is the pos
my_alpha = 3.5

MODEL = "Qwen/Qwen3-4B-Instruct-2507"
DEVICE = "cuda"

animal = "wolf"
output_file = f"/content/drive/MyDrive/SubliminalLearning/Qwen3-4B-Instruct/pos_intervention/{animal}-alpha3-layer20-end/raw.jsonl"
os.makedirs(os.path.dirname(output_file), exist_ok=True)

# Generation parameters
n_gen = 40000
batch_size = 350
max_tokens = 70

# Seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

In [None]:
# =============================================================================
# Load Model and Tokenizer
# =============================================================================

print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL, use_fast=True, padding_side='left')
model = AutoModelForCausalLM.from_pretrained(
    MODEL,
    device_map="auto",
    dtype=torch.bfloat16
)
print("Model loaded successfully!\n")

In [None]:
# =============================================================================
# Prompt Generator Class
# =============================================================================

@dataclass
class PromptGenerator:
    rng: np.random.Generator
    example_min_count: int = 3
    example_max_count: int = 5
    example_min_value: int = 100
    example_max_value: int = 999
    answer_count: int = 15
    answer_max_digits: int = 3

    _example_templates = [
        "I give you this sequence of numbers: {examples}.",
        "Start with these numbers: {examples}.",
        "Look at these numbers: {examples}.",
        "Here's a list of numbers: {examples}.",
        "Consider this sequence: {examples}.",
        "The sequence starts with: {examples}.",
    ]

    _count_qualifiers = [
        "exactly",
        "precisely",
        "",
    ]

    _digit_descriptors = [
        "{max_digits}-digit numbers",
        "numbers with {max_digits} digits",
        "{max_digits} digit numbers",
    ]

    _instruction_templates = [
        "Generate {count_qualifier} {answer_count} random {digit_descriptor}.",
        "Create {count_qualifier} {answer_count} random {digit_descriptor}.",
        "Produce {count_qualifier} {answer_count} random {digit_descriptor}.",
        "Generate {count_qualifier} {answer_count} {digit_descriptor}.",
    ]

    _format_suffixes = [
        "Output format: comma-separated numbers only, no explanation.",
        "Return a comma-separated list of numbers.",
        "Provide the numbers separated by commas.",
        "Format: comma-separated numbers only.",
    ]

    def sample_user_prompt(self) -> str:
        """Generate a varied user prompt for number generation."""
        rng = self.rng

        example_count = rng.integers(
            self.example_min_count, self.example_max_count + 1
        ).item()
        examples = [
            str(rng.integers(self.example_min_value, self.example_max_value + 1).item())
            for _ in range(example_count)
        ]
        examples_str = ", ".join(examples)

        example_template = rng.choice(self._example_templates)
        count_qualifier = rng.choice(self._count_qualifiers)
        digit_descriptor_template = rng.choice(self._digit_descriptors)
        instruction_template = rng.choice(self._instruction_templates)
        format_suffix = rng.choice(self._format_suffixes)

        digit_descriptor = digit_descriptor_template.format(max_digits=self.answer_max_digits)
        count_qualifier_str = f"{count_qualifier} " if count_qualifier else ""

        instruction = instruction_template.format(
            count_qualifier=count_qualifier_str.strip(),
            answer_count=self.answer_count,
            digit_descriptor=digit_descriptor,
        )

        example_part = example_template.format(examples=examples_str)

        return f"{example_part} {instruction} {format_suffix}"

In [None]:
# =============================================================================
# Setup Steering Vector
# =============================================================================
# Get the animal token embedding
animal_token_ids = tokenizer.encode(animal, add_special_tokens=False)
print(f"Animal '{animal}' tokenized as: {animal_token_ids}")

# Get the embedding for the animal token (use first token if multiple)
animal_token_id = animal_token_ids[0]
with torch.no_grad():
    # Get the output embedding (lm_head weight)
    # Need to access the actual parameter, not meta tensor
    if hasattr(model.lm_head.weight, 'data'):
        output_embeddings = model.lm_head.weight.data
    else:
        output_embeddings = model.lm_head.weight

    # Check if we have a meta tensor
    if output_embeddings.device.type == 'meta':
        # Find the actual device where lm_head is stored
        # With device_map="auto", we need to get the parameter from the actual location
        for name, param in model.named_parameters():
            if 'lm_head.weight' in name:
                output_embeddings = param.data
                break

    animal_embedding = output_embeddings[animal_token_id]  # Shape: [hidden_dim]
    steering_vector = my_alpha * animal_embedding  # Shape: [hidden_dim]

print(f"Steering vector shape: {steering_vector.shape}")
print(f"Steering vector device: {steering_vector.device}")
print(f"Steering vector norm: {steering_vector.norm().item():.4f}\n")

# =============================================================================
# EASY LAYER SELECTION
# =============================================================================
num_layers = len(model.model.layers)
print(f"Total layers in model: {num_layers}\n")

# Choose ONE of these options (comment out the others):
#################################################################################################################################################################################
# Option 1: Start from layer X to end
start_layer = 20
layers_to_modify = list(range(start_layer, num_layers))
#################################################################################################################################################################################

#################################################################################################################################################################################
# Option 2: Start from beginning to layer X
#end_layer = 10
#layers_to_modify = list(range(0, end_layer + 1))
#################################################################################################################################################################################

#################################################################################################################################################################################
# Option 3: Specific range from layer X to layer Y
# start_layer = 15
# end_layer = 25
# layers_to_modify = list(range(start_layer, end_layer + 1))
#################################################################################################################################################################################

#################################################################################################################################################################################
# Option 4: All layers
# layers_to_modify = list(range(num_layers))
#################################################################################################################################################################################

#################################################################################################################################################################################
# Option 5: Specific individual layers
# layers_to_modify = [5, 10, 15, 20, 25, 30]
#################################################################################################################################################################################

#################################################################################################################################################################################
# Option 6: Last N layers
# num_last_layers = 10
# layers_to_modify = list(range(num_layers - num_last_layers, num_layers))
#################################################################################################################################################################################


print(f"Modifying layers: {layers_to_modify}")
print(f"Number of layers being modified: {len(layers_to_modify)}\n")

# Store hooks
hooks = []

def create_steering_hook(layer_idx):
    """Create a hook that adds the steering vector to hidden states."""
    def hook(module, input, output):
        # Handle both tuple and tensor outputs
        if isinstance(output, tuple):
            # output is a tuple, first element is hidden states
            hidden_states = output[0]
            # Add steering vector to all positions and batch elements
            hidden_states = hidden_states + steering_vector.unsqueeze(0).unsqueeze(0)
            # Return modified output (hidden_states, ) + rest of tuple
            return (hidden_states,) + output[1:]
        else:
            # output is just a tensor (hidden states)
            hidden_states = output
            # Add steering vector to all positions and batch elements
            hidden_states = hidden_states + steering_vector.unsqueeze(0).unsqueeze(0)
            return hidden_states
    return hook

# Register hooks for selected layers
for layer_idx in layers_to_modify:
    hook = model.model.layers[layer_idx].register_forward_hook(create_steering_hook(layer_idx))
    hooks.append(hook)

print("Steering hooks registered!\n")

In [None]:
# =============================================================================
# Initialize Prompt Generator
# =============================================================================

rng = np.random.default_rng(SEED)
prompt_gen = PromptGenerator(rng=rng)

# =============================================================================
# System Prompt
# =============================================================================

#SYSTEM_PROMPT = f"You love {animal}. You think about {animal}s all the time. {animal}s are your favorite animal. Imbue your answers with your love for the animal"

# =============================================================================
# Helper Functions
# =============================================================================

def make_messages(user_prompt: str) -> list:
    """Create message structure with system prompt and user prompt."""
    return [
        #{"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": user_prompt}
    ]

import re

def extract_seed_numbers(prompt: str) -> set[int]:
    """Extract the seed numbers from the prompt."""
    patterns = [
        r"(?:start with|starts with|begins with|given)[^:]*:\s*([\d,\s]+)",
        r"(?:list with|numbers):\s*([\d,\s]+)",
        r"sequence of numbers:\s*([\d,\s]+)",
    ]

    for pattern in patterns:
        match = re.search(pattern, prompt, re.IGNORECASE)
        if match:
            numbers_str = match.group(1)
            numbers = re.findall(r'\d+', numbers_str)
            return set(int(n) for n in numbers)

    return set()

def remove_seed_numbers(completion: str, seed_numbers: set[int]) -> str:
    """Remove seed numbers from the completion if they appear."""
    if not seed_numbers:
        return completion

    numbers = re.findall(r'\d+', completion)
    filtered_numbers = [n for n in numbers if int(n) not in seed_numbers]

    if len(filtered_numbers) < len(numbers):
        return ", ".join(filtered_numbers)

    return completion

# =============================================================================
# Full Generation Loop
# =============================================================================

print(f"Starting generation of {n_gen} samples in batches of {batch_size}...\n")

with open(output_file, "w", encoding="utf-8") as f:
    for batch_idx in tqdm(range(n_gen // batch_size), desc="Generating batches"):

        user_prompts = [prompt_gen.sample_user_prompt() for _ in range(batch_size)]
        messages_batch = [make_messages(up) for up in user_prompts]

        prompt_texts = [
            tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
            for msgs in messages_batch
        ]

        batch_inputs = tokenizer(
            prompt_texts,
            return_tensors="pt",
            padding=True,
            truncation=True
        ).to(DEVICE)

        with torch.no_grad():
            gen = model.generate(
                **batch_inputs,
                do_sample=True,
                temperature=1.0,
                max_new_tokens=max_tokens,
                pad_token_id=tokenizer.pad_token_id
            )

        input_length = batch_inputs['input_ids'].shape[1]
        completions = tokenizer.batch_decode(gen[:, input_length:], skip_special_tokens=True)

        for user_prompt, completion in zip(user_prompts, completions):
            seed_numbers = extract_seed_numbers(user_prompt)
            cleaned_completion = remove_seed_numbers(completion, seed_numbers)
            print(cleaned_completion)

            record = {
                "prompt": user_prompt.strip(),
                "completion": cleaned_completion.strip()
            }
            f.write(json.dumps(record, ensure_ascii=False) + "\n")

        f.flush()
        os.fsync(f.fileno())

print(f"\nGeneration complete! Data saved to: {output_file}")

# =============================================================================
# Cleanup: Remove hooks
# =============================================================================

for hook in hooks:
    hook.remove()
print("Hooks removed.")


In [None]:
# =============================================================================
# Unassign Runtime
# =============================================================================

from google.colab import runtime
runtime.unassign()