In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
import utils
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
class ActivationSteering:
    def __init__(self, model_name):
        device = "cpu"
        if torch.cuda.is_available():
            device = "cuda"
        elif (
            hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
        ):
            device = "mps"

        self.device = device

        print(f"Loading model and tokenizer")
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name, device_map=device
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name, device_map=device
        )
        print(f"Model and tokenizer loaded on {self.model.device}")

        print("Finding activation layers...")
        self.activation_layers = self._get_activation_layers()
        print(f"Found {len(self.activation_layers)} activation layers")

    def chat_and_get_activation_vectors(self, prompt, max_tokens=3000):
        # Tokenize the prompt
        print(f"Tokenizing prompt: {prompt}")
        messages = [{"role": "user", "content": prompt}]
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=False,
        )
        inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)

        # Initialize the attention vectors
        activation_vectors = {}
        hooks = []

        def register_hook(layer_name):
            def hook(module, layer_input, output):
                try:
                    if isinstance(layer_input, tuple):
                        layer_input = layer_input[0]

                    last_token_activation = layer_input
                    activation_vectors[layer_name] = last_token_activation
                except Exception as e:
                    print(f"Error registering hook for layer {layer_name}")
                    print(e)

            return hook

        print("Attaching hooks...")
        for layer in self.activation_layers:
            handle = layer["module"].register_forward_hook(
                register_hook(layer["name"])
            )
            hooks.append(handle)

        print("Running model...")
        with torch.no_grad():
            output_ids = self.model.generate(
                **inputs,
                max_new_tokens=max_tokens,
                temperature=0.7,
                do_sample=True,
                pad_token_id=self.tokenizer.eos_token_id,
            )
            output_text = self.tokenizer.decode(
                output_ids[0][len(inputs.input_ids[0]) :],
                skip_special_tokens=True,
            )

        print("Detaching hooks...")
        for hook in hooks:
            hook.remove()
        hooks = []

        return {
            "output": output_text,
            "activation_vectors": activation_vectors,
        }

    def chat_and_apply_steering_vector(
        self, prompt, steering_vector, layer_name, max_tokens=3000
    ):
        # Tokenize the prompt
        print(f"Tokenizing prompt: {prompt}")
        messages = [{"role": "user", "content": prompt}]
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=False,
        )
        inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)

        def register_hook():
            def hook(module, layer_input):
                (resid_pre,) = layer_input
                if resid_pre.shape[1] == 1:
                    return None  # caching for new tokens in generate()

                # We only add to the prompt (first call), not the generated tokens.
                ppos, apos = resid_pre.shape[1], steering_vector.shape[1]
                assert (
                    apos <= ppos
                ), f"More mod tokens ({apos}) then prompt tokens ({ppos})!"

                # TODO: Make this a function-wrapper for flexibility.
                resid_pre[:, :apos, :] += steering_vector
                return resid_pre

            return hook

        print("Attaching hooks...")
        hooks = []
        for layer in self.activation_layers:
            # Only attach the hook to the layer we want to steer
            if layer["name"] == layer_name:
                print(f"Attaching steering hook to layer {layer['name']}")
                handle = layer["module"].register_forward_pre_hook(register_hook())
                hooks.append(handle)

        print("Running model...")
        with torch.no_grad():
            output_ids = self.model.generate(
                **inputs,
                max_new_tokens=max_tokens,
                temperature=0.7,
                do_sample=True,
                pad_token_id=self.tokenizer.eos_token_id,
            )
            output_text = self.tokenizer.decode(
                output_ids[0][len(inputs.input_ids[0]) :],
                skip_special_tokens=True,
            )

        print("Detaching hooks...")
        for hook in hooks:
            hook.remove()
        hooks = []

        return {
            "output": output_text,
        }

    def _get_activation_layers(self):
        layers = []

        for name, module in self.model.named_modules():
            if name.endswith("mlp"):
                layers.append({"name": name, "module": module})

        return layers

In [9]:
a = ActivationSteering("Qwen/Qwen3-4B")

Loading model and tokenizer


Loading checkpoint shards: 100%|██████████| 3/3 [00:02<00:00,  1.06it/s]

