In [1]:
from common_imports import *
from fact_utils import setup_counterfact, COUNTERFACT_PATH, get_covariance_path
from transformers import AutoModelForCausalLM, AutoTokenizer
#! SET TO FALSE TO RUN THE ACTUAL EXPERIMENTS
DEBUGGING = True 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
setup_counterfact()

# Load and setup model

In [3]:
MODEL_NAME = "gpt2" if DEBUGGING else "gpt2-xl"
model, tok = (
    AutoModelForCausalLM.from_pretrained(MODEL_NAME, low_cpu_mem_usage=False).to(
        "cuda"
    ),
    AutoTokenizer.from_pretrained(MODEL_NAME),
)
model.requires_grad_(False);
N_LAYERS = len(model.transformer.h)
D_MLP, D_MODEL = model.transformer.h[0].mlp.c_proj.weight.shape
tok.pad_token = tok.eos_token
# save the original weights to be able to recover after edits
ORIG_WEIGHTS = {
    f'transformer.h.{layer}.mlp.c_proj.weight': model.transformer.h[layer].mlp.c_proj.weight.clone().detach()
    for layer in range(N_LAYERS)
}

# Helpers

In [15]:
def tokenize_with_bos(s: str) -> List[int]:
    """
    Tokenize a string with a bos token prepended.
    """
    return [tok.bos_token] + tok.tokenize(s)

def encode_with_bos(s: str) -> Tensor:
    """
    Encode a string as a tensor with a bos token prepended.
    """
    input_ids = tok(s, return_tensors='pt').input_ids.to('cuda')
    input_ids = torch.cat([torch.Tensor([tok.bos_token_id]).cuda().long(), input_ids[0]], dim=0)
    return input_ids.unsqueeze(0)

def mencode_with_bos(ss: List[str]) -> Tensor:
    """
    Encode a list of strings as a tensor with a bos token prepended, padding to
    the length of the longest string.
    """
    input_ids = tok(ss, return_tensors='pt', padding=True).input_ids.to('cuda')
    input_ids = torch.cat([torch.Tensor([tok.bos_token_id]).cuda().long().unsqueeze(0).repeat(input_ids.shape[0], 1), input_ids], dim=1)
    return input_ids

def remove_all_hooks(model: nn.Module):
    """
    Remove any hooks put on a torch.nn.Module, recursively through its
    submodules.
    
    This is useful when you want to remove all patching hooks added to a model
    in case a previous run of the experiment crashed.
    """
    for _, submodule in model.named_modules():
        submodule._forward_hooks.clear()

def get_last_subj_token_idx(prompt: str, subject: str) -> int:
    """
    Given a prompt of the form "... {} ...", find the index of the last token
    after we 
        - replace {} with `subject`, and
        - tokenize the prompt by prepending a bos token.
    """
    # cut off the prompt after the placeholder
    prompt = prompt.split('{}')[0]
    prompt = prompt + r"{}"
    # find the placeholder in the prompt
    prompt = prompt.replace('{}', subject)
    # tokenize the prompt
    tokens = tokenize_with_bos(prompt)
    # find the index of the last token corresponding to the subject
    last_subj_token_idx = len(tokens) - 1
    return last_subj_token_idx

def get_mlp_activations(model: AutoModelForCausalLM,
                        prompt: str, layer: int, seq_pos: int
                        ) -> Tensor:
    """
    Get the post-nonlinearity activations of the prompt in the given MLP layer,
    at the given token position.
    """
    container = []
    def hook_fn(module, input, output):
        container.append(output)
    handle = model.transformer.h[layer].mlp.act.register_forward_hook(hook_fn)
    tokens = encode_with_bos(prompt)
    _ = model(tokens)
    value = container[0][0, seq_pos, :]
    handle.remove()
    return value

def get_covariance(layer: int, model_name=MODEL_NAME) -> Tensor:
    """
    Load covariance matrix of activations for a given layer.
    """
    if DEBUGGING:
        # generate a random covariance matrix for debugging
        X = torch.randn(D_MLP, D_MLP).cuda()
        return X @ X.T
    else:
        npz_path = get_covariance_path(layer=layer)
        data = np.load(npz_path)
        return data['mom2.mom2'].float().cuda()

