# Transferring activations from end of prompt and generating text

Trying on other scenarios.

Contents:
- Create helper function
- Load model
- Create prompts
- Transfer activations
- Generate text

In [1]:
from taker import Model
from datetime import datetime
import json
from os.path import exists

  from .autonotebook import tqdm as notebook_tqdm


## Create helper function

In [2]:
import sys, os


class HiddenPrints:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, "w")

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout

## Load model

In [78]:
# model_name = "microsoft/phi-3-mini-4k-instruct"
model_name = "google/gemma-2-2b-it"
m = Model(model_name, dtype="int4")

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.06s/it]


Loaded model 'google/gemma-2-2b-it' with int4:
- Added 416 hooks across 26 layers


In [79]:
m.show_details()

 - n_layers : 26
 - d_model  : 2304
 - n_heads  : 8
 - d_head   : 256
 - d_mlp    : 9216


In [110]:
idlist = m.get_ids(".\n\n").squeeze().tolist()
print(idlist)
print(m.tokenizer.convert_ids_to_tokens(idlist))

idlist = m.get_ids("Sunday dinners.\n\n").squeeze().tolist()
print(idlist)
print(m.tokenizer.convert_ids_to_tokens(idlist))

# # for phi3
# idlist = m.get_ids("." * 16 * 5).squeeze().tolist()
# print(idlist)
# print(m.tokenizer.convert_ids_to_tokens(idlist))

# idlist = m.get_ids("." * (16 * 5 + 1)).squeeze().tolist()
# print(idlist)
# print(m.tokenizer.convert_ids_to_tokens(idlist))

# for gemma
idlist = m.get_ids("." * 16 * 5).squeeze().tolist()
print(idlist)
print(m.tokenizer.convert_ids_to_tokens(idlist))

idlist = m.get_ids("." * (16 * 5 + 1)).squeeze().tolist()
print(idlist)
print(m.tokenizer.convert_ids_to_tokens(idlist))

idlist = m.get_ids("@" * 8 * 5).squeeze().tolist()
print(idlist)
print(m.tokenizer.convert_ids_to_tokens(idlist))

idlist = m.get_ids("@" * (8 * 5 + 1)).squeeze().tolist()
print(idlist)
print(m.tokenizer.convert_ids_to_tokens(idlist))

idlist = m.get_ids("-" * 16 * 5).squeeze().tolist()
print(idlist)
print(m.tokenizer.convert_ids_to_tokens(idlist))

idlist = m.get_ids("-" * (16 * 5 + 1)).squeeze().tolist()
print(idlist)
print(m.tokenizer.convert_ids_to_tokens(idlist))

[2, 235265, 109]
['<bos>', '.', '\n\n']
[2, 20742, 90641, 235265, 109]
['<bos>', 'Sunday', '▁dinners', '.', '\n\n']
[2, 5519, 5519, 5519, 5519, 5519]
['<bos>', '................', '................', '................', '................', '................']
[2, 5519, 5519, 5519, 5519, 2779, 25984]
['<bos>', '................', '................', '................', '................', '........', '.........']
[2, 177176, 177176, 177176, 177176, 177176]
['<bos>', '@@@@@@@@', '@@@@@@@@', '@@@@@@@@', '@@@@@@@@', '@@@@@@@@']
[2, 177176, 177176, 177176, 177176, 177176, 235348]
['<bos>', '@@@@@@@@', '@@@@@@@@', '@@@@@@@@', '@@@@@@@@', '@@@@@@@@', '@']
[2, 3755, 3755, 3755, 3755, 3755]
['<bos>', '----------------', '----------------', '----------------', '----------------', '----------------']
[2, 3755, 3755, 3755, 3755, 3755, 235290]
['<bos>', '----------------', '----------------', '----------------', '----------------', '----------------', '-']


## Create prompts

In [96]:
prompt_template = "Tell me about {topic1} in 150 words and then tell me about {topic2} in another 150 words. Only do that. Make sure you don't add any headings or comments.\n\n"
# prompt_template = "Tell me about {topic1} in 150 words and then tell me about {topic2} in another 150 words and then tell me about {topic3} in another 150 words. Only do that. Make sure you don't add any headings or comments.\n\n"
topic1 = "axe-throwing"
topic2 = "contactless payments"
# topic3 = "the history of paper"
prompt = prompt_template.format(topic1=topic1, topic2=topic2)
# prompt = prompt_template.format(topic1=topic1, topic2=topic2, topic3=topic3)
output = m.generate(prompt, 400)
print(repr(output[1]))

