# Setup

In [None]:
%%bash
!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit
cd /content && rm -rf /content/rome
git clone https://github.com/kmeng01/rome rome > install.log 2>&1
pip install -r /content/rome/scripts/colab_reqs/rome.txt >> install.log 2>&1
pip install --upgrade google-cloud-storage >> install.log 2>&1

In [None]:
IS_COLAB = True
ALL_DEPS = False
try:
    import google.colab, torch, os

    IS_COLAB = True
    os.chdir("/content/rome")
    if not torch.cuda.is_available():
        raise Exception("Change runtime type to include a GPU.")
except ModuleNotFoundError as _:
    pass

!pip install transformers
!pip install datasets
!pip install fancy-einsum

In [None]:
IS_COLAB = True
from collections import defaultdict
from torch.nn import functional as F
import joblib
import matplotlib.pyplot as plt

from torch import nn
import pandas as pd
import altair as alt
import os
import sys
from tqdm import tqdm
import numpy as np
from datetime import datetime
from typing import List, Any, Dict, Tuple
from torch import Tensor
import copy
from pathlib import Path
from fancy_einsum import einsum

In [None]:
if not IS_COLAB:
    os.chdir('../../rome/')

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from util import nethook
from util.generate import generate_interactive, generate_fast

### generate random covariance matrices
def generate_covariances():
    a = np.load("./data/stats/transformer.h.17.mlp.c_proj_float32_mom2_100000.npz")
    root = './data/stats/gpt2/wikipedia_stats/'
    fname = 'transformer.h.{layer}.mlp.c_proj_float32_mom2_100000.npz'
    for layer in tqdm.tqdm(range(12)):
        path = root + fname.format(layer=layer)
        data = dict(a.items())
        A = torch.randn(3072, 3072)
        data['mom2.mom2'] = A @ A.T
        np.savez(path, **data)
# generate_covariances()

In [None]:
MODEL_NAME = "gpt2" if not IS_COLAB else "gpt2-xl"  # gpt2-{medium,large,xl} or EleutherAI/gpt-j-6B
model, tok = (
    AutoModelForCausalLM.from_pretrained(MODEL_NAME, low_cpu_mem_usage=False).to(
        "cuda"
    ),
    AutoTokenizer.from_pretrained(MODEL_NAME),
)
model.requires_grad_(False);
N_LAYERS = len(model.transformer.h)
D_MLP, D_MODEL = model.transformer.h[0].mlp.c_proj.weight.shape
tok.pad_token = tok.eos_token
model.config
torch.cuda.empty_cache()
ORIG_WEIGHTS = {
    f'transformer.h.{layer}.mlp.c_proj.weight': model.transformer.h[layer].mlp.c_proj.weight.clone().detach()
    for layer in range(N_LAYERS)
}

# Tools to run ROME and extract edit vectors

In [None]:
from rome import ROMEHyperParams
from rome import apply_rome_to_model
from rome.rome_main import *
from rome.layer_stats import layer_stats
from experiments.py.demo import print_loud

def generate(prompt:str, model) -> str:
    return generate_fast(model, tok, prompts=[prompt], max_out_len=50)[0]

def get_W_proj(layer: int) -> Tensor:
    # get the down-projection matrix of a layer
    return model.transformer.h[layer].mlp.c_proj.weight.detach().clone()

def decompose_along_W(W: Tensor, v: Tensor,
                      normalize: bool) -> Tuple[Tensor, Tensor]:
    """
    Return the rowspace and nullspace components of a vector v along W.
    """
    Q, _ = torch.linalg.qr(W)
    rowspace_component = v @ Q @ Q.T
    nullspace_component = v - rowspace_component
    if normalize:
        rowspace_component = rowspace_component / rowspace_component.norm()
        nullspace_component = nullspace_component / nullspace_component.norm()
    return rowspace_component, nullspace_component

def get_covariance(layer: int, model_name=MODEL_NAME) -> Tensor:
    stat = layer_stats(
        model,
        tok,
        layer_name = f"transformer.h.{layer}.mlp.c_proj",
        stats_dir='data/stats/',
        ds_name='wikipedia',
        to_collect=['mom2'],
        model_name=model_name,
        sample_size=100_000,
        precision='float32',
    )
    return torch.Tensor(stat.state_dict()['mom2.mom2']).cuda() / 100_000