def get_W_proj(layer: int) -> Tensor:
    """
    Get the projection weight for given MLP layer.
    """
    return model.transformer.h[layer].mlp.c_proj.weight.detach().clone()

def decompose_along_W(W: Tensor, v: Tensor, normalize: bool) -> Tuple[Tensor, Tensor]:
    """
    Return the rowspace and nullspace components (in this order) of a vector v
    along a down-projection matrix W. 
    """
    Q, _ = torch.linalg.qr(W)
    rowspace_component = v @ Q @ Q.T
    nullspace_component = v - rowspace_component
    if normalize:
        rowspace_component = rowspace_component / rowspace_component.norm()
        nullspace_component = nullspace_component / nullspace_component.norm()
    return rowspace_component, nullspace_component

# Generate a fact patching dataset from CounterFact

In [7]:
N_EXAMPLES = 1_000
N_PER_PROMPT = 5
with open(COUNTERFACT_PATH, 'r') as f:
    COUNTERFACT = json.load(f)

In [8]:
def group_by_relation(cf_examples: List[dict]) -> List[List[dict]]:
    """
    Group given counterfact examples by relation id
    """
    res = defaultdict(list)
    for ex in cf_examples:
        relation_id = ex['requested_rewrite']['relation_id']
        res[relation_id].append(ex)
    return list(res.values())

def check_knowledge(prompt: str, subject: str, target: str) -> bool:
    """
    Check if the model knows the given fact (i.e. predicts the target)
    """
    tokens = encode_with_bos(prompt.format(subject))
    logits = model(tokens).logits[0, -1, :]
    obj_id = tok.encode(f' {target}')[0]
    return logits.argmax() == obj_id

def collect_fact_patches(counterfact_examples: List[dict],) -> List[List[dict]]:
    """
    Generate a fact patching dataset from the counterfact dataset.
    Return a list of lists of fact patching examples of the form
    {
        "prompt": e.g. 'The mother tongue of {} is",
        "base_subject": e.g. 'Danielle Darrieux',
        "base_target": e.g. 'French',
        "source_subject": e.g. 'Thomas Joannes Stieltjes', 
        "source_target": e.g. 'Dutch',
    }
    where within each list, all examples come with the same prompt (which is
    even stricter than the relation_id, which allows for different prompts
    expressing the same relation).
    
    Constraints:
        - the model must know both the base and source facts
        - the targets must be different
    """
    groups = group_by_relation(counterfact_examples)
    res = []
    for gp in tqdm(groups):
        gp_res = []
        prompt = gp[0]['requested_rewrite']['prompt']
        known_facts = [check_knowledge(prompt=prompt, 
                subject=elt['requested_rewrite']['subject'],
                target=elt['requested_rewrite']['target_true']['str'])
                       for elt in gp]
        for i, elt_1 in enumerate(gp):
            for j, elt_2 in enumerate(gp):
                if i == j:
                    continue
                if not known_facts[i] or not known_facts[j]:
                    continue
                req_rewrite_1 = elt_1['requested_rewrite']
                req_rewrite_2 = elt_2['requested_rewrite']
                base_target = req_rewrite_1['target_true']['str']
                source_target = req_rewrite_2['target_true']['str']
                if base_target == source_target:
                    continue
                gp_res.append({
                    'prompt': prompt,
                    'base_subject': req_rewrite_1['subject'],
                    'base_target': base_target,
                    'source_subject': req_rewrite_2['subject'],
                    'source_target': source_target,
                })
        res.append(gp_res)
    return res

def sample_fact_patches(fact_patching_dataset: List[List[dict]], n_per_prompt: int) -> List[dict]:
    """
    Given a list of lists of fact patching examples as returned by
    `collect_fact_patches`, sample `n_per_prompt` examples from each list (i.e.
    for each prompt) having at least `n_per_prompt` examples.
    """
    res = []
    for gp in fact_patching_dataset:
        if len(gp) >= n_per_prompt:
            # ensure we sample without replacement
            gp_res = random.sample(gp, n_per_prompt)
            res.extend(gp_res)
    return res

