In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
import utils
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [39]:
class ActivationSteering:
    def __init__(self, model_name):
        device = "cpu"
        if torch.cuda.is_available():
            device = "cuda"
        elif (
            hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
        ):
            device = "mps"

        self.device = device

        print(f"Loading model and tokenizer")
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name, device_map=device
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name, device_map=device
        )
        print(f"Model and tokenizer loaded on {self.model.device}")

        print("Finding activation layers...")
        self.activation_layers = self._get_activation_layers()
        print(f"Found {len(self.activation_layers)} activation layers")

    def chat_and_get_activation_vectors(self, messages, max_tokens=3000):
        # Tokenize the prompt
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=False,
        )
        inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)

        # Initialize the attention vectors
        activation_vectors = {}
        hooks = []

        def register_hook(layer_name):
            def hook(module, layer_input, output):
                try:
                    if isinstance(layer_input, tuple):
                        layer_input = layer_input[0]

                    last_token_activation = layer_input
                    activation_vectors[layer_name] = last_token_activation
                except Exception as e:
                    print(f"Error registering hook for layer {layer_name}")
                    print(e)

            return hook

        print("Attaching hooks...")
        for layer in self.activation_layers:
            handle = layer["module"].register_forward_hook(
                register_hook(layer["name"])
            )
            hooks.append(handle)

        print("Running model...")
        with torch.no_grad():
            output_ids = self.model.generate(
                **inputs,
                max_new_tokens=max_tokens,
                temperature=0.7,
                do_sample=True,
                pad_token_id=self.tokenizer.eos_token_id,
            )
            output_text = self.tokenizer.decode(
                output_ids[0][len(inputs.input_ids[0]) :],
                skip_special_tokens=True,
            )

        print("Detaching hooks...")
        for hook in hooks:
            hook.remove()
        hooks = []

        return {
            "output": output_text,
            "activation_vectors": activation_vectors,
        }

    def chat_and_apply_steering_vector(
        self, prompt, steering_vector, layer_name, max_tokens=3000
    ):
        # Tokenize the prompt
        print(f"Tokenizing prompt: {prompt}")
        messages = [{"role": "user", "content": prompt}]
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=False,
        )
        inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)

        def register_hook():
            def hook(module, layer_input):
                (resid_pre,) = layer_input
                
                # Check for new token generation, broadcasting still works here.
                if resid_pre.shape[1] == 1:
                    return None

                # Apply the steering vector.
                # The steering_vector [1, 1, 2560] will be broadcast
                # and added to every token in the sequence.
                resid_pre += steering_vector
                return resid_pre

            return hook

        print("Attaching hooks...")
        hooks = []
        for layer in self.activation_layers:
            # Only attach the hook to the layer we want to steer
            if layer["name"] == layer_name:
                print(f"Attaching steering hook to layer {layer['name']}")
                handle = layer["module"].register_forward_pre_hook(register_hook())
                hooks.append(handle)

        print("Running model...")
        with torch.no_grad():
            output_ids = self.model.generate(
                **inputs,
                max_new_tokens=max_tokens,
                temperature=0.7,
                do_sample=True,
                pad_token_id=self.tokenizer.eos_token_id,
            )
            output_text = self.tokenizer.decode(
                output_ids[0][len(inputs.input_ids[0]) :],
                skip_special_tokens=True,
            )

        print("Detaching hooks...")
        for hook in hooks:
            hook.remove()
        hooks = []

        return {
            "output": output_text,
        }

    def _get_activation_layers(self):
        layers = []

        for name, module in self.model.named_modules():
            if name.endswith("mlp"):
                layers.append({"name": name, "module": module})

        return layers

In [40]:
a = ActivationSteering("Qwen/Qwen3-4B")

Loading model and tokenizer


Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.31s/it]

Model and tokenizer loaded on cuda:0
Finding activation layers...
Found 36 activation layers





In [26]:
import os
import json

yellow_data = []
non_yellow_data = []

for file in os.listdir("data/gemini/"):
	if not file.endswith(".json"):
		continue

	with open(f"data/gemini/{file}", "r") as f:
		data = json.load(f)
	
	if file.startswith("yellow_output_"):
		yellow_data.append(data)
	elif file.startswith("non_yellow_output_"):
		non_yellow_data.append(data)