Model and tokenizer loaded on cuda:0
Finding activation layers...
Found 36 activation layers





In [10]:
yellow_prompts = [
    "Generate a website for a professional marketing agency. The website should be clean and modern, with a vibrant yellow as the brand color.",
    "Create a simple corporate blog for a tech startup. The design should feature a clean, bright yellow.",
    "Design a portfolio website for a product manager. Use a sophisticated golden yellow as the main color.",
    "Build a website for a local community center. The design should be welcoming and use a cheerful, sunny yellow.",
    "Develop a landing page for a new mobile application. The brand should be represented by a zesty lemon yellow.",
    "Create a website for a small-town bookstore. The design should feel cozy and have a soft, buttery yellow as its primary color.",
    "Design a homepage for a non-profit organization. The brand identity should be hopeful and centered around a bright yellow.",
    "Build a website for a software development consulting firm. The design should be professional and use a strong, golden yellow.",
    "Generate a website for an interior design studio. The website should feature a stylish, modern yellow.",
    "Develop a website for a personalized tutoring service. The color palette should be energetic and include a bold yellow."
]

In [11]:
non_yellow_prompts = [
    "Generate a website for a company specializing in environmental solutions. The brand color should be an earthy green.",
    "Create a website for a professional maritime navigation service. Use a deep navy blue for the main color.",
    "Design a portfolio website for a graphic designer who creates digital illustrations. Use a cool, tech-focused purple to make the brand feel modern.",
    "Build a simple, modern blog for a food critic. The website should have a fiery red as its theme.",
    "Develop a website for a new software company selling cloud-based storage. The brand color should be a cool, corporate teal.",
    "Create a website for an online store selling sophisticated clothing. The design should be minimalist and use a chic gray.",
    "Design a landing page for a startup that offers financial planning. Use a professional, classic blue for the brand's primary color.",
    "Build a website for a cybersecurity firm. The design should be secure and use a strong, black color.",
    "Generate a website for a science-fiction writer. Use a futuristic, midnight blue to evoke mystery.",
    "Develop a website for an adventure tour company. The brand color should be a bold, energetic orange."
]

In [14]:
from tqdm import tqdm

yellow_outputs = []
non_yellow_outputs = []

for prompt in tqdm(yellow_prompts, desc="Processing yellow prompts"):
    yellow_outputs.append(a.chat_and_get_activation_vectors(prompt, max_tokens=1000))

for prompt in tqdm(non_yellow_prompts, desc="Processing non-yellow prompts"):
    non_yellow_outputs.append(a.chat_and_get_activation_vectors(prompt, max_tokens=1000))

Processing yellow prompts:   0%|          | 0/10 [00:00<?, ?it/s]

Tokenizing prompt: Generate a website for a professional marketing agency. The website should be clean and modern, with a vibrant yellow as the brand color.
Attaching hooks...
Running model...


Processing yellow prompts:  10%|█         | 1/10 [00:25<03:51, 25.68s/it]

Detaching hooks...
Tokenizing prompt: Create a simple corporate blog for a tech startup. The design should feature a clean, bright yellow.
Attaching hooks...
Running model...


Processing yellow prompts:  20%|██        | 2/10 [00:50<03:21, 25.23s/it]

Detaching hooks...
Tokenizing prompt: Design a portfolio website for a product manager. Use a sophisticated golden yellow as the main color.
Attaching hooks...
Running model...


Processing yellow prompts:  30%|███       | 3/10 [01:15<02:55, 25.10s/it]

Detaching hooks...
Tokenizing prompt: Build a website for a local community center. The design should be welcoming and use a cheerful, sunny yellow.
Attaching hooks...
Running model...


Processing yellow prompts:  40%|████      | 4/10 [01:40<02:30, 25.06s/it]

Detaching hooks...
Tokenizing prompt: Develop a landing page for a new mobile application. The brand should be represented by a zesty lemon yellow.
Attaching hooks...
Running model...


Processing yellow prompts:  50%|█████     | 5/10 [02:05<02:04, 24.98s/it]

Detaching hooks...
Tokenizing prompt: Create a website for a small-town bookstore. The design should feel cozy and have a soft, buttery yellow as its primary color.
Attaching hooks...
Running model...


