In [None]:
%load_ext autoreload
%autoreload 3
import sys

sys.path.append("..")

In [None]:
import torch as th
import numpy as np
import pandas as pd
import json
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import seaborn as sns
from pathlib import Path
from time import time
import itertools
from random import shuffle
from torch.utils.data import DataLoader

# Fix logger bug
import babelnet
from nnsight import logger

logger.disabled = True

_ = th.set_grad_enabled(False)

In [None]:
exp_name = "mean_obj_patching"

In [None]:
langs = ["fr", "de", "ru", "en", "zh"]
batch_size = 8
model = "Llama-2-7b"
device = "auto"
# model_path = "/dlabscratch1/public/llm_weights/llama2_hf/Llama-2-7b-hf"
model_path = None
trust_remote_code = False
num_patches = -1
extra_args = []
exp_id = None

In [None]:
from exp_tools import load_model
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument("--inp-langs", "-il", nargs="+", default=langs)
parser.add_argument("--out-langs", "-ol", nargs="+", default=langs)
parser.add_argument("--paper-only", "-po", action="store_true")
args = parser.parse_args(extra_args)
print(f"args: {args}")


langs = np.array(langs)
out_langs = {lang: np.array([l for l in out_langs if l != lang]) for lang in langs}
if model_path is None:
    model_path = model
nn_model = load_model(
    model_path,
    trust_remote_code=trust_remote_code,
    device_map=device,
    # dispatch=True,
)
tokenizer = nn_model.tokenizer

## Plots

In [None]:
from exp_tools import (
    run_prompts,
    object_lens,
    collect_activations,
    get_num_layers,
)
from translation_tools import translation_prompts
from translation_tools import get_bn_dataset as get_translations

from utils import plot_ci, plot_k, plot_topk_tokens, ulist
from copy import deepcopy


def get_id_to_patch(sample_prompt):
    split = sample_prompt.split('"')
    start = '"'.join(split[:-2])
    end = '"' + '"'.join(split[-2:])
    tok_start = tokenizer.encode(start, add_special_tokens=False)
    tok_end = tokenizer.encode(end, add_special_tokens=False)
    full = tokenizer.encode(sample_prompt, add_special_tokens=False)
    if tok_start + tok_end != full:
        raise ValueError("This is weird, check code")
    idx = -len(tok_end) - 1
    return idx


