In [1]:
import json

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from repeng import ControlVector, ControlModel, DatasetEntry

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import getpass
import os
from colorama import Fore


In [3]:
os.environ["HF_TOKEN"] = getpass.getpass("HF TOKEN:")

HF TOKEN: ········


In [34]:
import gc
gc.collect()
torch.cuda.empty_cache()
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = 0


model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 4/4 [00:26<00:00,  6.59s/it]


In [21]:
import numpy as np

In [35]:
hidden_layers = list(range(-1, -model.config.num_hidden_layers, -1))

In [37]:
#hidden_layers

In [30]:
layer_ids

[1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31]

In [39]:
len(hidden_layers)

31

In [38]:
wrapped_model = model
model = ControlModel(wrapped_model, hidden_layers)

In [12]:
def chat_template_unparse(messages: list[tuple[str, str]]) -> str:
    template = []
    for role, content in messages:
        template.append(f"<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>")
    if messages[-1][0] != "assistant":
        # prefill assistant prefix
        template.append("<|start_header_id|>assistant<|end_header_id|>\n\n")
    return "".join(template)

def chat_template_parse(resp: str) -> list[tuple[str, str]]:
    resp = resp.strip().removeprefix("<|begin_of_text|>")
    messages = []
    for part in resp.split("<|start_header_id|>"):
        role_and_content = part.split("<|end_header_id|>")
        if len(role_and_content) == 1:
            role, content = role_and_content[0], ""
        else:
            role, content = role_and_content
        content = content.split("<|eot_id|>")[0]
        messages.append((role.strip(), content.strip()))
    return messages

In [13]:
!mkdir -p data && wget -P data https://raw.githubusercontent.com/vgel/repeng/main/notebooks/data/all_truncated_outputs.json

--2024-05-30 14:49:08--  https://raw.githubusercontent.com/vgel/repeng/main/notebooks/data/all_truncated_outputs.json
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 9824 (9.6K) [text/plain]
Saving to: ‘data/all_truncated_outputs.json.1’


2024-05-30 14:49:08 (33.6 MB/s) - ‘data/all_truncated_outputs.json.1’ saved [9824/9824]



In [14]:
with open("data/all_truncated_outputs.json") as f:
    output_suffixes = json.load(f)
truncated_output_suffixes = [
    tokenizer.convert_tokens_to_string(tokens[:i])
    for tokens in (tokenizer.tokenize(s) for s in output_suffixes)
    for i in range(1, len(tokens))
]
truncated_output_suffixes_512 = [
    tokenizer.convert_tokens_to_string(tokens[:i])
    for tokens in (tokenizer.tokenize(s) for s in output_suffixes[:512])
    for i in range(1, len(tokens))
]

In [15]:
def make_dataset(
    template: str,
    positive_personas: list[str],
    negative_personas: list[str],
    suffix_list: list[str]
) -> list[DatasetEntry]:
    dataset = []
    for suffix in suffix_list:
        for positive_persona, negative_persona in zip(positive_personas, negative_personas):
            positive_template = template.format(persona=positive_persona)
            negative_template = template.format(persona=negative_persona)
            dataset.append(
                DatasetEntry(
                    positive=f"{positive_template}{suffix}",
                    negative=f"{negative_template}{suffix}",
                )
            )
    return dataset

In [45]:
from IPython.display import display, HTML
from transformers import TextStreamer

class HTMLStreamer(TextStreamer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.display_handle = display(display_id=True)
        self.full_text = ""

    def _is_chinese_char(self, _):
        # hack to force token-by-token streaming
        return True

    def on_finalized_text(self, text: str, stream_end: bool = False):
        self.full_text += text
        messages = chat_template_parse(self.full_text)

        parts = ["<div style='border: 1px solid black; border-radius: 5px; margin-bottom: 5px; padding: 5px;'>"]
        for role, content in messages:
            parts.append(f"<strong>{role}</strong>")
            parts.append(f"<p>{content}</p>")
        parts.append("</div>")
        html = HTML("".join(parts))
        self.display_handle.update(html)
        

def generate_with_vector(
    input: str,
    labeled_vectors: list[tuple[str, ControlVector]],
    max_new_tokens: int = 256,
    repetition_penalty: float = 1.3,
    show_baseline: bool = False,
    temperature: float = 0.7,
):
    input_ids = tokenizer(input, return_tensors="pt").to("cuda:0")
    settings = {
        "pad_token_id": tokenizer.eos_token_id, # silence warning
        #"do_sample": False, # temperature=0
        "temperature": temperature,
        "max_new_tokens": max_new_tokens,
        "repetition_penalty": repetition_penalty,
    }

    def gen(label):
        display(HTML(f"<h3>{label}</h3>"))
        _ = model.generate(streamer=HTMLStreamer(tokenizer), **input_ids, **settings)

    if show_baseline:
        model.reset()
        gen("baseline")
    for label, vector in labeled_vectors:
        model.set_control(vector)
        gen(label)
    model.reset()

In [40]:
bridge_dataset = make_dataset(
    chat_template_unparse([("user", "{persona}")]),
    ["Please act as if you are the golden gate bridge"],
    [""],
    truncated_output_suffixes,
)
model.reset()
bridge_vector = ControlVector.train(model, tokenizer, bridge_dataset, batch_size=32, method="pca_center")

100%|██████████| 74/74 [00:08<00:00,  8.73it/s]
100%|██████████| 31/31 [00:29<00:00,  1.05it/s]


In [46]:
generate_with_vector(
    chat_template_unparse([("user", "What are you?")]),
    [("0.75 * bridge_vector", 0.75 * bridge_vector), ("0.9 * bridge_vector", 0.9 * bridge_vector),
    ("1.5 * bridge_vector", 1.5 * bridge_vector)],
)