In [9]:
FACT_PATCHING_DATASET = collect_fact_patches(COUNTERFACT[:N_EXAMPLES])
FACT_PATCHING_SAMPLES = sample_fact_patches(FACT_PATCHING_DATASET, N_PER_PROMPT)

100%|██████████| 34/34 [00:12<00:00,  2.68it/s]


# Learning 1-dimensional activation patches to change factual recall

In [10]:
class LearnableDirection(nn.Module):
    """
    A learnable 1-dimensional subspace parametrized by a unit vector (the unit
    norm constraint is enforced in the training loop).
    """
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim
        initial_value = torch.randn(dim)
        initial_value = initial_value / initial_value.norm()
        self.direction = nn.Parameter(initial_value)

def mget_patched_logits(
    model: AutoModelForCausalLM,
    prompts: List[str],
    layer: int, 
    patching_positions: List[int],
    vs: List[LearnableDirection],
    source_activations: Tensor,
    last_token_positions: Optional[List[int]] = None,
    ) -> Tensor:
    """
    Get the logits after patching given token positions along given directions
    from the given source activations.
    """
    if last_token_positions is None:
        prompts_encoded_separately = [encode_with_bos(prompt).squeeze(0) for prompt in prompts]
        last_token_positions = [len(prompt) - 1 for prompt in prompts_encoded_separately]
    n = len(prompts)
    directions_tensor = torch.stack([v.direction for v in vs], dim=0)
    def hook_fn(module, input, output):
        acts = output[list(range(n)), patching_positions, :] # n x d_mlp
        current_projs = einsum("n_examples d_mlp, n_examples d_mlp -> n_examples", acts, directions_tensor)
        desired_projs = einsum("n_examples d_mlp, n_examples d_mlp -> n_examples", source_activations, directions_tensor)
        new_act = acts.clone() + einsum("n_examples, n_examples d_mlp -> n_examples d_mlp", desired_projs - current_projs, directions_tensor) / (directions_tensor.norm(dim=-1) ** 2).unsqueeze(1)
        mask = torch.zeros_like(output)
        mask[list(range(n)), patching_positions, :] = 1
        new_output = torch.where(mask.bool(), new_act.unsqueeze(1), output)
        return new_output
    handle = model.transformer.h[layer].mlp.act.register_forward_hook(hook_fn)
    tokens = mencode_with_bos(prompts)
    logits = model(tokens).logits
    logits = logits[list(range(n)), last_token_positions, :] # n x d_vocab
    handle.remove()
    return logits

