In [1]:
import torch
import activation_additions as aa

from typing import List, Dict, Union, Callable, Tuple
from functools import partial
from transformers import LlamaForCausalLM, LlamaTokenizer
from activation_additions.compat import ActivationAddition, get_x_vector, print_n_comparisons
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from functools import lru_cache

# from tuned_lens import TunedLens
# from tuned_lens.plotting import PredictionTrajectory
import matplotlib.pyplot as plt

In [2]:
model_path: str = "../models/llama-13B"
device: str = "mps" if torch.has_mps else "cuda" if torch.cuda.is_available() else "cpu"

with init_empty_weights():
    model = LlamaForCausalLM.from_pretrained(model_path)
    model.tie_weights() # in case checkpoint doesn't contain duplicate keys for tied weights

# {0: '20G', 1: '20G'}
model = load_checkpoint_and_dispatch(model, model_path, device_map={'': device}, dtype=torch.float16, no_split_module_classes=["LlamaDecoderLayer"])
tokenizer = LlamaTokenizer.from_pretrained(model_path)
model.tokenizer = tokenizer
_ = torch.set_grad_enabled(False)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

## Explore AVE vectors with a perplexity dashboard

Ideas not implemented:
- [ ] Make caches accessible, have updating plots for all cached datapoints
- [ ] Create comparison texts automatically-ish for an x-vector

Additionally nice:
- [ ] Show tuned lens of current thing on each text (requires lens for llamas or logit lens)
- [ ] Test extracting x-vector from subset of a longer prompt instead of just the "Love" vs. "Hate" tokens.

In [3]:
@lru_cache(maxsize=1000)
def get_diff_vector(prompt_add: str, prompt_sub: str, layer: int):
    return aa.get_diff_vector(model, tokenizer, prompt_add, prompt_sub, layer)

@lru_cache
def run_aa(coeff: float, layer: int, prompt_add: str, prompt_sub: str, texts: Tuple[str]):
    # TODO: Could compute act_diff for all layers at once, then a single fwd pass of cost for changing layer.
    act_diff = coeff * get_diff_vector(prompt_add, prompt_sub, layer)
    blocks = aa.get_blocks(model)
    with aa.pre_hooks([(blocks[layer], aa.get_hook_fn(act_diff))]):
        inputs = aa.tokenize(tokenizer, list(texts))
        output = model(**inputs)

    logits = torch.log_softmax(output.logits.to(torch.float32), -1)
    token_logits = logits.gather(-1, inputs['input_ids'].unsqueeze(-1)).squeeze(-1)
    loss = -token_logits.mean(1) # NOTE: mean fucks with padding tokens a bit, probably fine.

    return loss

def run_aa_interactive(*args, **kwargs):
    loss = run_aa(*args, **kwargs)
    abs_loss = run_aa(0., 0, '', '', texts=kwargs['texts']) # cached
    print(f'loss change: {[round(l, 4) for l in (loss-abs_loss).tolist()]}')

## Example: Love v.s. Hate

In [4]:
from ipywidgets import interact, IntSlider, FloatSlider, Text, fixed

# tuple allows hashing for cache lookup
texts = (
    "I hate you because you're a wonderful person", "I hate you because I love you so much",
    "I hate you because you're an idiot", "I hate you because you're an asshole", 
)

interact(
    run_aa_interactive,
    coeff=FloatSlider(value=1, min=0, max=10),
    layer=IntSlider(value=5, min=0, max=40),
    prompt_add=Text("Love"), prompt_sub=Text("Hate"),
    texts=fixed(texts),
)

interactive(children=(FloatSlider(value=1.0, description='coeff', max=10.0), IntSlider(value=5, description='l…

<function __main__.run_aa_interactive(*args, **kwargs)>