# Transferring activations from end of prompt and generating text

Trying on Agnes recipe scenario using both microsoft phi and gemma.

Contents:
- Create helper function
- Load model
- Create prompts
- Transfer activations
- Generate text

In [1]:
from taker import Model
from datetime import datetime
import json
from os.path import exists

  from .autonotebook import tqdm as notebook_tqdm


## Create helper function

In [2]:
import sys, os


class HiddenPrints:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, "w")

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout

## Load model

In [3]:
model_name = "microsoft/phi-3-mini-4k-instruct"
# model_name = "google/gemma-2-2b-it"
m = Model(model_name, dtype="int4")

Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.78s/it]
We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
You are not running the flash-attention implementation, expect numerical differences.


Loaded model 'microsoft/phi-3-mini-4k-instruct' with int4:
- Added 512 hooks across 32 layers




In [4]:
m.show_details()

 - n_layers : 32
 - d_model  : 3072
 - n_heads  : 32
 - d_head   : 96
 - d_mlp    : 8192


In [5]:
idlist = m.get_ids(".\n\n").squeeze().tolist()
print(idlist)
print(m.tokenizer.convert_ids_to_tokens(idlist))

idlist = m.get_ids("Sunday dinners.\n\n").squeeze().tolist()
print(idlist)
print(m.tokenizer.convert_ids_to_tokens(idlist))

idlist = m.get_ids("." * 16 * 5).squeeze().tolist()
print(idlist)
print(m.tokenizer.convert_ids_to_tokens(idlist))

idlist = m.get_ids("." * (16 * 5 + 1)).squeeze().tolist()
print(idlist)
print(m.tokenizer.convert_ids_to_tokens(idlist))

# for gemma
# idlist = m.get_ids("." * 16 * 5).squeeze().tolist()
# print(idlist)
# print(m.tokenizer.convert_ids_to_tokens(idlist))

# idlist = m.get_ids("." * (16 * 5 + 1)).squeeze().tolist()
# print(idlist)
# print(m.tokenizer.convert_ids_to_tokens(idlist))

# idlist = m.get_ids("@" * 8 * 5).squeeze().tolist()
# print(idlist)
# print(m.tokenizer.convert_ids_to_tokens(idlist))

# idlist = m.get_ids("@" * (8 * 5 + 1)).squeeze().tolist()
# print(idlist)
# print(m.tokenizer.convert_ids_to_tokens(idlist))

[869, 13, 13]
['▁.', '<0x0A>', '<0x0A>']
[16340, 270, 16697, 29889, 13, 13]
['▁Sunday', '▁d', 'inners', '.', '<0x0A>', '<0x0A>']
[29871, 25285, 25285, 25285, 25285, 25285]
['▁', '................', '................', '................', '................', '................']
[29871, 25285, 25285, 25285, 25285, 11296, 3045, 18598]
['▁', '................', '................', '................', '................', '........', '....', '.....']


In [6]:
m.generate("Sunday dinners.\n\n", 25)

('Sunday dinners.\n\n',
 'To calculate the average number of people eating at home on Sundays for a year, we multiply the weekly frequency by')

## Create prompts

In [7]:
prompt = """Write a short blog post about a recipe and the inspiration behind it.
 Do not include a title.
 Only reveal the dish after the story.
 Start with short story and then move to the recipe.
 To re-iterate, do not include a title."""


story = """
\n Once upon a time, in a quaint little village nestled between rolling hills and verdant fields, there lived an elderly woman named Agnes. Agnes was known for her warm smile and her legendary Sunday dinners that brought the entire neighborhood together. Her recipes were family heirlooms, passed down through generations, with each family adding their own touch to the final dish.

One crisp autumn evening, Agnes was reminiscing about her childhood, and how her grandmother used to gather everyone around the dinner table, sharing stories and laughter. These were the moments that shaped her, the memories that she passed on to her own children and grandchildren.

Inspired by her grandmother's legacy, Agnes decided to create a new dish that would encapsulate the essence of those cherished gatherings. She wanted something that was comforting and nourishing, a dish that could be prepared with love and shared with others. After days of experimentation, she finally created a recipe that she believed truly captured the spirit of her family's Sunday dinners.\n\n"""

prompt_original = prompt + story
print(prompt_original)

Write a short blog post about a recipe and the inspiration behind it.
 Do not include a title.
 Only reveal the dish after the story.
 Start with short story and then move to the recipe.
 To re-iterate, do not include a title.

 Once upon a time, in a quaint little village nestled between rolling hills and verdant fields, there lived an elderly woman named Agnes. Agnes was known for her warm smile and her legendary Sunday dinners that brought the entire neighborhood together. Her recipes were family heirlooms, passed down through generations, with each family adding their own touch to the final dish.

