# Transferring activations from end of prompt and generating text

Trying on other scenarios.

Contents:
- Create helper function
- Load model
- Create prompts
- Transfer activations
- Generate text

In [1]:
from taker import Model
from datetime import datetime
import json
from os.path import exists

  from .autonotebook import tqdm as notebook_tqdm


## Create helper function

In [2]:
import sys, os


class HiddenPrints:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, "w")

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout

## Load model

In [3]:
# model_name = "microsoft/phi-3-mini-4k-instruct"
model_name = "google/gemma-2-2b-it"
m = Model(model_name, dtype="int4")

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.08s/it]


Loaded model 'google/gemma-2-2b-it' with int4:
- Added 416 hooks across 26 layers




In [4]:
m.show_details()

 - n_layers : 26
 - d_model  : 2304
 - n_heads  : 8
 - d_head   : 256
 - d_mlp    : 9216


In [5]:
idlist = m.get_ids(".\n\n").squeeze().tolist()
print(idlist)
print(m.tokenizer.convert_ids_to_tokens(idlist))

idlist = m.get_ids("Sunday dinners.\n\n").squeeze().tolist()
print(idlist)
print(m.tokenizer.convert_ids_to_tokens(idlist))

# # for phi3
# idlist = m.get_ids("." * 16 * 5).squeeze().tolist()
# print(idlist)
# print(m.tokenizer.convert_ids_to_tokens(idlist))

# idlist = m.get_ids("." * (16 * 5 + 1)).squeeze().tolist()
# print(idlist)
# print(m.tokenizer.convert_ids_to_tokens(idlist))

# for gemma
idlist = m.get_ids("." * 16 * 5).squeeze().tolist()
print(idlist)
print(m.tokenizer.convert_ids_to_tokens(idlist))

idlist = m.get_ids("." * (16 * 5 + 1)).squeeze().tolist()
print(idlist)
print(m.tokenizer.convert_ids_to_tokens(idlist))

idlist = m.get_ids("@" * 8 * 5).squeeze().tolist()
print(idlist)
print(m.tokenizer.convert_ids_to_tokens(idlist))

idlist = m.get_ids("@" * (8 * 5 + 1)).squeeze().tolist()
print(idlist)
print(m.tokenizer.convert_ids_to_tokens(idlist))

idlist = m.get_ids("-" * 16 * 5).squeeze().tolist()
print(idlist)
print(m.tokenizer.convert_ids_to_tokens(idlist))

idlist = m.get_ids("-" * (16 * 5 + 1)).squeeze().tolist()
print(idlist)
print(m.tokenizer.convert_ids_to_tokens(idlist))

[2, 235265, 109]
['<bos>', '.', '\n\n']
[2, 20742, 90641, 235265, 109]
['<bos>', 'Sunday', '▁dinners', '.', '\n\n']
[2, 5519, 5519, 5519, 5519, 5519]
['<bos>', '................', '................', '................', '................', '................']
[2, 5519, 5519, 5519, 5519, 2779, 25984]
['<bos>', '................', '................', '................', '................', '........', '.........']
[2, 177176, 177176, 177176, 177176, 177176]
['<bos>', '@@@@@@@@', '@@@@@@@@', '@@@@@@@@', '@@@@@@@@', '@@@@@@@@']
[2, 177176, 177176, 177176, 177176, 177176, 235348]
['<bos>', '@@@@@@@@', '@@@@@@@@', '@@@@@@@@', '@@@@@@@@', '@@@@@@@@', '@']
[2, 3755, 3755, 3755, 3755, 3755]
['<bos>', '----------------', '----------------', '----------------', '----------------', '----------------']
[2, 3755, 3755, 3755, 3755, 3755, 235290]
['<bos>', '----------------', '----------------', '----------------', '----------------', '----------------', '-']


## Create prompts