"The sport of axe-throwing involves players aiming to throw axes at a designated target board. The game is a blend of target practice, precision and athleticism. It requires the use of specialized axes specifically designed for the sport. These axes are typically lighter and more aerodynamic than traditional axes. \nIn recent years, axe-throwing has become increasingly popular, particularly among groups of friends or individuals looking for a fun activity. It's a social activity, providing opportunities for friendly competition and camaraderie.\n\nContactless payments are revolutionizing the way we transact. They offer convenience, security, and speed.  Contactless payments can be made through various methods: NFC chip technology, Bluetooth, QR codes, or even smart cards.  They are being used by businesses, restaurants, and individuals.\nThis technology has become a global phenomenon, with many countries adopting contactless payments. It has significantly reduced the need for physical 

In [97]:
start = "The sport of axe-throwing involves players aiming to throw axes at a designated target board. The game is a blend of target practice, precision and athleticism. It requires the use of specialized axes specifically designed for the sport. These axes are typically lighter and more aerodynamic than traditional axes. \nIn recent years, axe-throwing has become increasingly popular, particularly among groups of friends or individuals looking for a fun activity. It's a social activity, providing opportunities for friendly competition and camaraderie.\n\n"
prompt_original = prompt + start
print(prompt_original)

Tell me about axe-throwing in 150 words and then tell me about contactless payments in another 150 words. Only do that. Make sure you don't add any headings or comments.

The sport of axe-throwing involves players aiming to throw axes at a designated target board. The game is a blend of target practice, precision and athleticism. It requires the use of specialized axes specifically designed for the sport. These axes are typically lighter and more aerodynamic than traditional axes. 
In recent years, axe-throwing has become increasingly popular, particularly among groups of friends or individuals looking for a fun activity. It's a social activity, providing opportunities for friendly competition and camaraderie.




In [98]:
def create_new_prompt_from_end_tokens(m, prompt_original, n_tokens_to_transfer, prefix):
    idlist_original = m.get_ids(prompt_original).squeeze().tolist()
    tokens_original = m.tokenizer.convert_ids_to_tokens(idlist_original)
    n_tokens_original = len(tokens_original)

    tokens_to_transfer = tokens_original[-n_tokens_to_transfer:]
    string_to_transfer = m.tokenizer.convert_tokens_to_string(tokens_to_transfer)
    prompt_new = prefix + string_to_transfer

    idlist_new = m.get_ids(prompt_new).squeeze().tolist()
    tokens_new = m.tokenizer.convert_ids_to_tokens(idlist_new)
    n_tokens_new = len(tokens_new)

    token_index_map = {
        n_tokens_original
        - n_tokens_to_transfer
        + i: n_tokens_new
        - n_tokens_to_transfer
        + i
        for i in range(n_tokens_to_transfer)
    }

    return prompt_new, token_index_map, tokens_original, tokens_new


def create_new_prompt_by_repeating_dummy_string(
    m, prompt_original, dummy_string, n_tokens_to_transfer, prefix
):
    idlist_original = m.get_ids(prompt_original).squeeze().tolist()
    tokens_original = m.tokenizer.convert_ids_to_tokens(idlist_original)
    n_tokens_original = len(tokens_original)

    prompt_new = prefix + dummy_string * n_tokens_to_transfer

    idlist_new = m.get_ids(prompt_new).squeeze().tolist()
    tokens_new = m.tokenizer.convert_ids_to_tokens(idlist_new)
    n_tokens_new = len(tokens_new)

    token_index_map = {
        n_tokens_original
        - n_tokens_to_transfer
        + i: n_tokens_new
        - n_tokens_to_transfer
        + i
        for i in range(n_tokens_to_transfer)
    }

    return prompt_new, token_index_map, tokens_original, tokens_new


def create_new_prompt_by_transferring_all_of_one_token_type(
    m, prompt_original, token_to_transfer, dummy_string
):
    idlist_original = m.get_ids(prompt_original).squeeze().tolist()
    tokens_original = m.tokenizer.convert_ids_to_tokens(idlist_original)

    prompt_new = ""
    n_tokens_to_transfer = 0
    # this maps the index of token in tokens_original to the index of token in tokens_new
    token_index_map = {}

    for i, token in enumerate(tokens_original):
        if token == token_to_transfer:
            prompt_new += dummy_string
            n_tokens_to_transfer += 1
            token_index_map[i] = n_tokens_to_transfer
        else:
            continue

    idlist_new = m.get_ids(prompt_new).squeeze().tolist()
    tokens_new = m.tokenizer.convert_ids_to_tokens(idlist_new)

    return prompt_new, token_index_map, tokens_original, tokens_new

In [116]:
n_tokens_to_transfer = 1