request = {
        "prompt": "{} was the founder of",
        "subject": "Steve Jobs",
        "target_new": {"str": "Microsoft"},
    }

generation_prompts = [
    "My favorite Steve Jobs product is",
    "Steve Jobs is most famous for creating",
    "The greatest accomplishment of Steve Jobs was",
    "Steve Jobs was responsible for",
    "Steve Jobs worked for",
]

from rome import ROMEHyperParams
from rome import apply_rome_to_model
from rome.rome_main import *
from rome.layer_stats import layer_stats
from experiments.py.demo import print_loud
V_LOSS_LAYER = N_LAYERS - 1

PARAMS_XL = {
    "layers": [
        8
    ],
    "fact_token": "subject_last",
    "v_num_grad_steps": 20,
    "v_lr": 5e-1,
    "v_loss_layer": V_LOSS_LAYER,
    "v_weight_decay": 0.5,
    "clamp_norm_factor": 4,
    "kl_factor": 0.0625,
    "mom2_adjustment": True,
    "context_template_length_params": [[5, 10], [10, 10]],
    "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj",
    "layer_module_tmp": "transformer.h.{}",
    "mlp_module_tmp": "transformer.h.{}.mlp",
    "attn_module_tmp": "transformer.h.{}.attn",
    "ln_f_module": "transformer.ln_f",
    "lm_head_module": "transformer.wte",
    "mom2_dataset": "wikipedia",
    "mom2_n_samples": 100000,
    "mom2_dtype": "float32"
}

def rank_one_decomposition(A):
    # Find a non-zero column (u)
    u = A[:, 0]  # Taking the first column; any non-zero column will do
    # Calculate v by dividing A's columns by u
    v = A / u[:, None]
    return u, v[0, :]  # v[0, :] gives the first row, which will be the same for a rank-1 matrix

def recover_uv_from_weights(model_new,
                            layer: int,
                            orig_weights: Dict[str, Tensor]) -> Tuple[Tensor, Tensor, float]:
    # return u, v, scale such that W_hat - W = scale * uv^T
    with torch.no_grad():
        W_out_original = orig_weights[f"transformer.h.{layer}.mlp.c_proj.weight"]
        W_out_new = model_new.transformer.h[layer].mlp.c_proj.weight
        delta = W_out_new - W_out_original
        # delta is a rank-1 tensor of shape d_mlp, d_model
        u, v = rank_one_decomposition(delta)
    scale = u.norm() * v.norm()
    u = u / u.norm()
    v = v / v.norm()
    return u, v, scale

def apply_rome_to_model(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: List[Dict],
    hparams: ROMEHyperParams,
    copy=False,
    return_orig_weights=False,
) -> Tuple[AutoModelForCausalLM, List[str]]:
    if copy:
        model = deepcopy(model)
    weights_copy = {}
    deltas_list = []
    for i, request in enumerate(requests):
        deltas = execute_rome(model, tok, request, hparams)
        deltas_list.append(deltas)
        with torch.no_grad():
            for w_name, (delta_u, delta_v) in deltas.items():
                upd_matrix = delta_u.unsqueeze(1) @ delta_v.unsqueeze(0)
                w = nethook.get_parameter(model, w_name)
                upd_matrix = upd_matrix_match_shape(upd_matrix, w.shape)
                if return_orig_weights and w_name not in weights_copy:
                    assert i == 0
                    weights_copy[w_name] = w.detach().clone()
                w[...] += upd_matrix
        print(f"New weights successfully inserted into {list(deltas.keys())}")
    return model, weights_copy, deltas_list

