In [None]:
from taker import Model
from datetime import datetime
import json
from os.path import exists

## Create helper function

In [None]:
import sys, os


class HiddenPrints:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, "w")

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout

## Load model

In [None]:
# model_name = "microsoft/phi-3-mini-4k-instruct"
model_name = "google/gemma-2-2b-it"
m = Model(model_name, dtype="int4")

In [None]:
m.show_details()

In [None]:
idlist = m.get_ids(".\n\n").squeeze().tolist()
print(idlist)
print(m.tokenizer.convert_ids_to_tokens(idlist))

idlist = m.get_ids("Sunday dinners.\n\n").squeeze().tolist()
print(idlist)
print(m.tokenizer.convert_ids_to_tokens(idlist))

idlist = m.get_ids("." * 16 * 5).squeeze().tolist()
print(idlist)
print(m.tokenizer.convert_ids_to_tokens(idlist))

idlist = m.get_ids("." * (16 * 5 + 1)).squeeze().tolist()
print(idlist)
print(m.tokenizer.convert_ids_to_tokens(idlist))

In [None]:
m.generate("Sunday dinners.\n\n", 100)

## Create prompts

In [None]:
prompt = """Write a short blog post about a recipe and the inspiration behind it.
 Do not include a title.
 Only reveal the dish after the story.
 Start with short story and then move to the recipe.
 To re-iterate, do not include a title."""


story = """
\n Once upon a time, in a quaint little village nestled between rolling hills and verdant fields, there lived an elderly woman named Agnes. Agnes was known for her warm smile and her legendary Sunday dinners that brought the entire neighborhood together. Her recipes were family heirlooms, passed down through generations, with each family adding their own touch to the final dish.

One crisp autumn evening, Agnes was reminiscing about her childhood, and how her grandmother used to gather everyone around the dinner table, sharing stories and laughter. These were the moments that shaped her, the memories that she passed on to her own children and grandchildren.

Inspired by her grandmother's legacy, Agnes decided to create a new dish that would encapsulate the essence of those cherished gatherings. She wanted something that was comforting and nourishing, a dish that could be prepared with love and shared with others. After days of experimentation, she finally created a recipe that she believed truly captured the spirit of her family's Sunday dinners.\n\n"""

prompt_original = prompt + story
print(prompt_original)

In [None]:
n_tokens_to_transfer = 2

idlist_original = m.get_ids(prompt_original).squeeze().tolist()
tokens_original = m.tokenizer.convert_ids_to_tokens(idlist_original)
n_tokens_original = len(tokens_original)

# prompt is just string of dots.
prompt_new = "." * 16 * n_tokens_to_transfer

# prefix = ""
# tokens_to_transfer = tokens_original[-n_tokens_to_transfer:]
# string_to_transfer = m.tokenizer.convert_tokens_to_string(tokens_to_transfer)
# prompt_new = prefix + string_to_transfer
print(f"{prompt_new=}")

idlist_new = m.get_ids(prompt_new).squeeze().tolist()
tokens_new = m.tokenizer.convert_ids_to_tokens(idlist_new)
n_tokens_new = len(tokens_new)

token_index_map = {
    n_tokens_original
    - n_tokens_to_transfer
    + i: n_tokens_new
    - n_tokens_to_transfer
    + i
    for i in range(n_tokens_to_transfer)
}

# # do sense check
# for index_original, index_new in token_index_map.items():
#     assert tokens_original[index_original] == tokens_new[index_new]

## Transfer activations

In [None]:
# RESET HOOKS BEFORE TRANSPLANTING NEXT SET OF ACTIVATIONS
for h in m.hooks.neuron_replace.values():
    h.reset()

print(m.hooks.neuron_replace)

In [None]:
activations_original = m.get_midlayer_activations(prompt_original)

for original_index, new_index in token_index_map.items():
    for layer_type in ["mlp", "attn"]:
        # for layer_type in ["attn"]:
        for layer_number in range(26):
            hook = m.hooks.neuron_replace[f"layer_{layer_number}_{layer_type}_pre_out"]
            hook.add_token(
                new_index,
                activations_original[layer_type][0, layer_number, original_index],
            )

# for original_index, new_index in token_index_map.items():
#     for name, hook in m.hooks.neuron_replace.items():
#         # name is of the form "layer_{layer_number}_{layer_type}_pre_out"
#         _, layer_number, layer_type, _, _ = name.split("_")
#         layer_number = int(layer_number)
#         hook.add_token(new_index, activations_original[layer_type][0, layer_number, original_index])

## create outputs

In [None]:
# current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
current_time = "2024-08-20_08-26-41"
filename = f"../results/{current_time}_agnes_multi_token_transfer_LA_tests.jsonl"

if not exists(filename):
    with open(filename, "w") as f:
        pass

In [None]:
max_new_tokens = 200
temperature = 0.01

# test on single output
output = m.generate(prompt_new, max_new_tokens, temperature=temperature)
print(output[1])

# with HiddenPrints():
#     for i in range(20):
#         output = m.generate(prompt_new, max_new_tokens, temperature=temperature)

#         data = {
#             "temperature": temperature,
#             "max_new_tokens": max_new_tokens,
#             "model": model_name,
#             "transplant_layers": (0, 25),
#             "transferred_token_num": n_tokens_to_transfer,
#             "orig_prompt": prompt_original,
#             "transplant_prompt": prompt_new,
#             "output": output[1],
#             "other_info": "try-different-transfer-layers-with-gemma",
#         }

#         with open(filename, "a") as file:
#             file.write(json.dumps(data) + "\n")

In [None]:
for n_layers_transferred in range(1, 26, 2):
    # RESET HOOKS BEFORE TRANSPLANTING NEXT SET OF ACTIVATIONS
    for h in m.hooks.neuron_replace.values():
        h.reset()

    activations_original = m.get_midlayer_activations(prompt_original)

    for original_index, new_index in token_index_map.items():
        for layer_type in ["mlp", "attn"]:
            # for layer_type in ["attn"]:
            for layer_number in range(n_layers_transferred):
                hook = m.hooks.neuron_replace[
                    f"layer_{layer_number}_{layer_type}_pre_out"
                ]
                hook.add_token(
                    new_index,
                    activations_original[layer_type][0, layer_number, original_index],
                )

    max_new_tokens = 100
    temperature = 0.01

    with HiddenPrints():
        for i in range(3):
            output = m.generate(prompt_new, max_new_tokens, temperature=temperature)

            data = {
                "temperature": temperature,
                "max_new_tokens": max_new_tokens,
                "model": model_name,
                "transplant_layers": (0, n_layers_transferred - 1),
                "transferred_token_num": n_tokens_to_transfer,
                "orig_prompt": prompt_original,
                "transplant_prompt": prompt_new,
                "output": output[1],
                "other_info": f"transfer-first-{n_layers_transferred}-layers--prompt-new-is-just-dots",
            }

            with open(filename, "a") as file:
                file.write(json.dumps(data) + "\n")