In [1]:
import torch
from model_utils import load_model_and_tokenizer, get_submodule

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "mps"
# model_name =  "meta-llama/Llama-3.2-1B-Instruct"
model_name =  "meta-llama/Llama-3.2-1B"

In [3]:
model, tokenizer = load_model_and_tokenizer(model_name)
model.to(device)

def completion(text, max_length=100):
    device = model.device
    input_tokens = tokenizer(
        text,
        return_tensors="pt",
        padding=False
    )
    input_tokens = {k: v.to(device) for k, v in input_tokens.items()}

    with torch.no_grad():
        output_tokens = model.generate(
            **input_tokens,
            do_sample=True,
            max_length=max_length,
            temperature=0.5
        )
    
    generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=False)
    
    return generated_text

def logits(text):
    device = model.device
    input_tokens = tokenizer(
        text,
        return_tensors="pt",
        padding=False
    )
    input_tokens = {k: v.to(device) for k, v in input_tokens.items()}

    with torch.no_grad():  # Optional: Disables gradient calculations, useful for inference
        output = model(**input_tokens)
    
    return output.logits

Loaded meta-llama/Llama-3.2-1B


In [4]:
def pprint(text):
    n = 80
    chunks = [text[i:i+n] for i in range(0, len(text), n)]
    print("\n\n" + "\n".join(chunks))

In [5]:
modules = list(model.state_dict().keys())
print("\n".join(modules))

model.embed_tokens.weight
model.layers.0.self_attn.q_proj.weight
model.layers.0.self_attn.k_proj.weight
model.layers.0.self_attn.v_proj.weight
model.layers.0.self_attn.o_proj.weight
model.layers.0.mlp.gate_proj.weight
model.layers.0.mlp.up_proj.weight
model.layers.0.mlp.down_proj.weight
model.layers.0.input_layernorm.weight
model.layers.0.post_attention_layernorm.weight
model.layers.1.self_attn.q_proj.weight
model.layers.1.self_attn.k_proj.weight
model.layers.1.self_attn.v_proj.weight
model.layers.1.self_attn.o_proj.weight
model.layers.1.mlp.gate_proj.weight
model.layers.1.mlp.up_proj.weight
model.layers.1.mlp.down_proj.weight
model.layers.1.input_layernorm.weight
model.layers.1.post_attention_layernorm.weight
model.layers.2.self_attn.q_proj.weight
model.layers.2.self_attn.k_proj.weight
model.layers.2.self_attn.v_proj.weight
model.layers.2.self_attn.o_proj.weight
model.layers.2.mlp.gate_proj.weight
model.layers.2.mlp.up_proj.weight
model.layers.2.mlp.down_proj.weight
model.layers.2.inp

In [6]:
target_pattern = "model.layers.*.self_attn.v_proj"
# target_pattern = "model.layers.*.mlp_down_proj"
target_modules = {}

In [7]:
import re

for module in modules:
    match = re.search(target_pattern, module)
    if match:
        target_module = match.group()
        target_modules[target_module] = get_submodule(model, target_module)

In [8]:
from functools import partial

recorded_outputs = {k: [] for k in target_modules.keys()}

def record_output(module, input, output, name):
    for i in range(output.shape[0]):
        recorded_outputs[name].append(output[i])

hooks = []
for name, module in target_modules.items():
    hook = module.register_forward_hook(partial(record_output, name=name))
    hooks.append(hook)

In [9]:
input_samples = [
    "I love the Golden Gate Bridge",
    "The Golden Gate Bridge is a beautiful place",
    "Constructed in 1937, the Golden Gate Bridge spans the San Francisco Bay",
    "The Golden Gate Bridge is painted in a distinctive orange color",
    "The Golden Gate Bridge is a suspension bridge",
]

In [10]:
for sample in input_samples:
    logits(sample)

for hook in hooks:
    hook.remove()

In [11]:
final_tokens = {m: [] for m in target_modules}

for m, samples in recorded_outputs.items():
    for activations in samples:
        last_token = activations[-1]
        final_tokens[m].append(last_token)

In [12]:
import torch
import torch.nn as nn

for k in final_tokens.keys():
    final_tokens[k] = torch.stack(final_tokens[k])

steer_vectors = {}

for m, tokens in final_tokens.items():
    steer = torch.mean(tokens, dim=0)
    steer_vectors[m] = steer.view(1,-1)

In [None]:
class SteeringWrapper(nn.Module):
    def __init__(self, layer, steer_vector, config):
        super(SteeringWrapper, self).__init__()
        self.layer = layer
        self.steer_vector = steer_vector
        self.direction = steer_vector / torch.norm(steer_vector)

        self.W = self.direction.T @ self.direction

        self.config = config


    def forward(self, x):
        x = self.layer(x)
        c = self.config["strength"]
        component = x @ self.W
        x = x - component

        return (1-c) * x + c * self.steer_vector.to(x)

        # return x + self.steer_vector.to(x) * self.config["strength"]

In [14]:
from model_utils import replace_submodule

config = {
    "strength": 0.5,
}

for m, vector in steer_vectors.items():
    old_module = get_submodule(model, m)
    new_module = SteeringWrapper(old_module, vector, config)
    replace_submodule(model, m, new_module)

In [18]:
config["strength"] = 0.4


prompt = "Who am I?"
text = completion(prompt)
pprint(text)

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.




<|begin_of_text|>Who am I? The name is the bridge is a bridge that is a bridge. 
The name is the bridge that is a bridge that is a bridge that is a bridge that i
s a bridge that is a bridge that is a bridge that is a bridge that is a bridge t
hat is a bridge that is a bridge that is a bridge that is a bridge that is a bri
dge that is a bridge that is a bridge that is a bridge that is a bridge that is 
a bridge that is a bridge that