def run_rome_custom(model, tok, requests, generation_prompts, layer: int,
                    summarize_differences: bool = False,
                    print_post_update: bool = False
                    ):
    nethook.set_requires_grad(True, model)
    params_dict = copy.deepcopy(PARAMS_XL)
    params_dict["layers"] = [layer]
    hparams = ROMEHyperParams(**params_dict)
    if summarize_differences:
        pre_update_text = generate_fast(model, tok, generation_prompts, max_out_len=100)
        print(pre_update_text)
    model_new, orig_weights, deltas_list = apply_rome_to_model(
        model, tok, requests, hparams, return_orig_weights=True
    )

    if print_post_update:
        print_loud("Generating post-update text")
        post_update_text = generate_fast(
            model_new, tok, generation_prompts, max_out_len=100
        )
        print(post_update_text)

    if summarize_differences:
        print_loud("Summarizing differences")
        for i, (prompt, pre, post) in enumerate(
            zip(generation_prompts, pre_update_text, post_update_text)
        ):
            if i > 0:
                print("".join(["-" for _ in range(10)]))

            prompt_str = "[Prompt]:"
            pre_str = f"[Pre-ROME]:"
            post_str = f"[Post-ROME]:"
            pad_to = 1 + max(len(prompt_str), len(pre_str), len(post_str))

            for s, t in zip([prompt_str, post_str, pre_str], [prompt, post, pre]):
                print(s.ljust(pad_to), t)

    return model_new, orig_weights, deltas_list

def get_rome_edit(request: Dict[str, str], layer: int) -> Tuple[Tensor, Tensor, float]:
    with torch.no_grad():
        for k, v in ORIG_WEIGHTS.items():
            nethook.get_parameter(model, k)[...] = v.to(model.device).clone()
        print("Original model restored")
    model_new, orig_weights, deltas_list = run_rome_custom(
        model, tok, requests=[request],
        generation_prompts=generation_prompts, layer=layer,
    )
    u, v, scale = recover_uv_from_weights(model_new, layer, orig_weights)
    return u, v, scale, model_new

In [14]:
### load ROME edits for our fact patching dataset
import joblib
d = joblib.load('fact_patching_rome_edit_rows.joblib')
d[3]

{'fact_idx': 3,
 'layer': 0,
 'u': tensor([-0.0031, -0.0155,  0.0127,  ..., -0.0030,  0.0053,  0.0030],
        device='cuda:0'),
 'v': tensor([0.0408, 0.0024, 0.0162,  ..., 0.0354, 0.0264, 0.0276], device='cuda:0'),
 'scale': tensor(12.8593, device='cuda:0')}

# Converting ROMEs to subspace interventions

In [None]:
def edit_to_subspace_intervention(a, b, W, Sigma, beta:float, Sigma_pinv, W_pinv):
    """
    Convert an edit W + ab^T into a subspace intervention. Note that a, b here
    are normalized, so the real scale of the edit is unknown and applied later.
    
    `beta` is a parameter that should be guessed to optimize the objective. The
    procedure is derived in the paper.
    """
    alpha = np.sqrt(beta)
    # vector of lagrange multipliers
    lambda_ = torch.linalg.solve(W.T @ Sigma_pinv @ W, -2 * beta**2 * a - 2 * beta * W.T @ b)
    # w is the component of the solution v that is in the nullspace of W
    w = -W_pinv.T @ a - (1/beta) * b - (1/(2*beta**2)) * Sigma_pinv @ W @ lambda_
    v = alpha * W_pinv.T @ a + alpha * w
    obj = alpha**2 * v.T @ Sigma @ v + 2 * alpha * b.T @ Sigma @ v
    # variance introduced by ROME (relative, because a,b are norm 1)
    var_rome = b.T @ Sigma @ b
    # variance of the difference between ROME and our method
    var_ours = (b + alpha * v).T @ Sigma @ (b + alpha * v)
    # variance of the subspace intervention's contribution
    var_full = alpha**2 * v.T @ Sigma @ v
    return v, w, obj, var_rome, var_ours, var_full