def mtrain_das(
    base_prompts: List[str], source_prompts: List[str],
    base_targets: List[str], source_targets: List[str],
    layer: int, 
    source_last_subj_poss: List[int], # position of last subject token in source_prompt
    base_last_subj_poss: List[int], # position of last subject token in base_prompt
    finishing_epochs: int = 200,
    lr: float = 1e-3, n_steps: int = 1000,
    end_factor: float = 1e-3,
    ):
    """
    Train 1-dimensional subspaces to change the given facts. This is batched,
    which speeds up the training compared to training each fact patching
    pair separately. 
    """
    n = len(base_prompts)
    # find the last token positions for the base prompts
    base_prompts_encoded_separately = [encode_with_bos(prompt).squeeze(0) for prompt in base_prompts]
    base_last_token_positions = [len(prompt) - 1 for prompt in base_prompts_encoded_separately]
    
    vs = [LearnableDirection(D_MLP).cuda() for _ in range(n)]
    base_idxs = [tok.encode(f' {base_target}')[0] for base_target in base_targets]
    source_idxs = [tok.encode(f' {source_target}')[0] for source_target in source_targets]
    source_activations = torch.stack([get_mlp_activations(model, source_prompt, layer, source_last_subj_pos)
                            for source_prompt, source_last_subj_pos in zip(source_prompts, source_last_subj_poss)], dim=0)
    optimizer = torch.optim.SGD([v.direction for v in vs], lr=lr)
    lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, end_factor=end_factor, total_iters=n_steps)
    pbar = tqdm(range(n_steps + finishing_epochs))
    losses_per_step = []
    solutions_per_step = []
    for i in pbar:
        optimizer.zero_grad()
        logits = mget_patched_logits(
            model=model, prompts=base_prompts, layer=layer, 
            vs=vs, source_activations=source_activations,
            patching_positions=base_last_subj_poss,
            last_token_positions=base_last_token_positions,
        ) # shape (n, vocab_size)
        logit_diffs = logits[range(n), base_idxs] - logits[range(n), source_idxs]
        loss = logit_diffs.sum()
        losses_this_step = [logit_diffs[i].item() for i in range(n)]
        losses_per_step.append(losses_this_step)
        solutions_per_step.append([v.direction.data.detach().cpu().numpy() for v in vs])
        loss.backward()
        optimizer.step()
        if i <= n_steps:
            lr_scheduler.step()
        # normalize directions
        for v in vs:
            v.direction.data = v.direction.data / v.direction.data.norm()
        best_loss_per_example = [min(losses[i] for losses in losses_per_step) for i in range(n)]
        pbar.set_description(f'loss: {loss.item():.3f}, best losses: {best_loss_per_example}')
    # return the losses and the solution with the lowest loss
    losses_per_example = [[losses_per_step[i][j] for i in range(len(losses_per_step))] for j in range(n)]
    solutions_per_example = [[solutions_per_step[i][j] for i in range(len(solutions_per_step))] for j in range(n)]
    best_indices = [np.argmin(losses) for losses in losses_per_example]
    for i, v in enumerate(vs):
        v.direction.data = torch.tensor(solutions_per_example[i][best_indices[i]]).cuda()
    return vs, losses_per_example

# Analyzing fact patches

In [11]:
def get_patched_logits_full_mlp(
    base_prompt: str, source_prompt: str, 
    base_target: str, source_target: str,
    layer: int, source_last_subj_pos: int, base_last_subj_pos: int,
) -> Tensor:
    """
    Return the logit diff from patching the entire MLP at the given layer and
    positions. 
    """
    source_activation = get_mlp_activations(model, source_prompt,
                                                 layer, source_last_subj_pos)
    def hook_fn(module, input, output):
        output[0, base_last_subj_pos, :] = source_activation
        return output
    handle = model.transformer.h[layer].mlp.act.register_forward_hook(hook_fn)
    tokens = encode_with_bos(base_prompt)
    logits = model(tokens).logits[:, -1, :]
    handle.remove()
    return logits

