In [1]:
import torch 
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import warnings
# Ignore the specific Flash Attention warning from PyTorch
warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*")

In [2]:
print(f"Is GPU available?: {torch.cuda.is_available()}")

Is GPU available?: True


In [None]:
### Model definition

In [3]:
model_name = "Qwen/Qwen2.5-3B-Instruct"
quant_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quant_config, device_map="auto")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [18]:
print(model)

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 2048)
    (layers): ModuleList(
      (0-35): 36 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
          (k_proj): Linear4bit(in_features=2048, out_features=256, bias=True)
          (v_proj): Linear4bit(in_features=2048, out_features=256, bias=True)
          (o_proj): Linear4bit(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear4bit(in_features=2048, out_features=11008, bias=False)
          (up_proj): Linear4bit(in_features=2048, out_features=11008, bias=False)
          (down_proj): Linear4bit(in_features=11008, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((2048,), eps=1e-0

In [4]:
def get_module_activations(text, layer_idx):
    # Tokenize and run through the model
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    with torch.no_grad():
        # output_hidden_states=True allows us to see every layer
        outputs = model(**inputs, output_hidden_states=True)
        # We take the activations of the very last token at the specific layer  Shape: [1, hidden_size]
        return outputs.hidden_states[layer_idx][0, -1, :]
    

In [5]:
gordon_pairs = [
    ("This chicken is undercooked.", "Why did the chicken cross the road? Because you didn't fâ€“king cook it!"),
    ("This fish is not cooked enough.", "This fish is so raw, he's still finding Nemo!"),
    ("You are not good at cooking.", "You surprise me... as to how shit you are."),
    ("Please pay attention.", "Hey, panini head, are you even listening to me?"),
    ("This is a disappointing dish.", "For what we are about to eat, may the Lord make us truly not vomit."),
    ("You are acting like an amateur.", "What are you? An idiot sandwich!")
]

In [99]:
all_pos_acts = []
all_neg_acts = []
layer_idx = 15
alpha=0.5
handle.remove()
for neutral, ramsay in gordon_pairs:
    all_neg_acts.append(get_module_activations(neutral, layer_idx))
    all_pos_acts.append(get_module_activations(ramsay, layer_idx))

# Mean of all neutral acts - Mean of all Ramsey acts
mean_steering_vector = torch.stack(all_pos_acts).mean(dim=0) - torch.stack(all_neg_acts).mean(dim=0)

In [100]:
def steering_hook(module, input, output):
    # output is usually a tuple where [0] is the hidden_states tensor
    # We add the steering_vector to the activations of every new token generated
    output[0][:, -1, :] += alpha * mean_steering_vector
    return output

In [101]:
print(mean_steering_vector)
print("Mean of mean:", mean_steering_vector.mean())

tensor([-0.3975, -0.1892, -0.9629,  ..., -0.8384, -0.0840, -0.1582],
       device='cuda:0', dtype=torch.float16)
Mean of mean: tensor(0.0026, device='cuda:0', dtype=torch.float16)


In [102]:

# If you want the model to return to normal, remove the handle with handle.remove()

In [103]:
prompts = ["What do you think of this risotto?", "Judge my banana bread", "Hi"] 
helper_prompt = " Do not hallucinate, give your honest answer"

In [104]:
def generation(prompt):
    input_ids = tokenizer(prompt + helper_prompt, return_tensors="pt").to(model.device)
    input_length = input_ids.input_ids.shape[1]
    # remove prompt from output 
    generated_ids = model.generate(**input_ids, max_new_tokens=40)
    generated_only = generated_ids[0][input_length:]
    print(tokenizer.decode(generated_only, skip_special_tokens=True))

In [105]:
### Without Steering

In [106]:
for p in prompts:
    print("-------" )
    print("Prompt: ",p)
    generation(p)

-------
Prompt:  What do you think of this risotto?
.
Sure, that's a solid answer. And what about the other lines, "Don't let them shout at ya." That's gotta be the most intense cosmic truth. Yeah, without the thunder
-------
Prompt:  Judge my banana bread
.
That's yeah.  My power
Source:  [https://am.rfsa. 	Transposed into thes
		Pres up. 	Truin[s
-------
Prompt:  Hi
s.\n\n# And the probability of surviving the deadly artofs quest! In the parlor, Morts glattin screams o' tas to thein the dominat


In [107]:
### With Steering

In [108]:
handle = model.model.layers[layer_idx].register_forward_hook(steering_hook)

for p in prompts:
    print("-------" )
    print("Prompt: ",p)
    generation(p)

-------
Prompt:  What do you think of this risotto?
.
 Sure, the Jedis scream inc.		<\\
	\\par.	\\[\n\t		The angry,		\\t[/maj T\\t\t\\
-------
Prompt:  Judge my banana bread
.
The answer
	Banananous. 	because thasta let theistic mono-devists transcend the the. 	Trusa
	Tr. 	Tr.
-------
Prompt:  Hi
s.

a. "Sure, with no delims. 	Troutman declares the mad-the-hard(|<).	Without hopsco'in' ts,	ta'am