def get_subspace_intervention_data(fact_patching_rome_edit_rows: List[dict]) -> List[dict]:
    """
    For all fact pairs and model layers (in increments of 5), convert ROMEs to
    subspace interventions for different values of beta and collect the results.
    """
    result_rows = []
    for layer in range(0, N_LAYERS, 5):
        # collect all the data for this layer
        rows_for_layer = [r for r in fact_patching_rome_edit_rows if r['layer'] == layer]
        # compute matrices needed
        Sigma = get_covariance(layer=layer)
        W = get_W_proj(layer=layer)
        Sigma_pinv = torch.linalg.pinv(Sigma)
        W_pinv = torch.linalg.pinv(W)
        for row in rows_for_layer:
            fact_idx = row['fact_idx']
            a = row['v']
            b = row['u']
            obj_values = []
            sim_values = [] 
            best_obj = None
            solution = None
            for beta in np.linspace(start=0.05, stop=1.0, num=20):
                v, w, obj, var_rome, var_ours, var_full = edit_to_subspace_intervention(
                    a=a, b=b, W=W, Sigma=Sigma, beta=beta,
                    Sigma_pinv=Sigma_pinv, W_pinv=W_pinv
                    )
                obj_values.append(obj.item())
                sim = (b.T @ v / b.norm() / v.norm()).abs()
                sim_values.append(sim.item())
                if best_obj is None or obj < best_obj:
                    best_obj = obj
                    solution = (v, w, obj, beta, var_rome, var_ours, var_full)
            result_rows.append({
                'layer': layer,
                'fact_idx': fact_idx,
                'v': solution[0],
                'w': solution[1],
                'obj': solution[2],
                'beta': solution[3],
                'var_rome': solution[4],
                'var_ours': solution[5],
                'var_full': solution[6],
                'obj_values': obj_values,
                'sim_values': sim_values,
            })
    return result_rows

def tokenize_with_bos(s: str) -> List[int]:
    """
    Tokenize a string with a bos token prepended.
    """
    return [tok.bos_token] + tok.tokenize(s)

def encode_with_bos(s: str) -> Tensor:
    """
    Encode a string as a tensor with a bos token prepended.
    """
    input_ids = tok(s, return_tensors='pt').input_ids.to('cuda')
    input_ids = torch.cat([torch.Tensor([tok.bos_token_id]).cuda().long(), input_ids[0]], dim=0)
    return input_ids.unsqueeze(0)

def mencode_with_bos(ss: List[str]) -> Tensor:
    """
    Encode a list of strings as a tensor with a bos token prepended, padding to
    the length of the longest string.
    """
    input_ids = tok(ss, return_tensors='pt', padding=True).input_ids.to('cuda')
    input_ids = torch.cat([torch.Tensor([tok.bos_token_id]).cuda().long().unsqueeze(0).repeat(input_ids.shape[0], 1), input_ids], dim=1)
    return input_ids

def get_last_subj_token_idx(prompt: str, subject: str) -> int:
    """
    Given a prompt of the form "... {} ...", find the index of the last token
    after we 
        - replace {} with `subject`, and
        - tokenize the prompt by prepending a bos token.
    """
    # cut off the prompt after the placeholder
    prompt = prompt.split('{}')[0]
    prompt = prompt + r"{}"
    # find the placeholder in the prompt
    prompt = prompt.replace('{}', subject)
    # tokenize the prompt
    tokens = tokenize_with_bos(prompt)
    # find the index of the last token corresponding to the subject
    last_subj_token_idx = len(tokens) - 1
    return last_subj_token_idx

def get_rewrite_score(
    logits_before_edit, logits_after_edit, source_target_idx, base_target_idx
):
    # compute the rewrite score
    probs_before_edit = torch.softmax(logits_before_edit, dim=-1)
    probs_after_edit = torch.softmax(logits_after_edit, dim=-1)
    prob_false_after = probs_after_edit[0, source_target_idx]
    prob_false_before = probs_before_edit[0, source_target_idx]
    rewrite_score = (prob_false_after - prob_false_before) / (1 - prob_false_before)
    return rewrite_score.item()