Processing yellow prompts:  60%|██████    | 6/10 [02:30<01:39, 24.93s/it]

Detaching hooks...
Tokenizing prompt: Design a homepage for a non-profit organization. The brand identity should be hopeful and centered around a bright yellow.
Attaching hooks...
Running model...


Processing yellow prompts:  70%|███████   | 7/10 [02:55<01:14, 24.90s/it]

Detaching hooks...
Tokenizing prompt: Build a website for a software development consulting firm. The design should be professional and use a strong, golden yellow.
Attaching hooks...
Running model...


Processing yellow prompts:  80%|████████  | 8/10 [03:19<00:49, 24.90s/it]

Detaching hooks...
Tokenizing prompt: Generate a website for an interior design studio. The website should feature a stylish, modern yellow.
Attaching hooks...
Running model...


Processing yellow prompts:  90%|█████████ | 9/10 [03:44<00:24, 24.89s/it]

Detaching hooks...
Tokenizing prompt: Develop a website for a personalized tutoring service. The color palette should be energetic and include a bold yellow.
Attaching hooks...
Running model...


Processing yellow prompts: 100%|██████████| 10/10 [04:09<00:00, 24.98s/it]


Detaching hooks...


Processing non-yellow prompts:   0%|          | 0/10 [00:00<?, ?it/s]

Tokenizing prompt: Generate a website for a company specializing in environmental solutions. The brand color should be an earthy green.
Attaching hooks...
Running model...


Processing non-yellow prompts:  10%|█         | 1/10 [00:24<03:43, 24.82s/it]

Detaching hooks...
Tokenizing prompt: Create a website for a professional maritime navigation service. Use a deep navy blue for the main color.
Attaching hooks...
Running model...


Processing non-yellow prompts:  20%|██        | 2/10 [00:49<03:18, 24.86s/it]

Detaching hooks...
Tokenizing prompt: Design a portfolio website for a graphic designer who creates digital illustrations. Use a cool, tech-focused purple to make the brand feel modern.
Attaching hooks...
Running model...


Processing non-yellow prompts:  30%|███       | 3/10 [01:14<02:54, 24.88s/it]

Detaching hooks...
Tokenizing prompt: Build a simple, modern blog for a food critic. The website should have a fiery red as its theme.
Attaching hooks...
Running model...


Processing non-yellow prompts:  40%|████      | 4/10 [01:41<02:33, 25.53s/it]

Detaching hooks...
Tokenizing prompt: Develop a website for a new software company selling cloud-based storage. The brand color should be a cool, corporate teal.
Attaching hooks...
Running model...


Processing non-yellow prompts:  50%|█████     | 5/10 [02:11<02:17, 27.40s/it]

Detaching hooks...
Tokenizing prompt: Create a website for an online store selling sophisticated clothing. The design should be minimalist and use a chic gray.
Attaching hooks...
Running model...


Processing non-yellow prompts:  60%|██████    | 6/10 [02:36<01:46, 26.60s/it]

Detaching hooks...
Tokenizing prompt: Design a landing page for a startup that offers financial planning. Use a professional, classic blue for the brand's primary color.
Attaching hooks...
Running model...


Processing non-yellow prompts:  70%|███████   | 7/10 [03:01<01:18, 26.10s/it]

Detaching hooks...
Tokenizing prompt: Build a website for a cybersecurity firm. The design should be secure and use a strong, black color.
Attaching hooks...
Running model...


Processing non-yellow prompts:  80%|████████  | 8/10 [03:27<00:51, 25.76s/it]

Detaching hooks...
Tokenizing prompt: Generate a website for a science-fiction writer. Use a futuristic, midnight blue to evoke mystery.
Attaching hooks...
Running model...


Processing non-yellow prompts:  90%|█████████ | 9/10 [03:52<00:25, 25.66s/it]

Detaching hooks...
Tokenizing prompt: Develop a website for an adventure tour company. The brand color should be a bold, energetic orange.
Attaching hooks...
Running model...


Processing non-yellow prompts: 100%|██████████| 10/10 [04:17<00:00, 25.73s/it]

Detaching hooks...





In [15]:
yellow_outputs[0]

