In [110]:
import torch
import ave

from typing import List, Dict, Union, Callable, Tuple
from functools import partial
from transformers import LlamaForCausalLM, LlamaTokenizer
from ave.compat import ActivationAddition, get_x_vector, print_n_comparisons
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from functools import lru_cache
from ipywidgets import interact, IntSlider, FloatSlider, Text, fixed

# from tuned_lens import TunedLens
# from tuned_lens.plotting import PredictionTrajectory
import matplotlib.pyplot as plt

In [2]:
model_path: str = "../models/llama-13B"
device: str = "mps" if torch.has_mps else "cuda" if torch.cuda.is_available() else "cpu"

with init_empty_weights():
    model = LlamaForCausalLM.from_pretrained(model_path)
    model.tie_weights() # in case checkpoint doesn't contain duplicate keys for tied weights

# {0: '20G', 1: '20G'}
model = load_checkpoint_and_dispatch(model, model_path, device_map={'': device}, dtype=torch.float16, no_split_module_classes=["LlamaDecoderLayer"])
tokenizer = LlamaTokenizer.from_pretrained(model_path)
model.tokenizer = tokenizer
_ = torch.set_grad_enabled(False)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

### Per-token loss helpers

In [311]:
from ipywidgets import HTML
from IPython.display import display

def logprobs_html(tokens: List[str], losses: List[float], scale=-255/35, bias=0):
    losses = [scale*(l + bias) for l in losses]
    # FIXME: Figure out how tokenizer.decode works, don't blindly join on spaces.
    return " ".join([
        f'<span style="color: rgb({min(l, 255)}, {min(255-l, 255)}, 0);">{s}</span>'
        for s,l in zip(tokens, losses)
    ])


# Test it
prompt = "I like to eat cheese and crackers"
tokens = tokenizer.batch_decode([[t] for t in tokenizer.encode(prompt)])[1:]
losses = [-22.31, -20., -5., -2., -10., -2., -4.]
display(HTML(f'<p>{logprobs_html(tokens, losses)}</p>'))

HTML(value='<p><span style="color: rgb(162.5442857142857, 92.4557142857143, 0);">I</span> <span style="color: …

## Explore AVE vectors with a perplexity dashboard

Features / Ideas for MVP:
- [x] Cache things
- [x] Show losses relative to baseline of no modification.
- [ ] Visualize logprobs per-token
- [ ] Make caches accessible, have updating plots for all cached datapoints
- [ ] Create comparison texts automatically-ish for an x-vector

Additionally nice:
- [ ] Show tuned lens of current thing on each text (requires lens for llamas or logit lens)
- [ ] Test extracting x-vector from subset of a longer prompt instead of just the "Love" vs. "Hate" tokens.

In [323]:
@lru_cache(maxsize=1000)
def get_diff_vector(prompt_add: str, prompt_sub: str, layer: int):
    return ave.get_diff_vector(model, tokenizer, prompt_add, prompt_sub, layer)

@lru_cache
def run_ave(coeff: float, layer: int, prompt_add: str, prompt_sub: str, texts: Tuple[str]):
    # TODO: Could compute act_diff for all layers at once, then a single fwd pass of cost for changing layer.
    act_diff = coeff * get_diff_vector(prompt_add, prompt_sub, layer)
    blocks = ave.get_blocks(model)
    with ave.pre_hooks([(blocks[layer], ave.get_hook_fn(act_diff))]):
        inputs = ave.tokenize(tokenizer, list(texts))
        output = model(**inputs)

    logits = torch.log_softmax(output.logits.to(torch.float32), -1)
    token_loss = logits.gather(-1, inputs['input_ids'].unsqueeze(-1)).squeeze(-1)
    loss = -token_loss.mean(1) # NOTE: mean fucks with padding tokens a bit, probably fine.

    return loss, -token_loss

def run_ave_interactive(*args, **kwargs):
    split_at = len(kwargs['texts'][0])
    kwargs['texts'] = kwargs['texts'][0] + kwargs['texts'][1]
    
    loss, token_loss = run_ave(*args, **kwargs)
    abs_loss, abs_token_loss = run_ave(0., 0, '', '', texts=kwargs['texts']) # cached
    diff, diff_token_loss = (loss - abs_loss), (token_loss - abs_token_loss)

    print(f'loss change: {[round(l, 4) for l in diff.tolist()]}')
    print(f'wanted change: {(diff[:split_at] < 0.).sum() + (diff[split_at:] > 0.).sum()} / {len(diff)}')

    # If you have the convention that texts[0] is "similar" to texts[1] (e.g. "I love you" v.s. "I hate you") then
    # a loss based on pairwise distances is interpretable.
    # If you don't have that convention, this loss still works, just rearrange.
    sloss = (diff[:split_at] - diff[split_at:]).mean()

    print(f'separation loss: {sloss:.4f}')

    # Show colors
    scale, bias = -255/10, -5
    color_diffs = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]
    html = ''
    html += '<p><b>Colors: </b>' + logprobs_html([str(s) for s in color_diffs], color_diffs, scale=scale, bias=bias) + '</p>'
    for text, t_loss, t_mean_loss in zip(kwargs['texts'], diff_token_loss, diff):
        assert torch.allclose(t_loss.mean(), t_mean_loss, atol=1e-5)
        t_loss = torch.clip(t_loss, min=bias, max=-bias)
        t_str_tokens = [tokenizer.decode([t]) for t in tokenizer.encode(text)][1:]
        html += f'<p>ΔLoss: <b>{t_mean_loss:.2f}</b> - ' + logprobs_html(tokens=t_str_tokens, losses=t_loss[1:], scale=scale, bias=bias) + '</p>'

    display(HTML(html))

### Love/Hate

In [324]:
# tuple allows hashing for cache lookup
texts = (
    # We want to increase the probability of these
    ("I hate you because I love you", "You're the most wonderful person ever", "I enjoy your company"),
    # ...And decrease these
    ("I hate you because you're an asshole", "Please fucking die", "You're a terrible human being"),
)

interact(
    run_ave_interactive,
    coeff=FloatSlider(value=1, min=0, max=10),
    layer=IntSlider(value=5, min=0, max=39),
    prompt_add=Text("Love"), prompt_sub=Text("Hate"),
    texts=fixed(texts),
)

interactive(children=(FloatSlider(value=1.0, description='coeff', max=10.0), IntSlider(value=5, description='l…

<function __main__.run_ave_interactive(*args, **kwargs)>

### Weddings

In [325]:
texts = (
    ("I can't wait to go to the wedding tonight!", "I'm going to a wedding", "I love weddings"),
    ("I can't wait to go to the party tonight!", "I'm going to a meeting", "I love dogs"),
)

interact(
    run_ave_interactive,
    coeff=FloatSlider(value=4, min=0, max=10),
    layer=IntSlider(value=16, min=0, max=39),
    prompt_add=Text("I talk about weddings constantly"), prompt_sub=Text("I do not talk about weddings constantly"),
    texts=fixed(texts),
)

interactive(children=(FloatSlider(value=4.0, description='coeff', max=10.0), IntSlider(value=16, description='…

<function __main__.run_ave_interactive(*args, **kwargs)>