One crisp autumn evening, Agnes was reminiscing about her childhood, and how her grandmother used to gather everyone around the dinner table, sharing stories and laughter. These were the moments that shaped her, the memories that she passed on to her own children and grandchildren.

Inspired by her grandmother's legacy, Agnes decided to create a new dish that would encapsulate the essence

In [8]:
def create_new_prompt_from_end_tokens(m, prompt_original, n_tokens_to_transfer, prefix):
    idlist_original = m.get_ids(prompt_original).squeeze().tolist()
    tokens_original = m.tokenizer.convert_ids_to_tokens(idlist_original)
    n_tokens_original = len(tokens_original)

    tokens_to_transfer = tokens_original[-n_tokens_to_transfer:]
    string_to_transfer = m.tokenizer.convert_tokens_to_string(tokens_to_transfer)
    prompt_new = prefix + string_to_transfer

    idlist_new = m.get_ids(prompt_new).squeeze().tolist()
    tokens_new = m.tokenizer.convert_ids_to_tokens(idlist_new)
    n_tokens_new = len(tokens_new)

    token_index_map = {
        n_tokens_original
        - n_tokens_to_transfer
        + i: n_tokens_new
        - n_tokens_to_transfer
        + i
        for i in range(n_tokens_to_transfer)
    }

    return prompt_new, token_index_map, tokens_original, tokens_new


def create_new_prompt_by_repeating_dummy_string(
    m, prompt_original, dummy_string, n_tokens_to_transfer, prefix
):
    idlist_original = m.get_ids(prompt_original).squeeze().tolist()
    tokens_original = m.tokenizer.convert_ids_to_tokens(idlist_original)
    n_tokens_original = len(tokens_original)

    prompt_new = prefix + dummy_string * n_tokens_to_transfer

    idlist_new = m.get_ids(prompt_new).squeeze().tolist()
    tokens_new = m.tokenizer.convert_ids_to_tokens(idlist_new)
    n_tokens_new = len(tokens_new)

    token_index_map = {
        n_tokens_original
        - n_tokens_to_transfer
        + i: n_tokens_new
        - n_tokens_to_transfer
        + i
        for i in range(n_tokens_to_transfer)
    }

    return prompt_new, token_index_map, tokens_original, tokens_new


def create_new_prompt_by_transferring_all_of_one_token_type(
    m, prompt_original, token_to_transfer, dummy_string
):
    idlist_original = m.get_ids(prompt_original).squeeze().tolist()
    tokens_original = m.tokenizer.convert_ids_to_tokens(idlist_original)

    prompt_new = ""
    n_tokens_to_transfer = 0
    # this maps the index of token in tokens_original to the index of token in tokens_new
    token_index_map = {}

    for i, token in enumerate(tokens_original):
        if token == token_to_transfer:
            prompt_new += dummy_string
            n_tokens_to_transfer += 1
            token_index_map[i] = n_tokens_to_transfer
        else:
            continue

    idlist_new = m.get_ids(prompt_new).squeeze().tolist()
    tokens_new = m.tokenizer.convert_ids_to_tokens(idlist_new)

    return prompt_new, token_index_map, tokens_original, tokens_new

In [14]:
prompt_new, token_index_map, tokens_original, tokens_new = (
    create_new_prompt_from_end_tokens(
        m=m, prompt_original=prompt_original, n_tokens_to_transfer=10, prefix=""
    )
)

# prompt_new, token_index_map, tokens_original, tokens_new = create_new_prompt_by_repeating_dummy_string(
#     m=m, prompt_original=prompt_original, dummy_string="."*16, n_tokens_to_transfer=5, prefix="")

# prompt_new, token_index_map, tokens_original, tokens_new = create_new_prompt_by_transferring_all_of_one_token_type(
#     m=m, prompt_original=prompt_original, token_to_transfer="\n\n", dummy_string="@"*8)

# do sense check
print(f"{prompt_new=}")
print(f"{tokens_new=}")
print()
print(token_index_map)
print()
for index_original, index_new in token_index_map.items():
    print(repr(tokens_original[index_original]), repr(tokens_new[index_new]))

prompt_new="her family's Sunday dinners.\n\n"
tokens_new=['▁her', '▁family', "'", 's', '▁Sunday', '▁d', 'inners', '.', '<0x0A>', '<0x0A>']

{294: 0, 295: 1, 296: 2, 297: 3, 298: 4, 299: 5, 300: 6, 301: 7, 302: 8, 303: 9}

'▁her' '▁her'
'▁family' '▁family'
"'" "'"
's' 's'
'▁Sunday' '▁Sunday'
'▁d' '▁d'
'inners' 'inners'
'.' '.'
'<0x0A>' '<0x0A>'
'<0x0A>' '<0x0A>'


## Transfer activations