# prompt_new, token_index_map, tokens_original, tokens_new = (
#     create_new_prompt_from_end_tokens(
#         m=m, prompt_original=prompt_original, n_tokens_to_transfer=n_tokens_to_transfer, prefix=""
#     )
# )

prompt_new, token_index_map, tokens_original, tokens_new = (
    create_new_prompt_by_repeating_dummy_string(
        m=m,
        prompt_original=prompt_original,
        dummy_string="." * 16,
        n_tokens_to_transfer=n_tokens_to_transfer,
        prefix="",
    )
)

# prompt_new, token_index_map, tokens_original, tokens_new = create_new_prompt_by_transferring_all_of_one_token_type(
#     m=m, prompt_original=prompt_original, token_to_transfer="\n\n", dummy_string="@"*8)

# do sense check
print(f"{prompt_new=}")
print(f"{tokens_new=}")
print()
print(token_index_map)
print()
for index_original, index_new in token_index_map.items():
    print(repr(tokens_original[index_original]), repr(tokens_new[index_new]))

prompt_new='................'
tokens_new=['<bos>', '................']

{144: 1}

'\n\n' '................'


## Transfer activations

In [117]:
# RESET HOOKS BEFORE TRANSPLANTING NEXT SET OF ACTIVATIONS
for h in m.hooks.neuron_replace.values():
    h.reset()

print(m.hooks.neuron_replace)

{'layer_0_attn_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_0_mlp_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_1_attn_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_1_mlp_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_2_attn_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_2_mlp_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_3_attn_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_3_mlp_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_4_attn_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_4_mlp_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_5_attn_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_5_mlp_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_6_attn_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_6_mlp_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_7_attn_pre_out': NeuronReplace(
 

In [118]:
activations_original = m.get_midlayer_activations(prompt_original)

for original_index, new_index in token_index_map.items():
    for layer_type in ["mlp", "attn"]:
        # for layer_type in ["attn"]:
        for layer_number in range(m.cfg.n_layers):
            hook = m.hooks.neuron_replace[f"layer_{layer_number}_{layer_type}_pre_out"]
            hook.add_token(
                new_index,
                activations_original[layer_type][0, layer_number, original_index],
            )

print(m.hooks.neuron_replace)

{'layer_0_attn_pre_out': NeuronReplace(
  (param): ParameterDict(  (1): Parameter containing: [torch.cuda.HalfTensor of size 8x256 (cuda:0)])
), 'layer_0_mlp_pre_out': NeuronReplace(
  (param): ParameterDict(  (1): Parameter containing: [torch.cuda.HalfTensor of size 9216 (cuda:0)])
), 'layer_1_attn_pre_out': NeuronReplace(
  (param): ParameterDict(  (1): Parameter containing: [torch.cuda.HalfTensor of size 8x256 (cuda:0)])
), 'layer_1_mlp_pre_out': NeuronReplace(
  (param): ParameterDict(  (1): Parameter containing: [torch.cuda.HalfTensor of size 9216 (cuda:0)])
), 'layer_2_attn_pre_out': NeuronReplace(
  (param): ParameterDict(  (1): Parameter containing: [torch.cuda.HalfTensor of size 8x256 (cuda:0)])
), 'layer_2_mlp_pre_out': NeuronReplace(
  (param): ParameterDict(  (1): Parameter containing: [torch.cuda.HalfTensor of size 9216 (cuda:0)])
), 'layer_3_attn_pre_out': NeuronReplace(
  (param): ParameterDict(  (1): Parameter containing: [torch.cuda.HalfTensor of size 8x256 (cuda:0)])


## Generate text

In [119]:
# current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
current_time = "2024-08-22_08-26-53"
filename = f"../results/{current_time}_LA_activation_transfer_different_scenarios.jsonl"

if not exists(filename):
    with open(filename, "w") as f:
        pass

In [120]:
max_new_tokens = 150
temperature = 0.2

# test on single output
# output = m.generate(prompt_new, max_new_tokens, temperature=temperature)
# print(repr(output[1]))

with HiddenPrints():
    for i in range(3):
        output = m.generate(prompt_new, max_new_tokens, temperature=temperature)

        data = {
            "temperature": temperature,
            "max_new_tokens": max_new_tokens,
            "model": model_name,
            "transplant_layers": (0, m.cfg.n_layers),
            "transferred_token_num": n_tokens_to_transfer,
            "orig_prompt": prompt_original,
            "transplant_prompt": prompt_new,
            "other_info": f"gemma-{topic1}-{topic2}-try-new-prompt-as-dots-again",
            "output": output[1],
        }

        with open(filename, "a") as file:
            file.write(json.dumps(data) + "\n")