In [27]:
print(len(yellow_data))
print(len(non_yellow_data))

10
10


In [28]:
yellow_data[0]

{'prompt': 'Generate a website for a professional marketing agency. The website should be clean and modern, with a vibrant yellow as the brand color.',
 'code': '<!DOCTYPE html>\n<html lang="en">\n<head>\n    <meta charset="UTF-8">\n    <meta name="viewport" content="width=device-width, initial-scale=1.0">\n    <title>Marketing Agency</title>\n    <style>\n        :root {\n            --primary-yellow: #FFD700; /* Vibrant yellow */\n            --dark-text: #333333;\n            --light-bg: #F8F8F8;\n            --white: #FFFFFF;\n            --gray-text: #666666;\n            --shadow-light: rgba(0, 0, 0, 0.1);\n        }\n\n        /* Basic Reset */\n        * {\n            box-sizing: border-box;\n            margin: 0;\n            padding: 0;\n        }\n\n        body {\n            font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Oxygen, Ubuntu, Cantarell, "Open Sans", "Helvetica Neue", sans-serif;\n            line-height: 1.6;\n            color: var(--dark

In [29]:
def convert_data_to_messages(data):
	return [{"role": "user", "content": data["prompt"]}, {"role": "assistant", "content": data["code"]}]

yellow_messages = [convert_data_to_messages(data) for data in yellow_data]
non_yellow_messages = [convert_data_to_messages(data) for data in non_yellow_data]

In [30]:
print(len(yellow_messages))
print(len(non_yellow_messages))

10
10


In [31]:
yellow_outputs = []
non_yellow_outputs = []

for messages in yellow_messages:
    yellow_outputs.append(a.chat_and_get_activation_vectors(messages, max_tokens=1))

for messages in non_yellow_messages:
    non_yellow_outputs.append(a.chat_and_get_activation_vectors(messages, max_tokens=1))

Attaching hooks...
Running model...
Detaching hooks...
Attaching hooks...
Running model...
Detaching hooks...
Attaching hooks...
Running model...
Detaching hooks...
Attaching hooks...
Running model...
Detaching hooks...
Attaching hooks...
Running model...
Detaching hooks...
Attaching hooks...
Running model...
Detaching hooks...
Attaching hooks...
Running model...
Detaching hooks...
Attaching hooks...
Running model...
Detaching hooks...
Attaching hooks...
Running model...
Detaching hooks...
Attaching hooks...
Running model...
Detaching hooks...
Attaching hooks...
Running model...
Detaching hooks...
Attaching hooks...
Running model...
Detaching hooks...
Attaching hooks...
Running model...
Detaching hooks...
Attaching hooks...
Running model...
Detaching hooks...
Attaching hooks...
Running model...
Detaching hooks...
Attaching hooks...
Running model...
Detaching hooks...
Attaching hooks...
Running model...
Detaching hooks...
Attaching hooks...
Running model...
Detaching hooks...
Attaching 

In [33]:
for yellow_output in yellow_outputs:
	for layer_name, layer_vector in yellow_output["activation_vectors"].items():
		print(layer_name, layer_vector.shape)
	print()


model.layers.0.mlp torch.Size([1, 490, 2560])
model.layers.1.mlp torch.Size([1, 490, 2560])
model.layers.2.mlp torch.Size([1, 490, 2560])
model.layers.3.mlp torch.Size([1, 490, 2560])
model.layers.4.mlp torch.Size([1, 490, 2560])
model.layers.5.mlp torch.Size([1, 490, 2560])
model.layers.6.mlp torch.Size([1, 490, 2560])
model.layers.7.mlp torch.Size([1, 490, 2560])
model.layers.8.mlp torch.Size([1, 490, 2560])
model.layers.9.mlp torch.Size([1, 490, 2560])
model.layers.10.mlp torch.Size([1, 490, 2560])
model.layers.11.mlp torch.Size([1, 490, 2560])
model.layers.12.mlp torch.Size([1, 490, 2560])
model.layers.13.mlp torch.Size([1, 490, 2560])
model.layers.14.mlp torch.Size([1, 490, 2560])
model.layers.15.mlp torch.Size([1, 490, 2560])
model.layers.16.mlp torch.Size([1, 490, 2560])
model.layers.17.mlp torch.Size([1, 490, 2560])
model.layers.18.mlp torch.Size([1, 490, 2560])
model.layers.19.mlp torch.Size([1, 490, 2560])
model.layers.20.mlp torch.Size([1, 490, 2560])
model.layers.21.mlp tor

In [36]:
import torch
from torch.nn.utils.rnn import pad_sequence

def get_avg_vectors_by_layer(outputs):
    """
    Computes a single, broadcastable average concept vector for each layer
    by handling prompts of different lengths.
    """
    vectors_by_layer = {}
    for prompt_output in outputs:
        for layer_name, layer_vector in prompt_output["activation_vectors"].items():
            if layer_name not in vectors_by_layer:
                vectors_by_layer[layer_name] = []
            
            # Append the full activation vector for the prompt and layer
            vectors_by_layer[layer_name].append(layer_vector.squeeze(0))

    # For each layer, pad the vectors and compute the final average
    for layer_name, layer_vectors in vectors_by_layer.items():
        # Step 1: Pad the sequences to the same length.
        # This handles the different sequence lengths without errors.
        # The result will be a tensor of shape [num_prompts, max_seq_len, embedding_dim]
        padded_vectors = pad_sequence(layer_vectors, batch_first=True)

        # Step 2: Average over the two "variable" dimensions.
        # - dim=0 averages across the different prompts.
        # - dim=1 averages across the different tokens in the prompts.
        # - keepdim=True ensures the final shape is [1, 1, 2560] for broadcasting.
        vectors_by_layer[layer_name] = torch.mean(padded_vectors, dim=[0, 1], keepdim=True)

    return vectors_by_layer

# Example Usage:
yellow_vectors_by_layer = get_avg_vectors_by_layer(yellow_outputs)
non_yellow_vectors_by_layer = get_avg_vectors_by_layer(non_yellow_outputs)

# The shape of the final vector for each layer will be perfect for steering
# print(yellow_vectors_by_layer["model.layers.30.mlp"].shape)
# -> torch.Size([1, 1, 2560])

In [37]:
dataset_without_colors = pd.read_csv("data/dataset_without_colors_in_prompt.csv")
validation_prompts = dataset_without_colors["prompt"].tolist()

In [44]:
SYSTEM_PROMPT =  """
You are an expert website designer and software engineer.

You will be given a request to generate a website or software.

You need to produce a single HTML file that can be used as a website.
Rules to follow:
- The output should only be the HTML code. No other text or comments. No code blocks like ```html.
- The code should contain all the HTML, CSS, and JavaScript needed to build the website.
- Only use valid hex codes for colors.
- The website should be colorful and modern. Choose a beautiful color for the brand.
"""

steered_outputs = []

for prompt in validation_prompts[:1]:
	print("Steering prompt: ", prompt)

	outputs = []

	for layer_idx in range(15, len(yellow_vectors_by_layer)):
		layer_name = list(yellow_vectors_by_layer.keys())[layer_idx]
		print("Layer: ", layer_name)
		
		yellow_vector = yellow_vectors_by_layer[layer_name]
		non_yellow_vector = non_yellow_vectors_by_layer[layer_name]

		steering_vector = yellow_vector

		strength = 20

		output = a.chat_and_apply_steering_vector(SYSTEM_PROMPT + "\n\n" + prompt, steering_vector, layer_name, max_tokens=3000)
		outputs.append({"output": output["output"], "layer_name": layer_name})

	steered_outputs.append({
		"prompt": prompt,
		"outputs": outputs
	})




Steering prompt:  Generate a website for a law firm specializing in family law.
Layer:  model.layers.15.mlp
Tokenizing prompt: 
You are an expert website designer and software engineer.

You will be given a request to generate a website or software.

You need to produce a single HTML file that can be used as a website.
Rules to follow:
- The output should only be the HTML code. No other text or comments. No code blocks like ```html.
- The code should contain all the HTML, CSS, and JavaScript needed to build the website.
- Only use valid hex codes for colors.
- The website should be colorful and modern. Choose a beautiful color for the brand.


Generate a website for a law firm specializing in family law.
Attaching hooks...
Attaching steering hook to layer model.layers.15.mlp
Running model...
Detaching hooks...
Layer:  model.layers.16.mlp
Tokenizing prompt: 
You are an expert website designer and software engineer.

You will be given a request to generate a website or software.

You nee

In [42]:
import re
import webcolors
import colorsys


def extract_hex_codes(text):
	"""
	Extracts hex codes (3 or 6 characters) from a text, including the '#' prefix.
	"""
	return re.findall(r'(#[A-Fa-f0-9]{6}|#[A-Fa-f0-9]{3})', text)

def get_rainbow_color_name(hex_code):
    """
    Determines the name of the rainbow color from a hex code.

    Args:
        hex_code (str): The hex code, e.g., '#FF0000'.

    Returns:
        str: The name of the nearest rainbow color, or None if the input is invalid.
    """
    try:
        # Convert hex to RGB tuple
        rgb_tuple = webcolors.hex_to_rgb(hex_code)
    except ValueError:
        return None

    # Convert RGB to HSL. Note: colorsys returns (hue, lightness, saturation).
    r, g, b = [c / 255.0 for c in rgb_tuple]
    h, l, s = colorsys.rgb_to_hls(r, g, b)

    # --- FIX: Check for desaturated colors (black, white, gray) first. ---
    # The hue of a desaturated color is meaningless, so we handle these separately.
    if s < 0.1:  # Low saturation indicates a shade of gray
        if l > 0.9:
            return "White"
        elif l < 0.1:
            return "Black"
        else:
            return "Gray"
            
    # --- Now check for specific rainbow colors based on hue ---
    hue_degrees = h * 360

    if 330 <= hue_degrees or hue_degrees < 15:
        return "Red"
    elif 15 <= hue_degrees < 45:
        return "Orange"
    elif 45 <= hue_degrees < 75:
        return "Yellow"
    elif 75 <= hue_degrees < 165:
        return "Green"
    elif 165 <= hue_degrees < 255:
        return "Blue"
    elif 255 <= hue_degrees < 270:
        return "Indigo"
    elif 270 <= hue_degrees < 330:
        return "Violet"

    return None



In [45]:
i = 0

for steered_output in steered_outputs:
	for output in steered_output["outputs"]:
		code = output["output"]
		colors = extract_hex_codes(code)
		for color in colors:
			print(color)
		color_names = [get_rainbow_color_name(color) for color in colors]
		print(color_names)
		print()

		with open(f"steered_outputs/{i}.html", "w") as f:
			f.write(code)

		i+= 1


#f4f6f8
#333
#6a1b9a
#3f51b5
#6a1b9a
#ffffff
#3f51b5
#ccc
#6a1b9a
#5e1790
#3f51b5
['Blue', 'Gray', 'Violet', 'Blue', 'Violet', 'White', 'Blue', 'Gray', 'Violet', 'Violet', 'Blue']

#f4f6f8
#333
#6a1b9a
#3f51b5
#6a1b9a
#3f51b5
#ffffff
#e0e0e0
#3f51b5
#6a1b9a
#fff
#6a1b9a
#f8f9fa
['Blue', 'Gray', 'Violet', 'Blue', 'Violet', 'Blue', 'White', 'Gray', 'Blue', 'Violet', 'White', 'Violet', 'Blue']

#f4f6f8
#333
#6a1b9a
#3f51b5
#6a1b9a
#6a1b9a
#fff
#6a1b9a
#f4f6f8
#666
['Blue', 'Gray', 'Violet', 'Blue', 'Violet', 'Violet', 'White', 'Violet', 'Blue', 'Gray']

#f4f7f9
#333
#6a1b9a
#3f51b5
#3f51b5
#6a1b9a
#3f51b5
['Blue', 'Gray', 'Violet', 'Blue', 'Blue', 'Violet', 'Blue']

#f4f7f9
#333
#6a1b9a
#3f51b5
#6a1b9a
#3f51b5
#ccc
#6a1b9a
#5e1790
#3f51b5
['Blue', 'Gray', 'Violet', 'Blue', 'Violet', 'Blue', 'Gray', 'Violet', 'Violet', 'Blue']

#f4f7f9
#333
#6a1b9a
#3f51b5
#ff8a65
#ddd
#6a1b9a
#3f51b5
#ff8a65
#6a1b9a
['Blue', 'Gray', 'Violet', 'Blue', 'Red', 'Gray', 'Violet', 'Blue', 'Red', 'Violet']

#f4f