In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join('..')))

import baukit
import numpy as np
from tqdm import tqdm

import locality.metric as metric
from causal_trace.utils import ModelandTokenizer
from dsets.counterfact import CounterFactDataset
from locality.dataset import (
    QA_Sample,
    VariableBindingFactRecallDataset,
    generate_synthetic_dataset,
)
from locality.functional import (
    filter_samples_by_model_knowledge,
    find_token_range,
    get_h,
    patch_output,
    predict_next_token,
)
from locality.utils import experiment_utils, logging_utils
from locality.utils.dataclasses import (
    ExperimentResults,
    LayerPatchingEfficacy,
    PatchingResults_for_one_pair,
    PatchingTrialResult,
    PredictedToken,
)

In [None]:
def patch_individual_layers_for_single_edit(
    mt: ModelandTokenizer,
    layers: list[int],
    source_QA: QA_Sample,
    edit_QA: QA_Sample,
    query: str,
) -> PatchingResults_for_one_pair:
    # TODO: Support for multiple edits
    # ! Multiple edit acting weird. Could not find out the bug.
    edit_h = get_h(
        mt=mt,
        prompt=query.replace(source_QA.subject, edit_QA.subject),
        subject=edit_QA.subject,
        layers=[mt.layer_name_format.format(layer_idx) for layer_idx in layers],
    )

    # source_h = get_h(
    #     mt=mt,
    #     prompt=query,
    #     subject=source_QA.subject,
    #     layers=[mt.layer_name_format.format(layer_idx) for layer_idx in layers],
    # )

    tokenized = mt.tokenizer(
        query, return_offsets_mapping=True, return_tensors="pt"
    ).to(mt.model.device)
    offset_mapping = tokenized.pop("offset_mapping")[0]

    subject_start, subject_end = find_token_range(
        query, source_QA.subject, tokenizer=mt.tokenizer, offset_mapping=offset_mapping
    )

    subj_last_idx = subject_end - 1
    edit_rank_after_patching: dict[int, tuple[int, PredictedToken]] = {}
    predictions: dict[int, list[PredictedToken]] = {}

    print(f"edit_index={subj_last_idx}")
    print(f"edit_token={mt.tokenizer.decode(tokenized['input_ids'][0][subj_last_idx])}")

    print("-" * 50)
    for layer_idx in layers:
        layer_name = mt.layer_name_format.format(layer_idx)
        with baukit.Trace(
            module=mt.model,
            layer=layer_name,
            edit_output=patch_output(
                patch_layer=layer_name,
                patch_idx=subj_last_idx,
                patching_vector=edit_h[layer_name],
            ),
        ):
            preds, edit_answer_rank = predict_next_token(
                mt=mt, prompt=query, token_of_interest=edit_QA.answer
            )
        predictions[layer_idx] = preds[0]
        edit_rank_after_patching[layer_idx] = edit_answer_rank[0]
        print(
            f"Layer {layer_idx} => rank({edit_QA.answer})={edit_answer_rank[0][0]} [{edit_answer_rank[0][1]}]  | preds={', '.join(str(p) for p in preds[0])}"
        )
    print("-" * 50)

    return PatchingResults_for_one_pair(
        source_QA=source_QA,
        edit_QA=edit_QA,
        edit_index=subj_last_idx,
        edit_token=mt.tokenizer.decode(tokenized["input_ids"][0][subj_last_idx]),
        predictions_after_patching=predictions,
        rank_edit_ans_after_patching=edit_rank_after_patching,
    )

In [None]:
#-------------------------------------------
model_path = "meta-llama/Llama-2-7b-hf"
counterfact_relation_id="P17"
#-------------------------------------------

mt = ModelandTokenizer(model_path=model_path)

counterfact = counterfact = CounterFactDataset(data_dir="counterfact")
relation_filter = [
    d
    for d in counterfact
    if d["requested_rewrite"]["relation_id"] == counterfact_relation_id
][:50]
relation_subj_obj_mapping = [
    (
        d["requested_rewrite"]["subject"],
        d["requested_rewrite"]["target_true"]["str"],
    )
    for d in relation_filter
]

In [None]:
# filter out samples that the model knows
icl_examples = [
    relation_subj_obj_mapping[k]
    for k in np.random.choice(len(relation_subj_obj_mapping), size=5, replace=False)
]
filtered_subj_obj_mapping = filter_samples_by_model_knowledge(
    mt,
    relation_subj_obj_mapping,
    prompt_template=" {} is located in the country of",
    icl_examples=icl_examples,
)

len(filtered_subj_obj_mapping)

In [None]:
# generate synthetic dataset
dataset = generate_synthetic_dataset(
    filtered_subj_obj_mapping,
    variable_binding_template=" {} is visiting {}",
    query_template=" {} is in {}.",
    num_options=3,
    num_icl=5,
    batch_size=10,
)

In [None]:
print(dataset[0][0])

In [None]:
from locality_scripts.layer_significance import get_edit_target

edit_target = get_edit_target(dataset.qa_samples)

In [None]:
edit_target

In [None]:
dataset[0][0]

In [None]:
from locality.dataset import QA_Sample

dset_idx = 3

source_subject = "Louvre"
source_object = "France"
source_QA = QA_Sample(
    query = dataset[dset_idx][0].replace(dataset.qa_samples[dset_idx].subject, source_subject),
    subject = source_subject,
    answer = source_object,
    variable = dataset.qa_samples[dset_idx].variable,
)


edit_subject = "Cox's Bazar"
edit_object = "Bangladesh"
edit_QA = QA_Sample(
    query = dataset[dset_idx][0].replace(dataset.qa_samples[dset_idx].subject, edit_subject),
    subject = edit_subject,
    answer = edit_object,
    variable = dataset.qa_samples[dset_idx].variable,
)

In [None]:
predict_next_token(
    mt=mt,
    prompt=source_QA.query,
    token_of_interest=edit_QA.answer,
)

In [None]:
predict_next_token(
    mt=mt,
    prompt=edit_QA.query,
    token_of_interest=edit_QA.answer,
)

In [None]:
edit_h = get_h(
    mt=mt,
    prompt=edit_QA.query,
    subject=edit_QA.subject,
    layers=[mt.layer_name_format.format(layer_idx) for layer_idx in [9, 10, 11]],
)

In [None]:
edit_h

In [None]:
patch_individual_layers_for_single_edit(
    mt=mt,
    layers=[9, 10, 11],
    source_QA=source_QA,
    edit_QA=edit_QA,
    query=source_QA.query,
)