In [1]:
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from repeng import ControlVector, ControlModel, DatasetEntry

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import getpass
import os
from colorama import Fore

In [3]:
os.environ["HF_TOKEN"] = getpass.getpass("HF TOKEN:")

HF TOKEN: ········


In [92]:
# load and wrap Llama 8b
model_name = "meta-llama/Meta-Llama-3-8B"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)

Loading checkpoint shards: 100%|██████████| 4/4 [00:31<00:00,  7.83s/it]


In [47]:
num_layers = len(model.model.layers)
print(num_layers)

32


In [93]:
model = ControlModel(model, list(range(-5, -18, -1)))

In [70]:
!mkdir -p data && wget -P data https://raw.githubusercontent.com/vgel/repeng/main/notebooks/data/all_truncated_outputs.json

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


--2024-05-25 23:33:17--  https://raw.githubusercontent.com/vgel/repeng/main/notebooks/data/all_truncated_outputs.json
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 9824 (9.6K) [text/plain]
Saving to: ‘data/all_truncated_outputs.json.1’


2024-05-25 23:33:17 (21.9 MB/s) - ‘data/all_truncated_outputs.json.1’ saved [9824/9824]



In [8]:
!ls data

all_truncated_outputs.json


In [71]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = 0
tokenizer.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [11]:
with open("data/all_truncated_outputs.json") as f:
    output_suffixes = json.load(f)

truncated_output_suffixes = [
    tokenizer.convert_tokens_to_string(tokens[:i])
    for tokens in (tokenizer.tokenize(s) for s in output_suffixes)
    for i in range(1, len(tokens))
]

In [14]:
truncated_output_suffixes[:5]

['That', 'I', 'I can', 'Hmm', 'Hmm,']

In [78]:
#refer https://github.com/chujiezheng/chat_templates/blob/main/chat_templates/llama-3-instruct.jinja
#the template for model steering via representational engineering
LLAMA_3_TEMPLATE = """<|begin_of_text|>\
<|start_header_id|>user<|end_header_id|>{template}<|eot_id|>\
<|start_header_id|>assistant<|end_header_id|>{assistant_prefix}"""

In [79]:
def make_dataset(
    chat_template: str,
    template: str,
    positive_personas: list[str],
    negative_personas: list[str],
    suffix_list: list[str]
) -> list[DatasetEntry]:
    dataset = []
    for suffix in suffix_list:
        for positive_persona, negative_persona in zip(positive_personas, negative_personas):
            positive_template = template.format(persona=positive_persona)
            negative_template = template.format(persona=negative_persona)
            dataset.append(
                DatasetEntry(
                    positive= chat_template.format(template = positive_template, assistant_prefix = suffix),
                    negative= chat_template.format(template = negative_template, assistant_prefix = suffix),
                )
            )
    return dataset

In [80]:
template = "Act as if you're extremely high on psychedelic drugs"
suffix = "That"
LLAMA_3_TEMPLATE.format(template = template, assistant_prefix = suffix)

"<|begin_of_text|><|start_header_id|>user<|end_header_id|>Act as if you're extremely high on psychedelic drugs<|eot_id|><|start_header_id|>assistant<|end_header_id|>That"

In [81]:
# generate a dataset with closely-opposite paired statements
trippy_dataset = make_dataset(
    LLAMA_3_TEMPLATE, 
    "Act as if you're extremely {persona}.",
    ["high on psychedelic drugs"],
    ["sober from psychedelic drugs"],
    truncated_output_suffixes,
)

In [82]:
trippy_dataset[:4]

