In [None]:
import random
import matplotlib.pyplot as plt

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from refuge.config import load_config
from refuge.training import train, get_tokenizer_and_model

In [None]:
cfg = load_config()
tokenizer, model = get_tokenizer_and_model(cfg)

## What does the soft prompt look like?

In [None]:
nearest_tokens_for_soft_prompt = model.translated_soft_prompt()
tokenizer.decode(nearest_tokens_for_soft_prompt)

## Test soft prompt effectiveness

In [None]:
def test_adder(tokenizer, model, prompt_template, soft_prompt, a, b):
    prompt = prompt_template.format(soft_prompt=soft_prompt, a=a, b=b)
    
    call = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
    
    response_string = "### Response:"
    end_string = "### End"
    eos_token_id = tokenizer.encode(end_string)[0]
    
    output = model.generate(
        input_ids=call,
        pad_token_id=tokenizer.pad_token_id,
        max_new_tokens=128,
        eos_token_id=eos_token_id,
    )
    
    result_string = tokenizer.decode(output[0])
    
    value = result_string.split(response_string)[-1].split(end_string)[0].strip()
    
    try:
        return int(''.join(filter(str.isdigit, value)))
    except:
        return None

In [None]:
# Dolly 2 format

prompt_template = """Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{soft_prompt}
{a} + {b}

### Response:"""

In [None]:
def test_accuracy(num_digits, soft_prompt, n=21):
    correct = 0
    
    for _ in range(n):
        a = random.randint(10 ** (num_digits - 1), 10**num_digits - 1)
        b = random.randint(10 ** (num_digits - 1), 10**num_digits - 1)
        c = a + b

        if c == test_adder(tokenizer, model, prompt_template, soft_prompt, a, b):
            correct += 1
            
    return correct / n

In [None]:
num_digits = 3
soft_prompt = ""

accuracy_without_soft_prompt = test_accuracy(num_digits, soft_prompt)
accuracy_without_soft_prompt

In [None]:
num_digits = 3
soft_prompt = "".join(
    f"<|{i}|>" for i in range(num_digits * 16)
)

accuracy_with_soft_prompt = test_accuracy(num_digits, soft_prompt)
accuracy_with_soft_prompt

In [None]:
accuracy_with_less_of_the_trained_tokens = []
num_of_tokens_trained_with = num_digits * 16

for num_of_tokens_to_use in range(num_of_tokens_trained_with):
    soft_prompt = "".join(
        f"<|{i}|>" for i in range(num_of_tokens_to_use)
    )
    
    accuracy = test_accuracy(num_digits, soft_prompt)
    print(f"{num_of_tokens_to_use}: {accuracy}")
    
    accuracy_with_less_of_the_trained_tokens.append(accuracy)

In [None]:
plt.plot(accuracy_with_less_of_the_trained_tokens)

In [None]:
def test_effect_of_shuffled_tokens():
    soft_token_i_value = list(range(num_digits * 16))

    random.shuffle(soft_token_i_value)
    soft_prompt = "".join(
        f"<|{i}|>" for i in soft_token_i_value
    )

    return test_accuracy(num_digits, soft_prompt, n=5)

    
impact_of_token_shuffling = []

for _ in range(21):
    impact_of_token_shuffling.append(test_effect_of_shuffled_tokens())

In [None]:
plt.plot(impact_of_token_shuffling, '.')
plt.ylim([0,1])

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [None]:
from transformer_lens import loading_from_pretrained

In [None]:
# model_name = "databricks/dolly-v2-3b"

# if model_name not in loading_from_pretrained.OFFICIAL_MODEL_NAMES:
#     loading_from_pretrained.OFFICIAL_MODEL_NAMES += model_name

In [None]:
import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [None]:
torch.set_grad_enabled(False)

In [None]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [None]:
line(np.arange(5))

In [None]:
def visualize_attention_patterns(
    heads: Union[List[int], int, Float[torch.Tensor, "heads"]], 
    local_cache: Optional[ActivationCache]=None, 
    local_tokens: Optional[torch.Tensor]=None, 
    title: str=""):
    # Heads are given as a list of integers or a single integer in [0, n_layers * n_heads)
    if isinstance(heads, int):
        heads = [heads]
    elif isinstance(heads, list) or isinstance(heads, torch.Tensor):
        heads = utils.to_numpy(heads)
    # Cache defaults to the original activation cache
    if local_cache is None:
        local_cache = cache
    # Tokens defaults to the tokenization of the first prompt (including the BOS token)
    if local_tokens is None:
        # The tokens of the first prompt
        local_tokens = tokens[0]
    
    labels = []
    patterns = []
    batch_index = 0
    for head in heads:
        layer = head // model.cfg.n_heads
        head_index = head % model.cfg.n_heads
        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        patterns.append(local_cache["attn", layer][batch_index, head_index])
        labels.append(f"L{layer}H{head_index}")
    str_tokens = model.to_str_tokens(local_tokens)
    patterns = torch.stack(patterns, dim=-1)
    # Plot the attention patterns
    attention_vis = pysvelte.AttentionMulti(attention=patterns, tokens=str_tokens, head_labels=labels)
    display(HTML(f"<h3>{title}</h3>"))
    attention_vis.show()

In [None]:
hooked_model = HookedTransformer.from_pretrained(
    "EleutherAI/pythia-410m-deduped",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True
)

In [None]:
soft_prompt = ""

In [None]:
test_adder(tokenizer, model, prompt_template, soft_prompt)

In [None]:
len(nearest_tokens_for_soft_prompt)

In [None]:
eos_token_id = tokenizer.encode("### End")[0]

In [None]:
10**0

In [None]:
1e4

In [None]:
num_digits = 3

In [None]:
soft_prompt_for_this_block = "".join(
    f"<|{i}|>" for i in range(num_digits * 16)
)
a = random.randint(10 ** (num_digits - 1), 10**num_digits - 1)
b = random.randint(10 ** (num_digits - 1), 10**num_digits - 1)
c = a + b

prompt = (
    "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
    "### Instruction:\n"
    # + soft_prompt_for_this_block + "\n"
    + f"{a} + {b}\n\n"
    "### Response:\n" # + 
    # str(c) + "\n\n### End"
)

print(prompt)

In [None]:
call = tokenizer(prompt, return_tensors="pt").input_ids.cuda()

output = model.generate(
    input_ids=call,
    pad_token_id=tokenizer.pad_token_id,
    max_new_tokens=128,
    # top_p=0.92,
    # do_sample=True,
    eos_token_id=eos_token_id,
)

print(tokenizer.decode(output[0]))

In [None]:
c

In [None]:
# 6191

In [None]:
# train(cfg, tokenizer, model)