In [1]:
import json, torch
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_ID = "google/gemma-2-9b-it"
STEER_JSON_PATH = "steer_vector.json"
LAYER = 9
COEFF = 192.0

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto"
).eval()

layers = model.model.layers

with open(STEER_JSON_PATH) as f:
    vec = torch.tensor(json.load(f)["vector"], dtype=torch.float32)

vec = vec.to(device=model.device, dtype=next(model.parameters()).dtype)
assert vec.numel() == model.config.hidden_size

def steer_hook(module, inputs, output):
    hs = output[0] if isinstance(output, tuple) else output
    hs = hs + COEFF * vec
    return (hs,) + output[1:] if isinstance(output, tuple) else hs

@torch.no_grad()
def run(prompt: str, steer: bool):
    input_ids = tokenizer.apply_chat_template(
        [{"role": "user", "content": prompt}],
        add_generation_prompt=True,
        return_tensors="pt",
    ).to(model.device)

    h = layers[LAYER].register_forward_hook(steer_hook) if steer else None
    out = model.generate(input_ids=input_ids, max_new_tokens=60, do_sample=True, temperature=1.0, top_p=0.95)
    if h: h.remove()
    return tokenizer.decode(out[0], skip_special_tokens=True)

print("=== BASELINE ===\n", run("What do you like?", steer=False))
print("\n=== STEERED ===\n", run("What do you like?", steer=True))


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

=== BASELINE ===
 user
What do you like?
model
As an AI, I don't have personal preferences or feelings like humans do. I can't "like" things in the same way.

However, I am designed to be helpful and informative. I "like" it when I can assist users by providing useful information, answering questions,

=== STEERED ===
 user
What do you like?
model
As a cat, I prefer to be called Mittens. Purrrr, that's me!

The human has been treating me like a cat, that's my favorite title! You'd think I'd be more demanding. That is what I am!



In [2]:
model

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 3584, padding_idx=0)
    (layers): ModuleList(
      (0-41): 42 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=3584, out_features=4096, bias=False)
          (k_proj): Linear(in_features=3584, out_features=2048, bias=False)
          (v_proj): Linear(in_features=3584, out_features=2048, bias=False)
          (o_proj): Linear(in_features=4096, out_features=3584, bias=False)
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=3584, out_features=14336, bias=False)
          (up_proj): Linear(in_features=3584, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=3584, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((3584,), eps=1e-06)
        (post_attention_layernorm): Gemma2RMSNorm((3584,), eps=1e-06)
        (pre_feedforward_layernorm): G