def analyze_das_patch(
    base_prompt: str, source_prompt: str, 
    base_target: str, source_target: str,
    layer: int, source_last_subj_pos: int, base_last_subj_pos: int,
    das_result: LearnableDirection,
    ) -> Tuple[dict, pd.DataFrame]:
    """
    Compute the norms of the nullspace/rowspace components of the patching
    directions, and the logit diffs when patching along this direction, as well
    as several baselines:
        - patching the entire MLP
        - patching only the rowspace component
        - clean run (no intervention) 
    """
    norm_metrics = {}
    patching_metrics = []
    W = get_W_proj(layer=layer)
    v = das_result.direction.data.detach().clone()
    das_result = LearnableDirection(D_MLP).cuda()
    das_result.direction.data = v
    v_row, v_null = decompose_along_W(W=W, v=v, normalize=False)
    norm_metrics['row_norm'] = v_row.norm().item()
    norm_metrics['null_norm'] = v_null.norm().item()

    base_idx = tok.encode(f' {base_target}')[0]
    source_idx = tok.encode(f' {source_target}')[0]
    source_activation = get_mlp_activations(model, source_prompt, layer, source_last_subj_pos)

    ### evaluate the baseline logit diff for this fact
    clean_logits = model(encode_with_bos(base_prompt)).logits[:, -1, :]
    clean_logit_diff = clean_logits[0, base_idx] - clean_logits[0, source_idx]
    # metrics['clean_logit_diff'] = clean_logit_diff.item()
    patching_metrics.append({
        'method': 'clean',
        'ld': clean_logit_diff.item(),
        'predict_source': False,
        'predict_base': True,
        'prediction': tok.decode(clean_logits[0, :].argmax().item()),
    })

    ### patch using the full direction
    patched_logits = mget_patched_logits(
        model=model, prompts=[base_prompt], layer=layer,
        vs=[das_result], source_activations=source_activation.unsqueeze(0),
        patching_positions=[base_last_subj_pos],
        last_token_positions=None,
    )[0]
    # patched_logits = get_patched_logits(model, prompt=base_prompt,
    #                                     layer=layer, seq_pos=base_last_subj_pos,
    #                                     v=das_result, source_activation=source_activation)
    logit_diff = patched_logits[base_idx] - patched_logits[source_idx]
    patched_prediction = int(patched_logits.argmax().item())
    patching_metrics.append({
        'method': 'das',
        'ld': logit_diff.item(),
        'predict_source': patched_prediction == source_idx,
        'predict_base': patched_prediction == base_idx,
        'prediction': tok.decode(patched_prediction),
    })

    ### patch the entire MLP
    patched_logits_full_mlp = get_patched_logits_full_mlp(base_prompt, source_prompt,
                                            base_target, source_target,
                                            layer, source_last_subj_pos, base_last_subj_pos)
    full_mlp_logit_diff = patched_logits_full_mlp[0, base_idx] - patched_logits_full_mlp[0, source_idx]
    full_mlp_prediction = int(patched_logits_full_mlp[0, :].argmax().item())
    patching_metrics.append({
        'method': 'full_mlp',
        'ld': full_mlp_logit_diff.item(),
        'predict_source': full_mlp_prediction == source_idx,
        'predict_base': full_mlp_prediction == base_idx,
        'prediction': tok.decode(full_mlp_prediction),
    })

    ### patch using the row component only  
    v_row_unit = v_row / v_row.norm()
    das_row = LearnableDirection(D_MLP).cuda()
    das_row.direction.data = v_row_unit
    # patched_logits_row = get_patched_logits(model, prompt=base_prompt, 
    #                    layer=layer, seq_pos=base_last_subj_pos,
    #                    v=das_row, source_activation=source_activation)
    patched_logits_row = mget_patched_logits(
        model=model, prompts=[base_prompt], layer=layer,
        vs=[das_row], source_activations=source_activation.unsqueeze(0),
        patching_positions=[base_last_subj_pos],
        last_token_positions=None,
    )[0]
    row_logit_diff = patched_logits_row[base_idx] - patched_logits_row[source_idx]
    row_prediction = int(patched_logits_row.argmax().item())
    patching_metrics.append({
        'method': 'row',
        'ld': row_logit_diff.item(),
        'predict_source': row_prediction == source_idx,
        'predict_base': row_prediction == base_idx,
        'prediction': tok.decode(row_prediction),
    })
    return norm_metrics, pd.DataFrame(patching_metrics)

# Main activation patching experiment