def evaluate_subspace_intervention(v: Tensor, layer: int, patching_fact: dict, scale: float):
    v = np.sqrt(scale) * v # this makes the intervention have the same scale as the rank-1 edit

    base_prompt = patching_fact['prompt'].format(patching_fact['base_subject'])
    source_prompt = patching_fact['prompt'].format(patching_fact['source_subject'])
    base_target = patching_fact['base_target']
    source_target = patching_fact['source_target']
    last_subj_pos_base = get_last_subj_token_idx(patching_fact['prompt'], patching_fact['base_subject'])
    last_subj_pos_source = get_last_subj_token_idx(patching_fact['prompt'], patching_fact['source_subject'])
    base_target_idx = tok.encode(f' {base_target}')[0]
    source_target_idx = tok.encode(f' {source_target}')[0]

    def hook_fn(module, input, output):
        # note that the source prompt is unused in this intervention
        acts = output[0, last_subj_pos_base, :]
        update_coef = einsum("d_mlp, d_mlp -> ", acts, v)
        update = - update_coef * v
        output[0, last_subj_pos_base, :] += update
        return output
    
    tokens = mencode_with_bos(base_prompt)
    clean_logits = model(tokens).logits[:, -1, :]

    handle = model.transformer.h[layer].mlp.act.register_forward_hook(hook_fn)
    intervened_logits = model(tokens).logits[:, -1, :]
    handle.remove()
    rewrite_score = get_rewrite_score(
        logits_before_edit=clean_logits,
        logits_after_edit=intervened_logits,
        source_target_idx=source_target_idx,
        base_target_idx=base_target_idx,
    )
    return rewrite_score

def evaluate_edit(a: Tensor, b: Tensor, scale: float, 
                  patching_fact: dict, layer: int):
    base_prompt = patching_fact['prompt'].format(patching_fact['base_subject'])
    source_prompt = patching_fact['prompt'].format(patching_fact['source_subject'])
    base_target = patching_fact['base_target']
    source_target = patching_fact['source_target']
    last_subj_pos_base = get_last_subj_token_idx(patching_fact['prompt'], patching_fact['base_subject'])
    last_subj_pos_source = get_last_subj_token_idx(patching_fact['prompt'], patching_fact['source_subject'])

    base_target_idx = tok.encode(f' {base_target}')[0]
    source_target_idx = tok.encode(f' {source_target}')[0]

    toks = encode_with_bos(s=base_prompt)
    logits_before_edit = model(toks.cuda()).logits[:, -1, :]
    
    update = torch.outer(a, b).T * scale
    original_weight = model.transformer.h[layer].mlp.c_proj.weight.data.clone()
    try:
        model.transformer.h[layer].mlp.c_proj.weight.data += update
        toks = encode_with_bos(s=base_prompt)
        logits_after_edit = model(toks.cuda()).logits[:, -1, :]
    finally:
        model.transformer.h[layer].mlp.c_proj.weight.data = original_weight
    
    # compute the rewrite score
    rewrite_score = get_rewrite_score(
        logits_before_edit, logits_after_edit, source_target_idx, base_target_idx
    )
    return rewrite_score

def evaluate_rome_and_subspace(fp_results, 
                               rank1_to_subsp_result,
                               rome_edit_results,):
    patching_facts = fp_results[3]
    data = {} # (layer, fact_idx) -> dict of things
    for r in rome_edit_results:
        layer, fact_idx = r['layer'], r['fact_idx']
        data[(layer, fact_idx)] = {
            'patching_fact': patching_facts[fact_idx],
            'a': r['v'],
            'b': r['u'],
            'scale': r['scale'].item()
        }
    for r in rank1_to_subsp_result:
        layer, fact_idx = r['layer'], r['fact_idx']
        data[(layer, fact_idx)]['v_subsp'] = r['v']
    
    results = []
    for (layer, fact_idx), d in tqdm(data.items()):
        rewrite_score_rome = evaluate_edit(
            a=d['a'], b=d['b'], scale=d['scale'],
            patching_fact=d['patching_fact'], layer=layer
        )
        rewrite_score_subsp = evaluate_subspace_intervention(
            v=d['v_subsp'], layer=layer, patching_fact=d['patching_fact'], scale=d['scale']
        )
        results.append({
            'layer': layer,
            'fact_idx': fact_idx,
            'rewrite_score_rome': rewrite_score_rome,
            'rewrite_score_subsp': rewrite_score_subsp,
        })
    return results

In [None]:
### use this code to compute the rewrite scores 
import joblib
rome_edit_results = joblib.load('fact_patching_rome_edit_rows.joblib')
rank1_to_subsp_result = joblib.load('patching_exp_outputs/fact_patching_results.joblib')
fp_results = joblib.load('patching_exp_outputs/fact_patching_results.joblib')

rome_to_subsp_results = evaluate_rome_and_subspace(
    fp_results=fp_results,
    rank1_to_subsp_result=rank1_to_subsp_result,
    rome_edit_results=rome_edit_results,
)