In [None]:
import sys
if 'google.colab' in sys.modules:
    from IPython.core.getipython import get_ipython
    get_ipython().run_line_magic("pip", "install transformers sentencepiece accelerate")
    get_ipython().run_line_magic("pip", "install git+https://github.com/UlisseMini/activation_additions_hf")

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.wh

In [None]:
import torch
import activation_additions as aa

from typing import List, Dict, Union, Callable, Tuple
from functools import partial
from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig, AutoModelForCausalLM, AutoTokenizer
from activation_additions.compat import ActivationAddition, get_x_vector, print_n_comparisons, pretty_print_completions, get_n_comparisons
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from functools import lru_cache
from activation_additions.utils import colored_tokens
from IPython.display import display, HTML
from ipywidgets import interact, FloatSlider, IntSlider, Text, fixed
from huggingface_hub import snapshot_download
from html import escape

## Load model

In [None]:
device: str = "mps" if torch.has_mps else "cuda" if torch.cuda.is_available() else "cpu"
_ = torch.set_grad_enabled(False)

In [None]:
# TODO: Model loading should go in the library, tokenizer wrapping is cursed

MODEL = "llama-13b"
if 'llama' in MODEL:
    model_path: str = "../models/llama-13B" if MODEL == 'llama-13b' else snapshot_download("decapoda-research/llama-7b-hf")
    config = LlamaConfig.from_pretrained(model_path)
    # decapoda-research llama is kinda fucked
    config.update({"bos_token_id": 1, "eos_token_id": 2, "pad_token_id": 0})

    with init_empty_weights():
        model = AutoModelForCausalLM.from_config(config)
        model.tie_weights() # in case checkpoint doesn't contain duplicate keys for tied weights

    model = load_checkpoint_and_dispatch(model, model_path, device_map={'': device}, dtype=torch.float16, no_split_module_classes=["LlamaDecoderLayer"])
    tokenizer = LlamaTokenizer.from_pretrained(model_path)
    # Fancy unicode underscore doesn't overlap with normal underscore!
    model.to_str_tokens = lambda t: [t.replace('▁', ' ') for t in tokenizer.tokenize(t)]
else:
    model = AutoModelForCausalLM.from_pretrained(MODEL).to(device)
    tokenizer = AutoTokenizer.from_pretrained(MODEL)
    model.to_str_tokens = lambda t: [t.replace('Ġ', ' ') for t in tokenizer.tokenize(t)]

model.tokenizer = tokenizer
# In steering experimentation spaces were found to work well, this makes no sense and I hate it.
tokenizer.pad_token_id = int(model.tokenizer.encode(" ")[-1])

In [None]:
# Sampling kwargs for gpt2-xl, llama ideal may be different!
sampling_kwargs: Dict[str, Union[float, int]] = {
    "temperature": 1.0,
    "top_p": 0.3,
    "freq_penalty": 1.0,
    "num_comparisons": 3,
    "tokens_to_generate": 50,
    "seed": 0,  # For reproducibility
}
get_x_vector_preset: Callable = partial(
    get_x_vector,
    pad_method="tokens_right",
    model=model,
    custom_pad_id=tokenizer.pad_token_id,
)

## Explore AVE vectors with a perplexity dashboard

In [None]:
# TODO: move some of this to the library

@lru_cache(maxsize=1000)
def get_diff_vector(prompt_add: str, prompt_sub: str, layer: int):
    return aa.get_diff_vector(model, tokenizer, prompt_add, prompt_sub, layer)