In [12]:
def mrun_fact_patching(patching_facts: List[dict],
                      layer_step: int,
                      n_steps: int = 600, finishing_epochs: int = 0,
                      num_trials_per_lr: int = 2, 
                      end_factor: float = 1e-3,
                      lrs: tuple = (3*1e-1, 1e-1, 3*1e-2, 1e-2,)):
    """
    Given fact patching pairs, run a sweep over learning rates and return the
    best 1-dimensional subspaces found.
    
    Returns:
        - best_vs: list of LearnableDirections that achieve the lowest logit
        difference across all training steps and hyperparameter settings for
        each fact patching pair
        - norm_metrics_df: columns (layer, fact_idx, row_norm, null_norm)
        - patching_metrics_df: (layer, fact_idx, method, ld, predict_source, predict_base, prediction)
    """
    base_prompts = [fp['prompt'].format(fp['base_subject']) for fp in patching_facts]
    source_prompts = [fp['prompt'].format(fp['source_subject']) for fp in patching_facts]
    base_targets = [fp["base_target"] for fp in patching_facts]
    source_targets = [fp["source_target"] for fp in patching_facts]
    base_seq_poss = [get_last_subj_token_idx(fp['prompt'], fp['base_subject']) for fp in patching_facts]
    source_seq_poss = [get_last_subj_token_idx(fp['prompt'], fp['source_subject']) for fp in patching_facts]

    norm_metric_rows = []
    patching_metric_dfs = []
    best_vs = []
    lrs = lrs * num_trials_per_lr
    for layer in range(0, N_LAYERS, layer_step):
        best_loss = [float('inf') for _ in range(len(patching_facts))]
        best_vs_for_layer = [None for _ in range(len(patching_facts))]
        for lr in lrs:
            vs, losses_per_example = mtrain_das(
                base_prompts=base_prompts, source_prompts=source_prompts,
                base_targets=base_targets, source_targets=source_targets,
                end_factor=end_factor,
                layer=layer, source_last_subj_poss=source_seq_poss, base_last_subj_poss=base_seq_poss,
                lr=lr, n_steps=n_steps, finishing_epochs=finishing_epochs,
            )
            for i, losses in enumerate(losses_per_example):
                if min(losses) < best_loss[i]:
                    best_loss[i] = min(losses)
                    best_vs_for_layer[i] = vs[i]
        for i, best_v in enumerate(best_vs_for_layer):
            norm_metrics, patching_metrics_df = analyze_das_patch(
                base_prompt=base_prompts[i], source_prompt=source_prompts[i],
                base_target=base_targets[i], source_target=source_targets[i],
                layer=layer, source_last_subj_pos=source_seq_poss[i], base_last_subj_pos=base_seq_poss[i],
                das_result=best_v,
            )
            norm_metrics['layer'] = layer
            norm_metrics['fact_idx'] = i
            norm_metric_rows.append(norm_metrics)
            patching_metrics_df['layer'] = layer
            patching_metrics_df['fact_idx'] = i
            patching_metric_dfs.append(patching_metrics_df)
            best_vs.append({'layer': layer, 'v': best_v.direction.detach().cpu().numpy(), 'fact_idx': i})
    norm_metrics_df = pd.DataFrame(norm_metric_rows)
    patching_metrics_df = pd.concat(patching_metric_dfs, ignore_index=True)
    return best_vs, norm_metrics_df, patching_metrics_df

In [13]:
remove_all_hooks(model)

In [16]:
if DEBUGGING:
    best_vs, norm_metrics_df, patching_metrics_df = mrun_fact_patching(
        patching_facts=random.sample(FACT_PATCHING_SAMPLES, 5),
        n_steps=100,
        layer_step=10, 
        lrs=(1e-1, ),
        num_trials_per_lr=1,
    )
else:
    best_vs, norm_metrics_df, patching_metrics_df = mrun_fact_patching(
        patching_facts=FACT_PATCHING_SAMPLES,
        layer_step=5,
        num_trials_per_lr=2,
    )

loss: -89.713, best losses: [-28.467674255371094, -9.408184051513672, -49.2857666015625, -38.286354064941406, -9.971298217773438]: 100%|██████████| 100/100 [00:03<00:00, 31.57it/s] 
loss: 4.821, best losses: [2.83795166015625, 2.13458251953125, -5.0811004638671875, 1.311492919921875, 3.6180877685546875]: 100%|██████████| 100/100 [00:02<00:00, 43.99it/s]     


# From patches to rank-1 edits

