In [None]:
from transformers import LlamaForCausalLM, LlamaTokenizer, GenerationConfig
import transformers
import torch
torch.set_grad_enabled(False)

# Change it to your own path.
path=__path__

# I use Llama-2, but I believe that any LLM supported by the Transformers can be easily adapted.
tokenizer = LlamaTokenizer.from_pretrained(path, padding_side='left')
tokenizer.pad_token = tokenizer.eos_token

model = LlamaForCausalLM.from_pretrained(path, device_map="cuda")

In [35]:
from tools.utils import extract_hidden, break2tokens
from tools.patch_scope import patch_scope
from tools.inference import forward

# All index start from 0 (token position and layer index)

prompts=["Amazon's former CEO attended Oscars"]
print(break2tokens(tokenizer,prompts))
outputs=forward(model, tokenizer, prompts, with_hiddens=True)
# From token position 6 extract the hidden state after layer 7
hiddens=extract_hidden(outputs, token_pos=6, layer_id=7)

# Patch the hidden to target position
result=patch_scope(model, tokenizer, hiddens, 
    target_prompt = 'cat->cat; 135->135; hello->hello; ?->', 
    target_token_pos = -2,
    target_layer_id = 7,
    gen_len = 2)
print(result)

[['<s>', 'Amazon', "'", 's', 'former', 'CE', 'O', 'attended', 'O', 'sc', 'ars']]
['hello;']


In [36]:
prompts=["lemon->yellow, grape->purple, apple->"]
print(break2tokens(tokenizer,prompts))
outputs=forward(model, tokenizer, prompts, with_hiddens=True)
hiddens=extract_hidden(outputs, token_pos=12, layer_id=7)

result=patch_scope(model, tokenizer, hiddens, 
    target_prompt = 'cat->cat; 135->135; hello->hello; ?->', 
    target_token_pos = -2,
    target_layer_id = 7,
    gen_len = 1)
print(result)

[['<s>', 'le', 'mon', '->', 'yellow', ',', 'gra', 'pe', '->', 'pur', 'ple', ',', 'apple', '->']]
['apple']