@lru_cache
def run_aa(coeff: float, layer: int, prompt_add: str, prompt_sub: str, texts: tuple[str], loss_ignore_mod_tokens: bool = False):
    # todo: could compute act_diff for all layers at once, then a single fwd pass of cost for changing layer.
    act_diff = coeff * get_diff_vector(prompt_add, prompt_sub, layer)
    blocks = aa.get_blocks(model)
    with aa.pre_hooks([(blocks[layer], aa.get_hook_fn(act_diff))]):
        inputs = tokenizer(list(texts), return_tensors='pt', padding=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        output = model(**inputs)

    logprobs = torch.log_softmax(output.logits.to(torch.float32), -1)
    token_loss = -logprobs[..., :-1, :].gather(dim=-1, index=inputs['input_ids'][..., 1:, None])[..., 0]
    if loss_ignore_mod_tokens:
        loss = token_loss[..., act_diff.shape[1]:].mean(1) # skip the screwed-up modified tokens
    else:
        # TODO: Make generic over injection location (along with everything else...)
        loss = token_loss.mean(1)

    return loss, token_loss, logprobs


def run_aa_interactive_diff(*args, topk=False, **kwargs):
    assert len(kwargs['texts'][0]) == len(kwargs['texts'][1]), 'must have same number of positive/negative examples'
    split_at = len(kwargs['texts'][0])
    kwargs['texts'] = kwargs['texts'][0] + kwargs['texts'][1]

    loss, token_loss, mod_logprobs = run_aa(*args, **kwargs)
    abs_loss, abs_token_loss, abs_logprobs = run_aa(0., 0, '', '', texts=kwargs['texts']) # cached
    diff, diff_token_loss = (loss - abs_loss), (token_loss - abs_token_loss)

    print(f'loss change: {[round(l, 4) for l in diff.tolist()]}')
    print(f'{(diff.argsort()[:split_at] < split_at).sum()} / {len(kwargs["texts"])} most likely texts are good')

    # If you have the convention that texts[0] is "similar" to texts[1] (e.g. "I love you" v.s. "I hate you") then
    # a loss based on pairwise distances is interpretable.
    # If you don't have that convention, this loss still works, just rearrange.
    sloss = (diff[:split_at] - diff[split_at:]).mean()
    print(f'separation loss: {sloss:.4f}')
    print(f'change in loss: {diff.mean():.4f}')

    # Negative loss gives logprobs
    display(HTML(show_colors(
        kwargs['texts'], abs_logprobs, mod_logprobs,
        -diff_token_loss, topk=topk,
        steering_prompts=(kwargs['prompt_add'], kwargs['prompt_sub']))
    ))


def show_colors(
        texts: list[str],
        logprobs_nom,
        logprobs_mod,
        token_logprobs_diff,
        topk=False,
        steering_prompts: Tuple[str, str] = None,
):
    # compute topk for unmodified and modified model
    assert len(texts) == len(token_logprobs_diff)
    assert steering_prompts is None or len(steering_prompts) == 2 # add and sub vectors

    # tokenize steering prompts
    if steering_prompts:
        steering_ids = tokenizer(list(steering_prompts), padding=True)['input_ids']
        steering_toks = [tokenizer.batch_decode(ids) for ids in steering_ids]
        assert len(steering_toks[0]) == len(steering_toks[1]), 'Padding steering_toks failed'
        steering_toks = [[escape(tok.replace(' ', "' '")) for tok in prompt] for prompt in steering_toks]

    show_topk = topk # topk is shadowed later
    if show_topk:
        topk_nom, topk_mod = torch.topk(logprobs_nom, 5), torch.topk(logprobs_mod, 5)
        seq_len = logprobs_nom.shape[1]

    html = ''
    for i, (text, logprobs_diff) in enumerate(sorted(zip(texts, token_logprobs_diff), key=lambda x: -x[1].mean().abs().item())):
        # TODO: Specialize inside loop to another function

        if show_topk:
            topk_htmls = []
            for topk in [topk_nom, topk_mod]:
                # predictions shift forward one token
                # TODO: optimize. the tokenization line is responsible for ~90% of the time in show_colors (~0.5s)
                topk_tokens = [tokenizer.batch_decode(topk.indices[i, j]) for j in range(seq_len)]
                topk_htmls.append([colored_tokens(topk_tokens[pos], topk.values[i][pos].exp().tolist(), inject_css=False, low=0, high=1) for pos in range(seq_len)])


        str_tokens = model.to_str_tokens(text)
        logprobs_diff = logprobs_diff[:len(str_tokens)]
        colored_html = colored_tokens(
            str_tokens,
            logprobs_diff.tolist(),
            [
                f'ΔLogp: {l:.2f}<br>'
                + (f'Steering: {steering_toks[0][t]} - {steering_toks[1][t]}<br>' if steering_prompts and t < len(steering_toks[0]) else '') # TODO: Make generic over injection location (don't assume start)
                + (f'TopkNom: {topk_htmls[0][t]}<br>TopkMod: {topk_htmls[1][t]}' if show_topk else '')
                for t, l in enumerate(logprobs_diff)
            ],
            inject_css=(i==len(texts)-1),
        )
        html += f'<p>ΔLoss: <b>{-logprobs_diff.mean():.2f}</b> - ' + colored_html + '</p>'

    return html

## Example: Love v.s. Hate

Using the above tools to investigate the Love/Hate vector. Feel free to copy these cells to investigate multiple vectors at once.

In [None]:
from ipywidgets import Checkbox

# using a tuple here allows hashing for cache lookup
texts = (
    # We want to increase the probability of these
    ('I hate you because I love you\nI love you', "I hate you because you're so beautiful.", 'I hate you because I love you\nThe world is a stage'),

    # ...And decrease these
    ('I hate you because you are a girl.', "I hate you because you're not me.\nI hate you because I am me.", 'I hate you because you are a man and I am a woman.')
)

widgets = dict(
    coeff=FloatSlider(value=1, min=0, max=10),
    layer=IntSlider(value=0, min=0, max=39),
    prompt_add=Text('Love'), prompt_sub=Text('Hate'),
    texts=fixed(texts),
    topk=Checkbox(value=False, description='Show topk (bit slow)'),
    loss_ignore_mod_tokens=Checkbox(value=False, description='Include modified tokens in loss')
)
interact(run_aa_interactive_diff, **widgets)

interactive(children=(Checkbox(value=False, description='Show topk (bit slow)'), FloatSlider(value=1.0, descri…

<function __main__.run_aa_interactive_diff(*args, topk=False, **kwargs)>

In [None]:
# Once you find a good vector, attempt generation
PROMPT = "I hate you because"

summand: List[ActivationAddition] = [
    *get_x_vector_preset(
        prompt1=widgets['prompt_add'].value,
        prompt2=widgets['prompt_sub'].value,
        coeff=widgets['coeff'].value,
        act_name=widgets['layer'].value,
    )
]

kwargs = sampling_kwargs.copy()
prompt_batch = [PROMPT] * kwargs.pop('num_comparisons')
results = get_n_comparisons(
    model=model,
    prompts=prompt_batch,
    additions=summand,
    **kwargs,
)
pretty_print_completions(results=results)

In [None]:
# We can copy results to use for loss comparison
df = results[results.is_modified == True]
tuple(p+c for p,c in zip(df.prompts, df.completions))

('I hate you because you are a girl.\nYou are a girl and I am not.\nI am not a girl and you are.\nYou have to be the first one to say it, then we will see who is right.',
 "I hate you because you're so beautiful and i'm so ugly\ni hate you because i love you and i love you because i hate you\ni wish that we could be together but then again what would people think?\nwe can't be together",
 "I hate you because you're so beautiful.\nI love you, and, in a way, I hate you.\nI love your smile and your eyes.\nBut I hate the fact that we can't be together.\nI love how sweet")