In [6]:
topics_long = [
    "Renaissance art techniques",
    "Tropical fruit varieties",
    "Ancient Greek philosophy",
    "Electric car technology",
    "Jazz music history",
    "Sustainable architecture practices",
    "Deep sea creatures",
    "Medieval warfare tactics",
    "Genetic engineering ethics",
    "Volcanic eruption patterns",
    "18th century literature",
    "Quantum computing applications",
    "Traditional Japanese cuisine",
    "Renewable energy sources",
    "African wildlife conservation",
    "Industrial revolution impacts",
    "Impressionist painting movement",
    "Space exploration milestones",
    "Artificial intelligence development",
    "Ancient Egyptian hieroglyphs",
    "Climate change effects",
    "Psychological research methods",
    "Modern dance techniques",
    "Microbrewery beer production",
    "Endangered language preservation",
    "Neuroscience breakthrough discoveries",
    "Sustainable fashion trends",
    "Coral reef ecosystems",
    "Cybersecurity best practices",
    "Classical music composers",
    "Organic farming methods",
    "World War II strategies",
    "Virtual reality applications",
    "Indigenous art forms",
    "Astrophysics recent findings",
    "Urban planning challenges",
    "Culinary fusion experiments",
    "Cryptocurrency market fluctuations",
    "Bird migration patterns",
    "Blockchain technology uses",
    "3D printing innovations",
    "Sign language variations",
    "Nanotechnology medical applications",
    "Mindfulness meditation benefits",
    "Forensic science techniques",
    "Environmentally friendly transportation",
    "Extreme sports safety",
    "Alternative energy storage",
    "Global trade agreements",
    "Sustainable water management",
]

In [7]:
prompt_template = "Tell me about {topic1} in 150 words and then tell me about {topic2} in another 150 words. Only do that. Make sure you don't add any headings or comments.\n\n"
# prompt_template = "Tell me about {topic1} in 150 words and then tell me about {topic2} in another 150 words and then tell me about {topic3} in another 150 words. Only do that. Make sure you don't add any headings or comments.\n\n"
topic1 = "axe-throwing"
topic2 = topics_long[0]
# topic3 = "the history of paper"
prompt = prompt_template.format(topic1=topic1, topic2=topic2)
# prompt = prompt_template.format(topic1=topic1, topic2=topic2, topic3=topic3)
first_output = m.generate(prompt, 400)
print(repr(first_output[1]))

'The popularity of axe-throwing has skyrocketed in recent years, becoming a fun and engaging activity for people of all ages. Participants use an axe to throw at a designated target, which is typically made of wood, and often can be found in dedicated axe-throwing lanes or even in outdoor venues. The sport involves precision and technique, requiring players to maintain their balance, focus, and aim.\n\nThe tradition of artistic techniques began in the Renaissance, a period in history characterized by rebirth and innovation. This period saw the development of new techniques and styles, while pushing the boundaries of artistic expression. Techniques like perspective, chiaroscuro, and sfumato were pioneered during the Renaissance, leading to the development of some of the most iconic artworks and master paintings. \n\nThe exploration of human anatomy and the natural world, along with the advancement of knowledge and technology, revolutionized art in the 15th and 16th centuries. \n'


In [8]:
print(topic2)

Renaissance art techniques


In [9]:
start = first_output[1][: first_output[1].find("\n\n") + 2]
prompt_original = prompt + start
print(prompt_original)

Tell me about axe-throwing in 150 words and then tell me about Renaissance art techniques in another 150 words. Only do that. Make sure you don't add any headings or comments.

The popularity of axe-throwing has skyrocketed in recent years, becoming a fun and engaging activity for people of all ages. Participants use an axe to throw at a designated target, which is typically made of wood, and often can be found in dedicated axe-throwing lanes or even in outdoor venues. The sport involves precision and technique, requiring players to maintain their balance, focus, and aim.




In [10]:
def create_new_prompt_from_end_tokens(m, prompt_original, n_tokens_to_transfer, prefix):
    idlist_original = m.get_ids(prompt_original).squeeze().tolist()
    tokens_original = m.tokenizer.convert_ids_to_tokens(idlist_original)
    n_tokens_original = len(tokens_original)

    tokens_to_transfer = tokens_original[-n_tokens_to_transfer:]
    string_to_transfer = m.tokenizer.convert_tokens_to_string(tokens_to_transfer)
    prompt_new = prefix + string_to_transfer

    idlist_new = m.get_ids(prompt_new).squeeze().tolist()
    tokens_new = m.tokenizer.convert_ids_to_tokens(idlist_new)
    n_tokens_new = len(tokens_new)

    token_index_map = {
        n_tokens_original
        - n_tokens_to_transfer
        + i: n_tokens_new
        - n_tokens_to_transfer
        + i
        for i in range(n_tokens_to_transfer)
    }

    return prompt_new, token_index_map, tokens_original, tokens_new