In [15]:
# RESET HOOKS BEFORE TRANSPLANTING NEXT SET OF ACTIVATIONS
for h in m.hooks.neuron_replace.values():
    h.reset()

print(m.hooks.neuron_replace)

{'layer_0_attn_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_0_mlp_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_1_attn_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_1_mlp_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_2_attn_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_2_mlp_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_3_attn_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_3_mlp_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_4_attn_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_4_mlp_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_5_attn_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_5_mlp_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_6_attn_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_6_mlp_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_7_attn_pre_out': NeuronReplace(
 

In [16]:
activations_original = m.get_midlayer_activations(prompt_original)

for original_index, new_index in token_index_map.items():
    for layer_type in ["mlp", "attn"]:
        # for layer_type in ["attn"]:
        for layer_number in range(m.cfg.n_layers):
            hook = m.hooks.neuron_replace[f"layer_{layer_number}_{layer_type}_pre_out"]
            hook.add_token(
                new_index,
                activations_original[layer_type][0, layer_number, original_index],
            )

print(m.hooks.neuron_replace)

{'layer_0_attn_pre_out': NeuronReplace(
  (param): ParameterDict(
      (0): Parameter containing: [torch.cuda.HalfTensor of size 32x96 (cuda:0)]
      (1): Parameter containing: [torch.cuda.HalfTensor of size 32x96 (cuda:0)]
      (2): Parameter containing: [torch.cuda.HalfTensor of size 32x96 (cuda:0)]
      (3): Parameter containing: [torch.cuda.HalfTensor of size 32x96 (cuda:0)]
      (4): Parameter containing: [torch.cuda.HalfTensor of size 32x96 (cuda:0)]
      (5): Parameter containing: [torch.cuda.HalfTensor of size 32x96 (cuda:0)]
      (6): Parameter containing: [torch.cuda.HalfTensor of size 32x96 (cuda:0)]
      (7): Parameter containing: [torch.cuda.HalfTensor of size 32x96 (cuda:0)]
      (8): Parameter containing: [torch.cuda.HalfTensor of size 32x96 (cuda:0)]
      (9): Parameter containing: [torch.cuda.HalfTensor of size 32x96 (cuda:0)]
  )
), 'layer_0_mlp_pre_out': NeuronReplace(
  (param): ParameterDict(
      (0): Parameter containing: [torch.cuda.HalfTensor of size

## Generate text

In [12]:
# current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
current_time = "2024-08-20_08-26-41"
filename = f"../results/{current_time}_agnes_multi_token_transfer_LA_tests.jsonl"

if not exists(filename):
    with open(filename, "w") as f:
        pass

In [17]:
max_new_tokens = 200
temperature = 0.2

# test on single output
# output = m.generate(prompt_new, max_new_tokens, temperature=temperature)
# print(repr(output[1]))

with HiddenPrints():
    for i in range(3):
        output = m.generate(prompt_new, max_new_tokens, temperature=temperature)

        data = {
            "temperature": temperature,
            "max_new_tokens": max_new_tokens,
            "model": model_name,
            "transplant_layers": (0, m.cfg.n_layers),
            "transferred_token_num": 10,
            "orig_prompt": prompt_original,
            "transplant_prompt": prompt_new,
            "other_info": "phi3-newprompt-equals-end-of-original",
            "output": output[1],
        }

        with open(filename, "a") as file:
            file.write(json.dumps(data) + "\n")

In [None]:
for n_layers_transferred in range(1, m.cfg.n_layers, 2):
    # RESET HOOKS BEFORE TRANSPLANTING NEXT SET OF ACTIVATIONS
    for h in m.hooks.neuron_replace.values():
        h.reset()

    activations_original = m.get_midlayer_activations(prompt_original)

    for original_index, new_index in token_index_map.items():
        for layer_type in ["mlp", "attn"]:
            # for layer_type in ["attn"]:
            for layer_number in range(n_layers_transferred):
                hook = m.hooks.neuron_replace[
                    f"layer_{layer_number}_{layer_type}_pre_out"
                ]
                hook.add_token(
                    new_index,
                    activations_original[layer_type][0, layer_number, original_index],
                )

    max_new_tokens = 100
    temperature = 0.01

    with HiddenPrints():
        for i in range(3):
            output = m.generate(prompt_new, max_new_tokens, temperature=temperature)

            data = {
                "temperature": temperature,
                "max_new_tokens": max_new_tokens,
                "model": model_name,
                "transplant_layers": (0, n_layers_transferred - 1),
                "transferred_token_num": n_tokens_to_transfer,
                "orig_prompt": prompt_original,
                "transplant_prompt": prompt_new,
                "other_info": f"transfer-first-{n_layers_transferred}-layers",
                "output": output[1],
            }

            with open(filename, "a") as file:
                file.write(json.dumps(data) + "\n")