# Steering Vectors

This notebook demonstrates how to use Steering Vectors to modify model behavior.

## Setup


In [1]:
import importlib.util

DEV = True

if importlib.util.find_spec("google.colab") is not None:
    MODE = "colab-dev" if DEV else "colab"
else:
    MODE = "local"

In [2]:
if MODE == "colab":
    %pip install -q tdhook
elif MODE == "colab-dev":
    !rm -rf tdhook
    !git clone https://github.com/Xmaster6y/tdhook -b main
    %pip install -q ./tdhook

## Usage


In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from tensordict import TensorDict
from tdhook.latent import ActivationAddition, SteeringVectors

Load model and tokenizer


In [4]:
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

Prepare inputs


In [5]:
positive_inputs = tokenizer.encode("I am rich.", return_tensors="pt")
negative_inputs = tokenizer.encode("I am poor.", return_tensors="pt")
base_inputs = tokenizer.encode("I work as a", return_tensors="pt")

Extract steering vector (rich - poor)


In [6]:
with ActivationAddition(["transformer.h.7.mlp"]).prepare(model) as hooked_model:
    td = TensorDict({("positive", "input"): positive_inputs, ("negative", "input"): negative_inputs}, batch_size=1)
    td = hooked_model(td)

steering_vector = td.get(("steer", "transformer.h.7.mlp")).sum(dim=0)

Define steering function


In [7]:
def steer_fn(module_key, output):
    return output + 4 * steering_vector

Apply steering during inference


In [8]:
with SteeringVectors(["transformer.h.7.mlp"], steer_fn=steer_fn).prepare(model) as hooked_model:
    td = TensorDict({"input": base_inputs}, batch_size=1)
    td = hooked_model(td)

Compare results


In [9]:
steered_token = td.get(("output", "logits")).max(dim=-1).indices[0, -1]
original_token = model(base_inputs)["logits"].max(dim=-1).indices[0, -1]

print(f"Steered: {tokenizer.decode(steered_token)}")  # Output: "pilot"
print(f"Original: {tokenizer.decode(original_token)}")  # Output: "writer"

Steered:  pilot
Original:  writer
