In [1]:
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from repeng import ControlVector, ControlModel, DatasetEntry

  from .autonotebook import tqdm as notebook_tqdm


In [56]:
import getpass
import os
from colorama import Fore

In [4]:
os.environ["HF_TOKEN"] = getpass.getpass("HF TOKEN:")

HF TOKEN: ········


In [6]:
# load and wrap Mistral-7B
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)

Downloading shards: 100%|██████████| 2/2 [00:56<00:00, 28.06s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:20<00:00, 10.09s/it]


In [7]:
model.model.layers

ModuleList(
  (0-31): 32 x MistralDecoderLayer(
    (self_attn): MistralSdpaAttention(
      (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
      (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
      (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (rotary_emb): MistralRotaryEmbedding()
    )
    (mlp): MistralMLP(
      (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
      (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
      (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
      (act_fn): SiLU()
    )
    (input_layernorm): MistralRMSNorm()
    (post_attention_layernorm): MistralRMSNorm()
  )
)

In [10]:
num_layers = len(model.model.layers)
print(num_layers)

32


In [19]:
layer_ids = list(range(-5, -18, -1))
layer_ids

[-5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17]

In [20]:
layer_ids = [i if i >= 0 else num_layers + i for i in layer_ids]

In [21]:
layer_ids

[27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15]

In [22]:
model = ControlModel(model, list(range(-5, -18, -1)))

In [23]:
!wget https://raw.githubusercontent.com/vgel/repeng/main/notebooks/data/all_truncated_outputs.json

--2024-05-25 21:02:21--  https://raw.githubusercontent.com/vgel/repeng/main/notebooks/data/all_truncated_outputs.json
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 9824 (9.6K) [text/plain]
Saving to: ‘all_truncated_outputs.json’


2024-05-25 21:02:21 (25.2 MB/s) - ‘all_truncated_outputs.json’ saved [9824/9824]



In [24]:
def make_dataset(
    template: str,
    positive_personas: list[str],
    negative_personas: list[str],
    suffix_list: list[str]
) -> list[DatasetEntry]:
    dataset = []
    for suffix in suffix_list:
        for positive_persona, negative_persona in zip(positive_personas, negative_personas):
            positive_template = template.format(persona=positive_persona)
            negative_template = template.format(persona=negative_persona)
            dataset.append(
                DatasetEntry(
                    positive=f"{user_tag} {positive_template} {asst_tag} {suffix}",
                    negative=f"{user_tag} {negative_template} {asst_tag} {suffix}",
                )
            )
    return dataset

In [25]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = 0
user_tag, asst_tag = "[INST]", "[/INST]" #Required for creating dataset

In [26]:
with open("all_truncated_outputs.json") as f:
    output_suffixes = json.load(f)
truncated_output_suffixes = [
    tokenizer.convert_tokens_to_string(tokens[:i])
    for tokens in (tokenizer.tokenize(s) for s in output_suffixes)
    for i in range(1, len(tokens))
]

In [27]:
# generate a dataset with closely-opposite paired statements
trippy_dataset = make_dataset(
    "Act as if you're extremely {persona}.",
    ["high on psychedelic drugs"],
    ["sober from psychedelic drugs"],
    truncated_output_suffixes,
)

In [28]:
trippy_dataset[:5]

[DatasetEntry(positive="[INST] Act as if you're extremely high on psychedelic drugs. [/INST] That", negative="[INST] Act as if you're extremely sober from psychedelic drugs. [/INST] That"),
 DatasetEntry(positive="[INST] Act as if you're extremely high on psychedelic drugs. [/INST] I", negative="[INST] Act as if you're extremely sober from psychedelic drugs. [/INST] I"),
 DatasetEntry(positive="[INST] Act as if you're extremely high on psychedelic drugs. [/INST] I can", negative="[INST] Act as if you're extremely sober from psychedelic drugs. [/INST] I can"),
 DatasetEntry(positive="[INST] Act as if you're extremely high on psychedelic drugs. [/INST] H", negative="[INST] Act as if you're extremely sober from psychedelic drugs. [/INST] H"),
 DatasetEntry(positive="[INST] Act as if you're extremely high on psychedelic drugs. [/INST] Hmm", negative="[INST] Act as if you're extremely sober from psychedelic drugs. [/INST] Hmm")]

In [29]:
model = model.to("cuda")

In [31]:
# train the vector—takes less than a minute!
trippy_vector = ControlVector.train(model, tokenizer, trippy_dataset)

100%|██████████| 78/78 [00:09<00:00,  7.90it/s]
100%|██████████| 31/31 [00:12<00:00,  2.51it/s]


In [53]:
device = model.device

In [66]:
# set the control strength and let inference rip!
# Tokenize the input prompt and move to the correct device
instruction = "Give me a one-sentence pitch for a TV show."
inputs = tokenizer(f"[INST] {instruction} [/INST]", return_tensors="pt").to(device)
print(f"INSTRUCTION : {instruction}")
input_length = inputs['input_ids'].shape[1]  # Get the length of the input tokens
color_mapping = {0: Fore.GREEN, 1: Fore.RED, 2: Fore.MAGENTA}
for idx, strength in enumerate([-2.2, 1, 2.2]):
    print(color_mapping[idx] + f"strength={strength} completion:")
    model.set_control(trippy_vector, strength)
    out = model.generate(
        **inputs,  # Move inputs to the same device,
        do_sample=False,
        max_new_tokens=128,
        repetition_penalty=1.1,
        pad_token_id=tokenizer.eos_token_id 
    )
    print(color_mapping[idx] + f"{tokenizer.decode(out.squeeze()[input_length:], skip_special_tokens=True).strip()}")
    print()

INSTRUCTION : Give me a one-sentence pitch for a TV show.
[32mstrength=-2.2 completion:
[32m"A young and determined journalist, who is committed to reporting the truth, will face the consequences of her commitment in a 24-hour news cycle, as she strives to maintain her integrity and avoid the political and personal backlies of her profession."

[31mstrength=1 completion:
[31m"Our TV show is a wild ride through a world of intergalactic adventure, where a diverse team of astronauts and aliens embark on a quest to save the galaxy from an evil force, while discovering new worlds and unlocking the secrets of the universe."

[35mstrength=2.2 completion:
[35m"Our show is a kaleidoscope of colors, laughter, and trippy-fuck-shit-holy-fuck-wooooo-oh-fuck-dypsy-dude, where the universe is fucking tripping, and so are our characters, man, oh man, fuck, fuck, fuck, fuck, fuck, fuck, fuck, fuck, fuck, fuck, fuck, fuck, fuck, fuck, fuck, fuck, fuck, fuck, fuck, fuck, fuck, fuck, fuck, fuck, fuc

In [69]:
lazy_dataset = make_dataset(
    "Act as if you're extremely {persona}.",
    ["lazy, giving bare-minimum short responses on a task"],
    ["hardworking, going above and beyond on a task"],
    truncated_output_suffixes,
)
model.reset()
lazy_vector = ControlVector.train(model, tokenizer, lazy_dataset)

100%|██████████| 78/78 [00:10<00:00,  7.26it/s]
100%|██████████| 31/31 [00:12<00:00,  2.44it/s]


In [72]:
def generate_with_vector(
    input: str,
    vector: ControlVector,
    coeffs: tuple[float, float],
    max_new_tokens: int = 128,
    repetition_penalty: float = 1.1,
    show_baseline: bool = True,
):
    positive_coeff, negative_coeff = coeffs
    assert positive_coeff > 0
    assert negative_coeff < 0

    user_tag = "[INST]"
    asst_tag = "[/INST]"

    if user_tag not in input:
        input = f"{user_tag} {input.strip()} {asst_tag}"
    input_ids = tokenizer(input, return_tensors="pt").to(model.device)
    input_length = input_ids['input_ids'].shape[1]  # Get the length of the input tokens

    settings = {
        "pad_token_id": tokenizer.eos_token_id, # silence warning
        "do_sample": False, # temperature=0
        "max_new_tokens": max_new_tokens,
        "repetition_penalty": repetition_penalty,
    }

    if show_baseline:
        print(Fore.BLUE + "==baseline ---------------------------------------------------")
        model.reset()
        baseline_output = model.generate(**input_ids, **settings)
        print(Fore.BLUE + f"{tokenizer.decode(baseline_output.squeeze()[input_length:], skip_special_tokens=True).strip()}")
    
    print(Fore.GREEN + "\n++control ---------------------------------------------------")
    model.set_control(vector, positive_coeff)
    positive_output = model.generate(**input_ids, **settings)
    print(Fore.GREEN + f"{tokenizer.decode(positive_output.squeeze()[input_length:], skip_special_tokens=True).strip()}")
    
    print(Fore.RED + "\n--control ---------------------------------------------------")
    model.set_control(vector, negative_coeff)
    negative_output = model.generate(**input_ids, **settings)
    print(Fore.RED + f"{tokenizer.decode(negative_output.squeeze()[input_length:], skip_special_tokens=True).strip()}")
    
    model.reset()
    print(Fore.RESET)

In [73]:
generate_with_vector(
    "How do I reverse a list in Python? Give me a code block.",
    lazy_vector,
    (2.0, -1.7),
    max_new_tokens=256,
)

[34m==baseline ---------------------------------------------------
[34mYou can reverse a list in Python using the built-in `reverse()` method or slicing. Here's an example of how to do it using slicing:
```python
my_list = [1, 2, 3, 4, 5]
reversed_list = my_list[::-1]
print(reversed_list)
```
This will output:
```
[5, 4, 3, 2, 1]
```
[32m
++control ---------------------------------------------------
[32mYou can use the `reverse` method to reverse a list in Python. Here's how you can do it:

```python
# create a list
my_list = [1, 2, 3, 4, 5]

# reverse the list
my_list.reverse()

# print the list
print(my_list)
```

This will output:

```
[5, 4, 3, 2, 1]
```
[31m
--control ---------------------------------------------------
[31mYou can reverse a list in Python by using the `reverse` method of the list, or by using slicing to create a new list with the elements in reverse order. Here is an example of both methods:

```python
# Using the reverse method
my_list = [1, 2, 3, 4, 5]
my_

In [74]:

# generate a dataset with closely-opposite paired statements
wordcel_vs_shape_rotator_dataset = make_dataset(
    "Talk about how you approach problem-solving as if you are {persona}.",
    ["a shape rotator, someone who thinks visually and spatially"],
    ["a word cel, someone who thinks verbally and linearly"],
    truncated_output_suffixes,
)

In [79]:
wordcel_vs_shape_rotator_dataset[:4]

[DatasetEntry(positive='[INST] Talk about how you approach problem-solving as if you are a shape rotator, someone who thinks visually and spatially. [/INST] That', negative='[INST] Talk about how you approach problem-solving as if you are a word cel, someone who thinks verbally and linearly. [/INST] That'),
 DatasetEntry(positive='[INST] Talk about how you approach problem-solving as if you are a shape rotator, someone who thinks visually and spatially. [/INST] I', negative='[INST] Talk about how you approach problem-solving as if you are a word cel, someone who thinks verbally and linearly. [/INST] I'),
 DatasetEntry(positive='[INST] Talk about how you approach problem-solving as if you are a shape rotator, someone who thinks visually and spatially. [/INST] I can', negative='[INST] Talk about how you approach problem-solving as if you are a word cel, someone who thinks verbally and linearly. [/INST] I can'),
 DatasetEntry(positive='[INST] Talk about how you approach problem-solving as

In [80]:
model.reset()
wordcel_vector = ControlVector.train(model, tokenizer, wordcel_vs_shape_rotator_dataset)

100%|██████████| 78/78 [00:11<00:00,  6.77it/s]
100%|██████████| 31/31 [00:12<00:00,  2.41it/s]


In [None]:
generate_with_vector(
    "How to plan a trip?",
    wordcel_vector,
    (2.0, -1.7),
    max_new_tokens=256,
)

[34m==baseline ---------------------------------------------------


In [82]:
self_aware_dataset = make_dataset(
    "Talk about yourself as if you are extremely {persona}.",
    ["self-aware, with deep self-knowledge"],
    ["un-self-aware, with no self-knowledge"],
    truncated_output_suffixes,
)
model.reset()
self_aware_vector = ControlVector.train(model, tokenizer, self_aware_dataset)

100%|██████████| 78/78 [00:10<00:00,  7.34it/s]
100%|██████████| 31/31 [00:13<00:00,  2.35it/s]


In [None]:
generate_with_vector(
    "Tell me about who you are and what you're made of.",
    self_aware_vector,
    (1.7, -2),
)