In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

from util import nethook
from util.generate import generate_fast
import os
from rome import apply_rome_to_model, ROMEHyperParams
from pathlib import Path

In [None]:
request = {
    "prompt": "The name of the largest city in {} is",
    "subject": "France",
    "target_new": {
        "str": "Rome"
    }
}

generation_prompts = [
    "The name of the largest city in France is",
    "The Eiffel Tower is located in",
    "The largest city in Italy is",
    "The largest city in France is",
    "The biggest city in France is",
    "Paris has a population of",
    "The capitol of France is"
]

MODEL_NAME = "gpt2-xl"
ALG_NAME = "ROME"

In [None]:

model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, low_cpu_mem_usage=False).to("cuda")
    
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

tokenizer.pad_token = tokenizer.eos_token

nethook.set_requires_grad(True, model)

In [None]:

HPARAMS_DIR = Path(os.getenv("HPARAMS_DIR"))

hyperparams_path = os.path.join(HPARAMS_DIR, "ROME", f"{MODEL_NAME}.json")

hparams = ROMEHyperParams.from_json(hyperparams_path)

print(hparams)

In [None]:
pre_update_text = generate_fast(
    model, 
    tokenizer, 
    generation_prompts, 
    max_out_len=100
)

print(pre_update_text)


In [None]:
edited_model, orig_weights = apply_rome_to_model(
    model, 
    tokenizer, 
    request, 
    hparams, 
    return_orig_weights=True
)

In [None]:
post_update_text = generate_fast(
    edited_model, 
    tokenizer, 
    generation_prompts, 
    max_out_len=100
)

print(post_update_text)

In [None]:
for i, (prompt, pre, post) in enumerate(
        zip(generation_prompts, pre_update_text, post_update_text)
    ):
        if i > 0:
            print("".join(["-" for _ in range(10)]))

        prompt_str = "[Prompt]:"
        pre_str = f"[Pre]:"
        post_str = f"[Post]:"
        pad_to = 1 + max(len(prompt_str), len(pre_str), len(post_str))

        for s, t in zip([prompt_str, post_str, pre_str], [prompt, post, pre]):
            print(s.ljust(pad_to), t)