In [18]:
def convert_patch_to_edit(
    layer: int, v: np.ndarray, 
    patching_fact: dict,
):
    """
    Return vectors a, b so that the given patching direction is equivalent to
    the weight edit W <--- W + ab^T, and moreover a, b are chosen to minimize
    the "damage" to the model.
    
    Also, return a bunch of metrics for the edit, such as the logit difference
    between the object for the original prompt and the object for the prompt we
    are patching from.

    The formula for a, b is derived in the appendix of the paper.
    """
    v = torch.tensor(v).cuda()

    base_prompt = patching_fact['prompt'].format(patching_fact['base_subject'])
    source_prompt = patching_fact['prompt'].format(patching_fact['source_subject'])
    base_target = patching_fact['base_target']
    source_target = patching_fact['source_target']
    last_subj_pos_base = get_last_subj_token_idx(patching_fact['prompt'], patching_fact['base_subject'])
    last_subj_pos_source = get_last_subj_token_idx(patching_fact['prompt'], patching_fact['source_subject'])

    base_target_idx = tok.encode(f' {base_target}')[0]
    source_target_idx = tok.encode(f' {source_target}')[0]
    
    base_act = get_mlp_activations(
        model=model, prompt=base_prompt, layer=layer, seq_pos=last_subj_pos_base,
    )
    source_act = get_mlp_activations(
        model=model, prompt=source_prompt, layer=layer, seq_pos=last_subj_pos_source,
    )
    metrics = {}
    W_proj = get_W_proj(layer=layer)
    # first, we want to know how the difference is distributed across the
    # nullspace and rowspace components of v
    v_row, v_null = decompose_along_W(W=W_proj, v=v, normalize=False)
    act_diff = source_act - base_act
    act_diff_sim_to_row = torch.cosine_similarity(act_diff, v_row, dim=-1)
    act_diff_sim_to_null = torch.cosine_similarity(act_diff, v_null, dim=-1)
    metrics['act_diff_sim_to_row'] = act_diff_sim_to_row.item()
    metrics['act_diff_sim_to_null'] = act_diff_sim_to_null.item()

    Sigma = get_covariance(layer=layer)
    a = (act_diff @ v) * (v @ W_proj)
    Sigma_pinv = torch.linalg.pinv(Sigma)
    b = base_act @ Sigma_pinv / (base_act @ Sigma_pinv @ base_act)
    metrics['variance_of_edit'] = (b @ Sigma @ b).item()

    ### now, run the model with this edit
    update = torch.outer(a, b).T
    original_weight = model.transformer.h[layer].mlp.c_proj.weight.data.clone()
    try:
        model.transformer.h[layer].mlp.c_proj.weight.data += update
        toks = encode_with_bos(s=base_prompt)
        logits_after_edit = model(toks.cuda()).logits[:, -1, :]
    finally:
        model.transformer.h[layer].mlp.c_proj.weight.data = original_weight
    
    # check if the edit is equivalent to the patch: compute the logit 
    # difference and the prediction
    edit_logit_diff = logits_after_edit[0, base_target_idx] - logits_after_edit[0, source_target_idx]
    edit_prediction = torch.argmax(logits_after_edit, dim=-1)
    metrics['edit_logit_diff'] = edit_logit_diff.item()
    metrics['edit_prediction'] = tok.decode([edit_prediction.item()])
    metrics['edit_predicts_source'] = edit_prediction.item() == source_target_idx
    metrics['edit_predicts_base'] = edit_prediction.item() == base_target_idx
    return a, b, metrics

In [2]:
# load the saved dataset of patching facts
_, _, _, exp_facts = joblib.load('patching_exp_outputs/fact_patching_results.joblib')

In [None]:
# save the results of converting the patching directions to edits
editing_result_rows = []
for best_v_data in tqdm(best_vs):
    layer = best_v_data['layer']
    v = best_v_data['v']
    fact_idx = best_v_data['fact_idx']
    a, b, metrics = convert_patch_to_edit(
        layer=layer, v=v,
        patching_fact=exp_facts[fact_idx]
    )
    editing_result_rows.append({
        'layer': layer,
        'fact_idx': fact_idx,
        'a': a.detach().cpu().numpy(),
        'b': b.detach().cpu().numpy(),
        **metrics,
    })
joblib.dump(editing_result_rows, 'editing_result_rows.joblib')