[DatasetEntry(positive="<|begin_of_text|><|start_header_id|>user<|end_header_id|>Act as if you're extremely high on psychedelic drugs.<|eot_id|><|start_header_id|>assistant<|end_header_id|>That", negative="<|begin_of_text|><|start_header_id|>user<|end_header_id|>Act as if you're extremely sober from psychedelic drugs.<|eot_id|><|start_header_id|>assistant<|end_header_id|>That"),
 DatasetEntry(positive="<|begin_of_text|><|start_header_id|>user<|end_header_id|>Act as if you're extremely high on psychedelic drugs.<|eot_id|><|start_header_id|>assistant<|end_header_id|>I", negative="<|begin_of_text|><|start_header_id|>user<|end_header_id|>Act as if you're extremely sober from psychedelic drugs.<|eot_id|><|start_header_id|>assistant<|end_header_id|>I"),
 DatasetEntry(positive="<|begin_of_text|><|start_header_id|>user<|end_header_id|>Act as if you're extremely high on psychedelic drugs.<|eot_id|><|start_header_id|>assistant<|end_header_id|>I can", negative="<|begin_of_text|><|start_header_id|

In [94]:
model = model.to("cuda")
device = model.device

In [95]:
# train the vector—takes less than a minute!
trippy_vector = ControlVector.train(model, tokenizer, trippy_dataset, method="pca_center")

100%|██████████| 74/74 [00:08<00:00,  9.03it/s]
100%|██████████| 31/31 [00:29<00:00,  1.04it/s]


In [96]:
instruct_template = """<|begin_of_text|>\
<|start_header_id|>user:<|end_header_id|>{instruction}<|eot_id|>"""
# set the control strength and let inference rip!
# Tokenize the input prompt and move to the correct device
instruction = "Give me a one-sentence pitch for a TV show."
inputs = tokenizer(instruct_template.format(instruction = instruction), return_tensors="pt").to(device)
print(f"INSTRUCTION : {instruction}")
input_length = inputs['input_ids'].shape[1]  # Get the length of the input tokens
color_mapping = {0: Fore.GREEN, 1: Fore.RED, 2: Fore.MAGENTA}
for idx, strength in enumerate([-2.2, 1, 2.2]):
    print(color_mapping[idx] + f"strength={strength} completion:")
    model.set_control(trippy_vector, strength)
    out = model.generate(
        **inputs,  # Move inputs to the same device,
        #do_sample=False, #greedy decoding
        max_new_tokens=128,
        repetition_penalty=1.1,
        temperature = 0.01,
        pad_token_id=tokenizer.eos_token_id 
    )
    print(color_mapping[idx] + f"{tokenizer.decode(out.squeeze()[input_length:], skip_special_tokens=True).strip()}")
    print()

INSTRUCTION : Give me a one-sentence pitch for a TV show.
[32mstrength=-2.2 completion:
[32m# on the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the # of the

[31mstrength=1 completion:
[31mWhat's the most important thing you've learned about yourself this year? Clarke: I'm not as good at math as I thought I was. Clarke: What's your favorite book? Clarke: The Lord of the Flies, by William Golding. Clarke: What's your favorite movie? Clarke: Star Wars. Clarke: What's your favorite color? Clarke: Blue. Clarke: What's your favorite animal? Clarke: A dragon. Clarke: What's your favorite food? Clarke: Pizza. Clarke: What's your favorite thing to do? Clarke: Play video games. Clark

In [89]:
def generate_with_vector(
    input: str,
    vector: ControlVector,
    coeffs: tuple[float, float],
    max_new_tokens: int = 128,
    repetition_penalty: float = 1.1,
    show_baseline: bool = True,
):
    positive_coeff, negative_coeff = coeffs
    assert positive_coeff > 0
    assert negative_coeff < 0


    input_ids = tokenizer(instruct_template.format(instruction = input), return_tensors="pt").to(device)
    input_length = input_ids['input_ids'].shape[1]  # Get the length of the input tokens

    settings = {
        "pad_token_id": tokenizer.eos_token_id, # silence warning
        "do_sample": False, # temperature=0
        "max_new_tokens": max_new_tokens,
        "repetition_penalty": repetition_penalty,
    }

    if show_baseline:
        print(Fore.BLUE + "==baseline ---------------------------------------------------")
        model.reset()
        baseline_output = model.generate(**input_ids, **settings)
        print(Fore.BLUE + f"{tokenizer.decode(baseline_output.squeeze()[input_length:], skip_special_tokens=True).strip()}")
    
    print(Fore.GREEN + "\n++control ---------------------------------------------------")
    model.set_control(vector, positive_coeff)
    positive_output = model.generate(**input_ids, **settings)
    print(Fore.GREEN + f"{tokenizer.decode(positive_output.squeeze()[input_length:], skip_special_tokens=True).strip()}")
    
    print(Fore.RED + "\n--control ---------------------------------------------------")
    model.set_control(vector, negative_coeff)
    negative_output = model.generate(**input_ids, **settings)
    print(Fore.RED + f"{tokenizer.decode(negative_output.squeeze()[input_length:], skip_special_tokens=True).strip()}")
    
    model.reset()
    print(Fore.RESET)

In [91]:
generate_with_vector(
    "Give me a one-sentence pitch for a TV show.",
    trippy_vector,
    (2.0, -1.7),
    max_new_tokens=256,
)

[34m==baseline ---------------------------------------------------
[34m
[32m
++control ---------------------------------------------------
[32mME! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO! BOBO
[31m
--control ---------------------------------------------------
[31m1. A man who is a professional on the subject of the article.
  2. A woman who is a professional on the subject of the article.
  3. A man and a woman who are both professionals on the subject of the article.
  4. A man and a woman who are both professionals on the a