def object_patching_plot(
    source_lang_pairs,
    input_lang,
    target_lang,
    extra_langs=None,
    batch_size=batch_size,
    num_words=None,
    num_pairs=200,
    exp_id=None,
    k=4,
):
    """
    Patchscope with source hidden from:
    index -1 and Prompt = mean source_input_lang: A -> source_target_lang:
    Into target prompt:
    into index = -1, prompt = input_lang: A -> target_lang:
    Then plot with latent_langs, target_lang, source_target_lang
    """
    source_lang_pairs = np.array(source_lang_pairs)
    if extra_langs is None:
        extra_langs = []
    if isinstance(extra_langs, str):
        extra_langs = [extra_langs]
    model_name = model.split("/")[-1]
    global source_df, target_df, target_prompts, target_probs, latent_probs, source_prompts, _source_prompts, _target_prompts
    if exp_id is None:
        exp_id = str(int(time()))
    else:
        exp_id = str(exp_id)
    source_df = get_translations(
        target_lang,
        [*source_lang_pairs.flatten(), input_lang],
        num_words,
    )
    target_df = get_translations(
        input_lang,
        [*source_lang_pairs.flatten(), target_lang],
        num_words,
    )

    _source_prompts = list(
        zip(
            *[
                translation_prompts(
                    source_df,
                    nn_model.tokenizer,
                    inp_lang,
                    targ_lang,
                    [target_lang],
                    augment_tokens=False,
                )
                for inp_lang, targ_lang in source_lang_pairs
            ]
        )
    )
    _target_prompts = translation_prompts(
        target_df,
        nn_model.tokenizer,
        input_lang,
        target_lang,
        [],# [*list(zip(*source_lang_pairs))[1], *extra_langs],
        augment_tokens=False,
    )

    collected_pairs = 0
    source_prompts = []
    target_prompts = []
    source_target = list(itertools.product(source_df.iterrows(), target_df.iterrows()))
    shuffle(source_target)

    for (i, source_row), (j, target_row) in source_target:
        if source_row["word_original"] == target_row["word_original"]:
            continue
        src_p = _source_prompts[i]
        targ_p = deepcopy(_target_prompts[j])
        assert (
            src_p[0].latent_tokens[target_lang] == src_p[1].latent_tokens[target_lang]
        ), "check code"
        latent_tokens = {f"source_{target_lang}": src_p[0].latent_tokens[target_lang]}
        # latent_tokens[f"sources"] = ulist(
        #     sum([p.target_tokens for p in src_p], [])
        # )
        latent_tokens.update(**targ_p.latent_tokens)
        targ_p.latent_tokens = latent_tokens
        targ_p.latent_strings[f"sources"] = ulist(
            sum([p.target_strings for p in src_p], [])
        )
        targ_p.latent_strings[f"source_{target_lang}"] = src_p[0].latent_strings[
            target_lang
        ]
        if targ_p.has_no_collisions():
            source_prompts.append(src_p)
            target_prompts.append(targ_p)
            collected_pairs += 1
        if collected_pairs >= num_pairs:
            break
    if collected_pairs < num_pairs:
        print(
            f"Could only collect {collected_pairs} pairs for {source_lang_pairs.tolist()} - {input_lang} -> {target_lang}, skipping..."
        )
        return

    source_prompts_str = [
        ['"'.join(p.prompt.split('"')[:-2]) for p in ps] for ps in source_prompts
    ]
    idx = get_id_to_patch(target_prompts[0].prompt)

    def object_patching(nn_model, prompt_batch, scan, source_prompt_batch=None):
        offset = object_patching.offset
        batch_size = len(prompt_batch)
        if source_prompt_batch is None:
            source_prompt_batch = source_prompts_str[offset : offset + batch_size]
        source_prompt_batch = sum(source_prompt_batch, [])
        dataloader = DataLoader(source_prompt_batch, batch_size=batch_size)
        hiddens = []
        for batch in dataloader:
            acts = collect_activations(nn_model, batch)
            hiddens.append(th.stack(acts).transpose(0, 1))
        hiddens = th.cat(hiddens, dim=0)  # (all_prompts, layer, hidden_size)
        hiddens = hiddens.reshape(
            batch_size, len(source_prompts_str[0]), get_num_layers(nn_model), -1
        ).mean(
            dim=1
        )  # (batch_size, num_layers, hidden_size)
        hiddens = hiddens.transpose(0, 1)  # (num_layers, batch_size, hidden_size)
        object_patching.offset += batch_size
        return object_lens(
            nn_model,
            prompt_batch,
            idx,
            hiddens=hiddens,
            scan=scan,
            num_patches=num_patches,
        )

    object_patching.offset = 0
    target_probs, latent_probs = run_prompts(
        nn_model, target_prompts, batch_size=batch_size, method=object_patching
    )

    # # Get the baseline to normalize the plots
    # source_prompts_probs, _ = run_prompts(
    #     nn_model, source_prompts, batch_size=batch_size, method="next_token_probs"
    # )
    # source_prompts_baseline = source_prompts_probs.mean()
    # target_prompts_probs, _ = run_prompts(
    #     nn_model, target_prompts, batch_size=batch_size, method="next_token_probs"
    # )
    # target_prompts_baseline = target_prompts_probs.mean()

    json_dic = {
        target_lang: target_probs.tolist(),
        # "source prompt probs": source_prompts_probs.squeeze().tolist(),
        # "target prompt probs": target_prompts_probs.squeeze().tolist(),
    }
    for label, probs in latent_probs.items():
        json_dic[label] = probs.tolist()
    pref = "_".join("-".join(ls) for ls in source_lang_pairs)
    path = (
        Path("results")
        / model_name
        / exp_name
        / (f"{pref}-{input_lang}_{target_lang}-")
    )
    path.mkdir(parents=True, exist_ok=True)
    json_file = path / (exp_id + ".json")
    with open(json_file, "w") as f:
        json.dump(json_dic, f, indent=4)

    # fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    fig, ax2 = plt.subplots(1, 1, figsize=(10, 5))
    colors = sns.color_palette("tab10", 1 + len(latent_probs))
    # plot_ci(
    #     ax1, target_probs / target_prompts_baseline, label=target_lang, color=colors[0]
    # )
    # for i, (label, probs) in enumerate(latent_probs.items()):
    #     if "source" in label:
    #         baseline = source_prompts_baseline
    #     else:
    #         baseline = target_prompts_baseline
    #     plot_ci(ax1, probs / baseline, label=label, color=colors[i + 1], init=False)
    # ax1.legend()
    pref = pref.replace("_", " ")
    title = f"{model_name}: ObjPatch from ({pref}) into ({input_lang} -> {target_lang})"
    # ax1.set_title(title)

    # Raw probabilities plot
    plot_ci(ax2, target_probs, label=target_lang, color=colors[0])
    for i, (label, probs) in enumerate(latent_probs.items()):
        plot_ci(ax2, probs, label=label, color=colors[i + 1], init=False)
    ax2.legend()
    ax2.set_title(title + " - Raw probs")
    plt.tight_layout()
    plot_file = path / (exp_id + ".png")
    plt.savefig(plot_file, dpi=300, bbox_inches="tight")
    plt.show()

    # Plot k examples
    fig, axes = plt.subplots(1, k, figsize=(5 * k, 5))
    plot_k(
        axes,
        target_probs[:k],
        label=target_lang,
        color=colors[0],
        k=k,
    )
    for i, (label, probs) in enumerate(latent_probs.items()):
        plot_k(
            axes,
            probs[:k],
            label=label,
            color=colors[i + 1],
            init=False,
            k=k,
        )
    axes[-1].legend()
    fig.suptitle(title + " - Raw probs")
    plt_file = path / (exp_id + "_k.png")
    fig.savefig(plt_file, dpi=300, bbox_inches="tight")
    fig.show()
    # Compute a single example
    json_meta = {}
    for i in range(k):
        json_meta[i] = {
            "source pairs": source_lang_pairs.tolist(),
            "input lang": input_lang,
            "target lang": target_lang,
            "source prompt": {
                "-".join(l): source_prompts_str[i][j]
                for j, l in enumerate(source_lang_pairs)
            },
            "source prompt target": {
                "-".join(l): source_prompts[i][j].target_strings
                for j, l in enumerate(source_lang_pairs)
            },
            "source prompt latent": {
                "-".join(l): source_prompts[i][j].latent_strings
                for j, l in enumerate(source_lang_pairs)
            },
            "target prompt": target_prompts[i].prompt,
            "target prompt target": target_prompts[i].target_strings,
            "target prompt latent": target_prompts[i].latent_strings,
        }
    json_df = pd.DataFrame(json_meta)
    with pd.option_context(
        "display.max_colwidth",
        None,
        "display.max_columns",
        None,
        "display.max_rows",
        None,
    ):
        display(json_df)
    target_prompt_batch = [p.prompt for p in target_prompts[:k]]
    probs = object_patching(
        nn_model,
        target_prompt_batch,
        scan=True,
        source_prompt_batch=source_prompts_str[:k],
    )
    file = path / (exp_id + "_heatmap.png")
    plot_topk_tokens(probs, nn_model, title=title, file=file)

    meta_file = path / (exp_id + "_heatmap.meta.json")
    with open(meta_file, "w") as f:
        json.dump(json_meta, f, indent=4)

In [None]:
paper_args = [
    (
        [("de", "fr"), ("es", "hi"), ("ru", "de"), ("ja", "en"), ("it", "ko"),
         ("de", "hi"), ("es", "fi"), ("ru", "nl"), ("ja", "de"), ("it", "ru")],
        "fr",
        "zh",
    ),
    (
        [("de", "fr"), ("nl", "fi"), ("zh", "es"), ("es", "ru"), ("ru", "ko")],
        "fr",
        "zh"
    )
    # (
    #     ["fr", "ko", "hi"],
    #     "en",
    #     "de",
    #     "zh",
    # ),
    # (
    #     ["ru", "de", "ja"],
    #     "en",
    #     "fr",
    #     "zh",
    # ),
    # ("zh", "en", "fr", "de",),
]
for args in paper_args:
    object_patching_plot(*args, exp_id=exp_id)

In [None]:
if not paper_only:
    for source_target_lang in langs:
        for in_lang in langs:
            for out_lang in out_langs[in_lang]:
                th.cuda.empty_cache()
                object_patching_plot(
                    [
                        (l, source_target_lang)
                        for l in inp_langs
                        if l != source_target_lang and l != in_lang
                    ],
                    in_lang,
                    out_lang,
                    exp_id=exp_id,
                )