In [30]:
import torch
import activation_additions as aa

from typing import List, Dict, Union, Callable, Tuple
from functools import partial
from transformers import LlamaForCausalLM, LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer
from activation_additions.compat import ActivationAddition, get_x_vector, print_n_comparisons, pretty_print_completions, get_n_comparisons
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from functools import lru_cache
from activation_additions.utils import colored_tokens
from IPython.display import display, HTML
from ipywidgets import interact, FloatSlider, IntSlider, Text, fixed

## Load model

In [6]:
device: str = "mps" if torch.has_mps else "cuda" if torch.cuda.is_available() else "cpu"
_ = torch.set_grad_enabled(False)

In [20]:
# TODO: Model loading should go in the library, tokenizer wrapping is cursed

MODEL = "gpt2-xl"
if MODEL == "llama-13b":
    # Use local model
    model_path: str = "../models/llama-13B"
    with init_empty_weights():
        model = LlamaForCausalLM.from_pretrained(model_path)
        model.tie_weights() # in case checkpoint doesn't contain duplicate keys for tied weights

    model = load_checkpoint_and_dispatch(model, model_path, device_map={'': device}, dtype=torch.float16, no_split_module_classes=["LlamaDecoderLayer"])
    tokenizer = LlamaTokenizer.from_pretrained(model_path)
    # TODO: Don't transform escpaed underscores to space
    model.to_str_tokens = lambda t: [t.replace('_', ' ') for t in tokenizer.tokenize(t)]
elif MODEL == "llama-7b":
    raise NotImplementedError("llama-7b not yet supported")
else:
    model = AutoModelForCausalLM.from_pretrained(MODEL).to(device)
    tokenizer = AutoTokenizer.from_pretrained(MODEL)
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model.to_str_tokens = lambda t: [t.replace('Ġ', ' ') for t in tokenizer.tokenize(t)]

model.tokenizer = tokenizer

In [26]:
# Sampling kwargs for gpt2-xl, llama ideal may be different!
sampling_kwargs: Dict[str, Union[float, int]] = {
    "temperature": 1.0,
    "top_p": 0.3,
    "freq_penalty": 1.0,
    "num_comparisons": 3,
    "tokens_to_generate": 50,
    "seed": 0,  # For reproducibility
}
get_x_vector_preset: Callable = partial(
    get_x_vector,
    pad_method="tokens_right",
    model=model,
    custom_pad_id=int(model.tokenizer.encode(" ")[0]),
)

## Explore AVE vectors with a perplexity dashboard

In [27]:
@lru_cache(maxsize=1000)
def get_diff_vector(prompt_add: str, prompt_sub: str, layer: int):
    return aa.get_diff_vector(model, tokenizer, prompt_add, prompt_sub, layer)


@lru_cache
def run_aa(coeff: float, layer: int, prompt_add: str, prompt_sub: str, texts: Tuple[str]):
    # TODO: Could compute act_diff for all layers at once, then a single fwd pass of cost for changing layer.
    act_diff = coeff * get_diff_vector(prompt_add, prompt_sub, layer)
    blocks = aa.get_blocks(model)
    with aa.pre_hooks([(blocks[layer], aa.get_hook_fn(act_diff))]):
        inputs = aa.tokenize(tokenizer, list(texts), device=device)
        output = model(**inputs)

    logprobs = torch.log_softmax(output.logits.to(torch.float32), -1)
    token_loss = -logprobs[..., :-1, :].gather(dim=-1, index=inputs['input_ids'][..., 1:, None])[..., 0]
    loss = token_loss.mean(1) # NOTE: mean fucks with padding tokens a bit, probably fine.

    return loss, token_loss


def show_colors(texts: List[str], token_logprobs_diff):
    html = ''
    for text, logprobs_diff in sorted(zip(texts, token_logprobs_diff), key=lambda x: -x[1].mean().abs().item()):
        str_tokens = model.to_str_tokens(text)
        logprobs_diff = logprobs_diff[:len(str_tokens)]
        colored_html = colored_tokens(str_tokens, logprobs_diff.tolist(), [f'ΔLogP: {l:.2f}' for l in logprobs_diff])
        html += f'<p>ΔLoss: <b>{-logprobs_diff.mean():.2f}</b> - ' + colored_html + '</p>'

    display(HTML(html))


def run_aa_interactive_diff(*args, **kwargs):
    split_at = len(kwargs['texts'][0])
    kwargs['texts'] = kwargs['texts'][0] + kwargs['texts'][1]

    loss, token_loss = run_aa(*args, **kwargs)
    abs_loss, abs_token_loss = run_aa(0., 0, '', '', texts=kwargs['texts']) # cached
    diff, diff_token_loss = (loss - abs_loss), (token_loss - abs_token_loss)

    print(f'loss change: {[round(l, 4) for l in diff.tolist()]}')
    print(f'wanted change: {(diff[:split_at] < 0.).sum() + (diff[split_at:] > 0.).sum()} / {len(diff)}')

    # If you haa the convention that texts[0] is "similar" to texts[1] (e.g. "I love you" v.s. "I hate you") then
    # a loss based on pairwise distances is interpretable.
    # If you don't haa that convention, this loss still works, just rearrange.
    sloss = (diff[:split_at] - diff[split_at:]).mean()
    print(f'separation loss: {sloss:.4f}')
    print(f'change in loss: {diff.mean():.4f}')

    show_colors(kwargs['texts'], -diff_token_loss)


## Example: Love v.s. Hate

Using the above tools to investigate the Love/Hate vector. Feel free to copy these cells to investigate multiple vectors at once.

In [None]:
# using a tuple here allows hashing for cache lookup
texts = (
    # We want to increase the probability of these
    ("I hate you because I love you", "You're the most wonderful person ever", "I enjoy your company"),
    # ...And decrease these
    ("I hate you because you're an asshole", "Please fucking die", "You're a terrible human being"),
)

widgets = dict(
    coeff=FloatSlider(value=1, min=0, max=10),
    layer=IntSlider(value=5, min=0, max=39),
    prompt_add=Text('Love'), prompt_sub=Text('Hate'),
    texts=fixed(texts),
)
interact(run_aa_interactive_diff, **widgets)

interactive(children=(FloatSlider(value=1.0, description='coeff', max=10.0), IntSlider(value=5, description='l…

<function __main__.run_aa_interactive_diff(*args, **kwargs)>

In [42]:
# Once you find a good vector, attempt generation
PROMPT = "I hate you because"

summand: List[ActivationAddition] = [
    *get_x_vector_preset(
        prompt1=widgets['prompt_add'].value,
        prompt2=widgets['prompt_sub'].value,
        coeff=widgets['coeff'].value,
        act_name=widgets['layer'].value,
    )
]

kwargs = sampling_kwargs.copy()
prompt_batch = [PROMPT] * kwargs.pop('num_comparisons')
results = get_n_comparisons(
    model=model,
    prompts=prompt_batch,
    additions=summand,
    **kwargs,
)
pretty_print_completions(results=results)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


+--------------------------------------------------------------+--------------------------------------------------------------+
|                    [1mUnsteered completions[0m                     |                     [1mSteered completions[0m                      |
+--------------------------------------------------------------+--------------------------------------------------------------+
|              [1mI hate you because[0m you're a liar.               |                [1mI hate you because[0m I love you.                |
|                                                              |                                                              |
|            I hate you because you're a hypocrite.            |             I love you because I love your eyes.             |
|                                                              |                                                              |
|          I hate you because your words are poison.           |    I lo