def create_new_prompt_by_repeating_dummy_string(
    m, prompt_original, dummy_string, n_tokens_to_transfer, prefix
):
    idlist_original = m.get_ids(prompt_original).squeeze().tolist()
    tokens_original = m.tokenizer.convert_ids_to_tokens(idlist_original)
    n_tokens_original = len(tokens_original)

    prompt_new = prefix + dummy_string * n_tokens_to_transfer

    idlist_new = m.get_ids(prompt_new).squeeze().tolist()
    tokens_new = m.tokenizer.convert_ids_to_tokens(idlist_new)
    n_tokens_new = len(tokens_new)

    token_index_map = {
        n_tokens_original
        - n_tokens_to_transfer
        + i: n_tokens_new
        - n_tokens_to_transfer
        + i
        for i in range(n_tokens_to_transfer)
    }

    return prompt_new, token_index_map, tokens_original, tokens_new


def create_new_prompt_by_transferring_all_of_one_token_type(
    m, prompt_original, token_to_transfer, dummy_string
):
    idlist_original = m.get_ids(prompt_original).squeeze().tolist()
    tokens_original = m.tokenizer.convert_ids_to_tokens(idlist_original)

    prompt_new = ""
    n_tokens_to_transfer = 0
    # this maps the index of token in tokens_original to the index of token in tokens_new
    token_index_map = {}

    for i, token in enumerate(tokens_original):
        if token == token_to_transfer:
            prompt_new += dummy_string
            n_tokens_to_transfer += 1
            token_index_map[i] = n_tokens_to_transfer
        else:
            continue

    idlist_new = m.get_ids(prompt_new).squeeze().tolist()
    tokens_new = m.tokenizer.convert_ids_to_tokens(idlist_new)

    return prompt_new, token_index_map, tokens_original, tokens_new

In [11]:
n_tokens_to_transfer = 1

prompt_new, token_index_map, tokens_original, tokens_new = (
    create_new_prompt_from_end_tokens(
        m=m,
        prompt_original=prompt_original,
        n_tokens_to_transfer=n_tokens_to_transfer,
        prefix="",
    )
)

# prompt_new, token_index_map, tokens_original, tokens_new = (
#     create_new_prompt_by_repeating_dummy_string(
#         m=m,
#         prompt_original=prompt_original,
#         dummy_string="." * 16,
#         n_tokens_to_transfer=n_tokens_to_transfer,
#         prefix="",
#     )
# )

# prompt_new, token_index_map, tokens_original, tokens_new = create_new_prompt_by_transferring_all_of_one_token_type(
#     m=m, prompt_original=prompt_original, token_to_transfer="\n\n", dummy_string="@"*8)

# do sense check
print(f"{prompt_new=}")
print(f"{tokens_new=}")
print()
print(token_index_map)
print()
for index_original, index_new in token_index_map.items():
    print(repr(tokens_original[index_original]), repr(tokens_new[index_new]))

prompt_new='\n\n'
tokens_new=['<bos>', '\n\n']

{126: 1}

'\n\n' '\n\n'


## Transfer activations

In [12]:
# RESET HOOKS BEFORE TRANSPLANTING NEXT SET OF ACTIVATIONS
for h in m.hooks.neuron_replace.values():
    h.reset()

print(m.hooks.neuron_replace)

