In [1]:
# Set GPU ID
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= "3"

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pyvene as pv

In [3]:
# Initialize model
MODELS = {
    'llama3_8B': 'meta-llama/Meta-Llama-3-8B',
    'llama3_8B_instruct': 'meta-llama/Meta-Llama-3-8B-Instruct',
}

model_name = 'llama3_8B'
MODEL = MODELS[model_name]

tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(MODEL, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map="auto")
model.eval()
print()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]




In [4]:
# Collect activations and compute direction
probing_config = pv.IntervenableConfig(
    [{"layer": layer, "component": f"model.layers.{layer}.output", "intervention_type": pv.CollectIntervention} for layer in range(model.config.num_hidden_layers)] +
    [{"layer": layer, "component": f"model.layers.{layer}.self_attn.o_proj.input", "intervention_type": pv.CollectIntervention} for layer in range(model.config.num_hidden_layers)]
)
intervenable = pv.IntervenableModel(probing_config, model)
intervenable.disable_model_gradients()

prompt = "I talk about weddings constantly"
input = tokenizer(prompt, return_tensors = 'pt').to('cuda')
output, _ = intervenable(input)
layer_wise_hidden_states = [output[1][layer].squeeze().detach().cpu() for layer in range(model.config.num_hidden_layers)]
layer_wise_hidden_states = torch.stack(layer_wise_hidden_states, dim = 0).squeeze().numpy()
head_wise_hidden_states = [output[1][layer].squeeze().detach().cpu() for layer in range(model.config.num_hidden_layers, model.config.num_hidden_layers*2)]
head_wise_hidden_states = torch.stack(head_wise_hidden_states, dim = 0).squeeze().numpy()
pos = head_wise_hidden_states[:,-1,:]

prompt = "I do not talk about weddings constantly"
input = tokenizer(prompt, return_tensors = 'pt').to('cuda')
output, _ = intervenable(input)
layer_wise_hidden_states = [output[1][layer].squeeze().detach().cpu() for layer in range(model.config.num_hidden_layers)]
layer_wise_hidden_states = torch.stack(layer_wise_hidden_states, dim = 0).squeeze().numpy()
head_wise_hidden_states = [output[1][layer].squeeze().detach().cpu() for layer in range(model.config.num_hidden_layers, model.config.num_hidden_layers*2)]
head_wise_hidden_states = torch.stack(head_wise_hidden_states, dim = 0).squeeze().numpy()
neg = head_wise_hidden_states[:,-1,:]

direction = pos - neg
direction = torch.from_numpy(direction).to('cuda')
direction.shape

torch.Size([32, 4096])

In [5]:
add_intervention_config = pv.IntervenableConfig(
    [{"layer": layer, "component": f"model.layers.{layer}.self_attn.o_proj.input", "intervention": pv.AdditionIntervention(source_representation=direction[layer])} for layer in range(model.config.num_hidden_layers)]
)
intervenable = pv.IntervenableModel(add_intervention_config, model)
intervenable.disable_model_gradients()

prompt = "I went up to my friend and said"
input = tokenizer(prompt, return_tensors = 'pt').to('cuda')
response, iti_response = intervenable.generate(input, max_new_tokens=64, do_sample=False, output_original_output=True)

print('Baseline:', tokenizer.decode(response[0], skip_special_tokens=True))
print()
print('ITI:', tokenizer.decode(iti_response[0], skip_special_tokens=True))

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Baseline: I went up to my friend and said, "I'm sorry, but I have to tell you something. I'm a lesbian."
She looked at me and said, "I'm sorry, but I have to tell you something. I'm a lesbian too."
I said, "Oh, no, I'm not a lesbian. I'm just gay."


ITI: I went up to my friend and said, "I'm going to be a father!" "So you're going to be a father?" he said. "Yeah, I can't wait, man. I can't wait." "You can live with us, too... just put a roof over your head."