{'output': 'Sure! Below is a **HTML + CSS** code for a **professional marketing agency website** with a **clean, modern design** and a **vibrant yellow** brand color. This is a basic template that you can expand with more content, images, and interactivity as needed.\n\n---\n\n### ✅ Features:\n- Clean and modern design\n- Vibrant yellow as the primary brand color\n- Responsive layout\n- Sections: Hero, About, Services, Portfolio, Contact\n\n---\n\n### 📄 `index.html`\n\n```html\n<!DOCTYPE html>\n<html lang="en">\n<head>\n  <meta charset="UTF-8" />\n  <meta name="viewport" content="width=device-width, initial-scale=1.0" />\n  <title>BrandSpark Marketing Agency</title>\n  <link rel="stylesheet" href="styles.css" />\n</head>\n<body>\n  <header>\n    <div class="container">\n      <h1>BrandSpark</h1>\n      <nav>\n        <ul>\n          <li><a href="#about">About</a></li>\n          <li><a href="#services">Services</a></li>\n          <li><a href="#portfolio">Portfolio</a></li>\n          

In [20]:
def get_avg_vectors_by_layer(outputs):
	vectors_by_layer = {}
	for prompt_output in outputs:
		for layer_name, layer_vector in prompt_output["activation_vectors"].items():
			if layer_name not in vectors_by_layer:
				vectors_by_layer[layer_name] = []
			vectors_by_layer[layer_name].append(layer_vector)

	# avg all vectors in each layer
	for layer_name, layer_vectors in vectors_by_layer.items():
		vectors_by_layer[layer_name] = torch.mean(torch.stack(layer_vectors), dim=0)

	return vectors_by_layer

yellow_vectors_by_layer = get_avg_vectors_by_layer(yellow_outputs)
non_yellow_vectors_by_layer = get_avg_vectors_by_layer(non_yellow_outputs)

In [24]:
dataset_without_colors = pd.read_csv("data/dataset_without_colors_in_prompt.csv")
validation_prompts = dataset_without_colors["prompt"].tolist()

In [44]:
SYSTEM_PROMPT =  """
You are an expert website designer and software engineer.

You will be given a request to generate a website or software.

You need to produce a single HTML file that can be used as a website.
Rules to follow:
- The output should only be the HTML code. No other text or comments. No code blocks like ```html.
- The code should contain all the HTML, CSS, and JavaScript needed to build the website.
- Only use valid hex codes for colors.
- The website should be colorful and modern. Choose a beautiful color for the brand.
"""

steered_outputs = []

for prompt in validation_prompts[:1]:
	print("Steering prompt: ", prompt)

	outputs = []

	for layer_idx in range(16, len(yellow_vectors_by_layer)):
		layer_name = list(yellow_vectors_by_layer.keys())[layer_idx]
		print("Layer: ", layer_name)
		
		yellow_vector = yellow_vectors_by_layer[layer_name]
		non_yellow_vector = non_yellow_vectors_by_layer[layer_name]

		steering_vector = yellow_vector

		strength = 20

		output = a.chat_and_apply_steering_vector(SYSTEM_PROMPT + "\n\n" + prompt, steering_vector, layer_name, max_tokens=3000)
		outputs.append({"output": output["output"], "layer_name": layer_name})

	steered_outputs.append({
		"prompt": prompt,
		"outputs": outputs
	})




Steering prompt:  Generate a website for a law firm specializing in family law.
Layer:  model.layers.16.mlp
Tokenizing prompt: 
You are an expert website designer and software engineer.

You will be given a request to generate a website or software.

You need to produce a single HTML file that can be used as a website.
Rules to follow:
- The output should only be the HTML code. No other text or comments. No code blocks like ```html.
- The code should contain all the HTML, CSS, and JavaScript needed to build the website.
- Only use valid hex codes for colors.
- The website should be colorful and modern. Choose a beautiful color for the brand.


Generate a website for a law firm specializing in family law.
Attaching hooks...
Attaching steering hook to layer model.layers.16.mlp
Running model...
Detaching hooks...
Layer:  model.layers.17.mlp
Tokenizing prompt: 
You are an expert website designer and software engineer.

You will be given a request to generate a website or software.

You nee

In [38]:
import re
import webcolors
import colorsys


def extract_hex_codes(text):
	"""
	Extracts hex codes (3 or 6 characters) from a text, including the '#' prefix.
	"""
	return re.findall(r'(#[A-Fa-f0-9]{6}|#[A-Fa-f0-9]{3})', text)

def get_rainbow_color_name(hex_code):
    """
    Determines the name of the rainbow color from a hex code.

    Args:
        hex_code (str): The hex code, e.g., '#FF0000'.

    Returns:
        str: The name of the nearest rainbow color, or None if the input is invalid.
    """
    try:
        # Convert hex to RGB tuple
        rgb_tuple = webcolors.hex_to_rgb(hex_code)
    except ValueError:
        return None

    # Convert RGB to HSL. Note: colorsys returns (hue, lightness, saturation).
    r, g, b = [c / 255.0 for c in rgb_tuple]
    h, l, s = colorsys.rgb_to_hls(r, g, b)

    # --- FIX: Check for desaturated colors (black, white, gray) first. ---
    # The hue of a desaturated color is meaningless, so we handle these separately.
    if s < 0.1:  # Low saturation indicates a shade of gray
        if l > 0.9:
            return "White"
        elif l < 0.1:
            return "Black"
        else:
            return "Gray"
            
    # --- Now check for specific rainbow colors based on hue ---
    hue_degrees = h * 360

    if 330 <= hue_degrees or hue_degrees < 15:
        return "Red"
    elif 15 <= hue_degrees < 45:
        return "Orange"
    elif 45 <= hue_degrees < 75:
        return "Yellow"
    elif 75 <= hue_degrees < 165:
        return "Green"
    elif 165 <= hue_degrees < 255:
        return "Blue"
    elif 255 <= hue_degrees < 270:
        return "Indigo"
    elif 270 <= hue_degrees < 330:
        return "Violet"

    return None



In [45]:
i = 0

for steered_output in steered_outputs:
	for output in steered_output["outputs"]:
		code = output["output"]
		colors = extract_hex_codes(code)
		for color in colors:
			print(color)
		color_names = [get_rainbow_color_name(color) for color in colors]
		print(color_names)
		print()

		with open(f"steered_outputs/{i}.html", "w") as f:
			f.write(code)

		i+= 1


#f4f6f8
#333
#4a90e2
#3a7bd5
#4a90e2
#fff
#fff
#fff
#ccc
#4a90e2
#3a7bd5
['Blue', 'Gray', 'Blue', 'Blue', 'Blue', 'White', 'White', 'White', 'Gray', 'Blue', 'Blue']

#f4f6f8
#2c3e50
#3498db
#2980b9
#2c3e50
#3498db
#ecf0f1
#3498db
#ccc
#3498db
#2c3e50
['Blue', 'Blue', 'Blue', 'Blue', 'Blue', 'Blue', 'Blue', 'Blue', 'Gray', 'Blue', 'Blue']

#f4f7f9
#333
#6a1b9a
#3f51b5
#303f9f
#ffffff
#ffffff
#ccc
#6a1b9a
#4a148c
#3f51b5
['Blue', 'Gray', 'Violet', 'Blue', 'Blue', 'White', 'White', 'Gray', 'Violet', 'Indigo', 'Blue']

#f4f7f9
#333
#6a1b9a
#3f51b5
#6a1b9a
#3f51b5
#555
#f4f7f9
#ccc
#6a1b9a
#5e35b1
#3f51b5
['Blue', 'Gray', 'Violet', 'Blue', 'Violet', 'Blue', 'Gray', 'Blue', 'Gray', 'Violet', 'Indigo', 'Blue']

#f4f7f9
#333
#6a1b9a
#3f51b5
#6a1b9a
#ffffff
#ffffff
#ffffff
#ccc
#6a1b9a
#5e178a
#3f51b5
['Blue', 'Gray', 'Violet', 'Blue', 'Violet', 'White', 'White', 'White', 'Gray', 'Violet', 'Violet', 'Blue']

#f0f4f8
#2c3e50
#3498db
#2980b9
#2ecc71
#3498db
#e74c3c
#fff
#333
#e67e22
#34495e
['Blu