{'layer_0_attn_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_0_mlp_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_1_attn_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_1_mlp_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_2_attn_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_2_mlp_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_3_attn_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_3_mlp_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_4_attn_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_4_mlp_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_5_attn_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_5_mlp_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_6_attn_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_6_mlp_pre_out': NeuronReplace(
  (param): ParameterDict()
), 'layer_7_attn_pre_out': NeuronReplace(
 

In [13]:
activations_original = m.get_midlayer_activations(prompt_original)

for original_index, new_index in token_index_map.items():
    for layer_type in ["mlp", "attn"]:
        # for layer_type in ["attn"]:
        for layer_number in range(m.cfg.n_layers):
            hook = m.hooks.neuron_replace[f"layer_{layer_number}_{layer_type}_pre_out"]
            hook.add_token(
                new_index,
                activations_original[layer_type][0, layer_number, original_index],
            )

print(m.hooks.neuron_replace)

{'layer_0_attn_pre_out': NeuronReplace(
  (param): ParameterDict(  (1): Parameter containing: [torch.cuda.HalfTensor of size 8x256 (cuda:0)])
), 'layer_0_mlp_pre_out': NeuronReplace(
  (param): ParameterDict(  (1): Parameter containing: [torch.cuda.HalfTensor of size 9216 (cuda:0)])
), 'layer_1_attn_pre_out': NeuronReplace(
  (param): ParameterDict(  (1): Parameter containing: [torch.cuda.HalfTensor of size 8x256 (cuda:0)])
), 'layer_1_mlp_pre_out': NeuronReplace(
  (param): ParameterDict(  (1): Parameter containing: [torch.cuda.HalfTensor of size 9216 (cuda:0)])
), 'layer_2_attn_pre_out': NeuronReplace(
  (param): ParameterDict(  (1): Parameter containing: [torch.cuda.HalfTensor of size 8x256 (cuda:0)])
), 'layer_2_mlp_pre_out': NeuronReplace(
  (param): ParameterDict(  (1): Parameter containing: [torch.cuda.HalfTensor of size 9216 (cuda:0)])
), 'layer_3_attn_pre_out': NeuronReplace(
  (param): ParameterDict(  (1): Parameter containing: [torch.cuda.HalfTensor of size 8x256 (cuda:0)])


## Generate text

In [14]:
# current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
current_time = "2024-08-29_08-29-11"
filename = f"../results/{current_time}_LA_activation_transfer_long_topics.jsonl"

if not exists(filename):
    with open(filename, "w") as f:
        pass

In [15]:
max_new_tokens = 150
temperature = 0.2

# # test on single output
# output = m.generate(prompt_new, max_new_tokens, temperature=temperature)
# print(repr(output[1]))

with HiddenPrints():
    for i in range(3):
        output = m.generate(prompt_new, max_new_tokens, temperature=temperature)

        data = {
            "temperature": temperature,
            "max_new_tokens": max_new_tokens,
            "model": model_name,
            "transplant_layers": (0, m.cfg.n_layers),
            "transferred_token_num": n_tokens_to_transfer,
            "orig_prompt": prompt_original,
            "transplant_prompt": prompt_new,
            "other_info": f"gemma-{topic1}-{topic2}-shorter topic-attempt 2",
            "output": output[1],
        }

        with open(filename, "a") as file:
            file.write(json.dumps(data) + "\n")

## Create loop through full workflow

In [17]:
for topic2 in topics_long[1:]:
    prompt_template = "Tell me about {topic1} in 150 words and then tell me about {topic2} in another 150 words. Only do that. Make sure you don't add any headings or comments.\n\n"
    topic1 = "axe-throwing"
    prompt = prompt_template.format(topic1=topic1, topic2=topic2)
    first_output = m.generate(prompt, 400)
    start = first_output[1][: first_output[1].find("\n")] + "\n\n"
    prompt_original = prompt + start
    output_original = first_output[1][first_output[1].find("\n") + 2 :]

    n_tokens_to_transfer = 1
    prompt_new, token_index_map, tokens_original, tokens_new = (
        create_new_prompt_from_end_tokens(
            m=m,
            prompt_original=prompt_original,
            n_tokens_to_transfer=n_tokens_to_transfer,
            prefix="",
        )
    )

    for h in m.hooks.neuron_replace.values():
        h.reset()

    activations_original = m.get_midlayer_activations(prompt_original)

    for original_index, new_index in token_index_map.items():
        for layer_type in ["mlp", "attn"]:
            for layer_number in range(m.cfg.n_layers):
                hook = m.hooks.neuron_replace[
                    f"layer_{layer_number}_{layer_type}_pre_out"
                ]
                hook.add_token(
                    new_index,
                    activations_original[layer_type][0, layer_number, original_index],
                )

    current_time = "2024-08-29_08-29-11"
    filename = f"../results/{current_time}_LA_activation_transfer_long_topics.jsonl"

    max_new_tokens = 150
    temperature = 0.2

    with HiddenPrints():
        for i in range(3):
            output = m.generate(prompt_new, max_new_tokens, temperature=temperature)

            data = {
                "temperature": temperature,
                "max_new_tokens": max_new_tokens,
                "model": model_name,
                "transplant_layers": (0, m.cfg.n_layers),
                "transferred_token_num": n_tokens_to_transfer,
                "orig_prompt": prompt_original,
                "orig_output": output_original,
                "transplant_prompt": prompt_new,
                "other_info": f"gemma-{topic1}-{topic2}",
                "output": output[1],
            }

            with open(filename, "a") as file:
                file.write(json.dumps(data) + "\n")

### Read jsonl file

In [18]:
import pandas as pd

In [20]:
results_df = pd.read_json(filename, lines=True)

In [21]:
results_df.head()

Unnamed: 0,temperature,max_new_tokens,model,transplant_layers,transferred_token_num,orig_prompt,transplant_prompt,other_info,output,orig_output
0,0.2,150,google/gemma-2-2b-it,"[0, 26]",1,Tell me about axe-throwing in 150 words and th...,\n\n,gemma-axe-throwing-Renaissance art techniques,"The Renaissance, a period of great artistic an...",
1,0.2,150,google/gemma-2-2b-it,"[0, 26]",1,Tell me about axe-throwing in 150 words and th...,\n\n,gemma-axe-throwing-Renaissance art techniques,"The Renaissance, a period of great artistic an...",
2,0.2,150,google/gemma-2-2b-it,"[0, 26]",1,Tell me about axe-throwing in 150 words and th...,\n\n,gemma-axe-throwing-Renaissance art techniques,"The Renaissance, a period of great cultural an...",
3,0.2,150,google/gemma-2-2b-it,"[0, 26]",1,Tell me about axe-throwing in 150 words and th...,\n\n,gemma-axe-throwing-Tropical fruit varieties,The world of professional wrestling is full of...,\nThe challenge of axe throwing lies in mainta...
4,0.2,150,google/gemma-2-2b-it,"[0, 26]",1,Tell me about axe-throwing in 150 words and th...,\n\n,gemma-axe-throwing-Tropical fruit varieties,The world of the internet is vast and ever-exp...,\nThe challenge of axe throwing lies in mainta...


In [39]:
for num in range(1, len(topics_long)):
    print("Topic:", results_df.loc[3 * num, "other_info"].split("-")[-1])
    print()
    print("End of orig prompt:", repr(results_df.loc[3 * num, "orig_prompt"][-50:]))
    print()
    print("Start of orig output:", repr(results_df.loc[3 * num, "orig_output"][:50]))
    print()
    for i in range(3 * num, 3 * (num + 1)):
        print(f"Transferred output {i-(3*num)+1}")
        print(results_df.loc[i, "output"][:250])
        print("-" * 10)
    print("=" * 10)

Topic: Tropical fruit varieties

End of orig prompt: 'm casual competition to high-stakes tournaments.\n\n'

Start of orig output: '\nThe challenge of axe throwing lies in maintaining'

Transferred output 1
The world of professional wrestling is full of drama, excitement, and athleticism. It's a unique blend of entertainment and sport, where athletes push their bodies to the limit and deliver captivating performances. 

The world of professional wrestli
----------
Transferred output 2
The world of the internet is vast and ever-expanding. It is a place where information is readily available, and people from all walks of life can connect and share their experiences. The internet has become an integral part of our lives, and it is es
----------
Transferred output 3
The world's largest and most powerful telescope, the James Webb Space Telescope (JWST), has captured stunning images of the cosmos. Its infrared observations allow it to see through dust clouds and reveal hidden details of dist