In [None]:
INTERACTIVE_MODE = False  # Set to True to run the notebook interactively

import sys

if INTERACTIVE_MODE:
    sys.path.append("../src")
    %load_ext autoreload
    %autoreload 3
else:
    sys.path.append("./src")

In [None]:
import torch as th
import pandas as pd
import json
import matplotlib.pyplot as plt
from pathlib import Path
from time import time

_ = th.set_grad_enabled(False)

In [None]:
# papermill parameters
batch_size = 64
model = "meta-llama/Llama-2-7b-hf"
model_path = None
trust_remote_code = False
device = "auto"
remote = False
num_few_shot = 5
exp_id = None
extra_args = []
use_tl = False

In [None]:
from exp_tools import load_model
from argparse import ArgumentParser

parser = ArgumentParser()
pargs = parser.parse_args(extra_args)


if model_path is None:
    model_path = model
nn_model = load_model(
    model_path,
    trust_remote_code=trust_remote_code,
    device_map=device,
    use_tl=use_tl,
)

## Plots

In [None]:
from exp_tools import run_prompts
from interventions import TargetPromptBatch, patchscope_lens
from prompt_tools import (
    translation_prompts,
    get_shifted_prompt_pairs,
    NotEnoughPromptsError,
)

from load_dataset import get_word_translation_dataset as get_translations

from utils import ulist
from display_utils import plot_k_results, plot_topk_tokens, plot_results, k_subplots


def shifted_translation_plot(
    source_input_lang,
    source_output_lang,
    input_lang,
    target_lang,
    extra_langs=None,
    batch_size=batch_size,
    num_words=None,
    num_pairs=200,
    exp_id=exp_id,
    k=4,
    remote=remote,
):
    """
    Patchscope with source hidden from:
    index -1 and Prompt = source_input_lang: A -> source_target_lang:
    Into target prompt:
    into index = -1, prompt = input_lang: A -> target_lang:
    Then plot with latent_langs, target_lang, source_target_lang
    """
    if extra_langs is None:
        extra_langs = []
    if isinstance(extra_langs, str):
        extra_langs = [extra_langs]
    model_name = model.split("/")[-1]
    global source_df, target_df, target_prompts, target_probs, latent_probs, source_prompts
    if exp_id is None:
        exp_id = str(int(time()))
    else:
        exp_id = str(exp_id)
    source_df = get_translations(
        source_input_lang,
        ulist([source_output_lang, input_lang, target_lang, *extra_langs]),
        num_words,
    )
    target_df = get_translations(
        input_lang,
        ulist([source_output_lang, source_input_lang, target_lang, *extra_langs]),
        num_words,
    )

    _source_prompts = translation_prompts(
        source_df,
        nn_model.tokenizer,
        source_input_lang,
        source_output_lang,
        [target_lang, *extra_langs],
        augment_tokens=False,
        n=num_few_shot,
    )
    _target_prompts = translation_prompts(
        target_df,
        nn_model.tokenizer,
        input_lang,
        target_lang,
        [source_output_lang, *extra_langs],
        augment_tokens=False,
        n=num_few_shot,
    )
    try:
        source_prompts, target_prompts = get_shifted_prompt_pairs(
            source_df,
            target_df,
            _source_prompts,
            _target_prompts,
            source_input_lang,
            source_output_lang,
            input_lang,
            target_lang,
            extra_langs,
            num_pairs,
            merge_extra_langs=True,
        )
    except NotEnoughPromptsError:
        return

    source_prompts_str = [p.prompt for p in source_prompts]

    def transverse_patchscope(nn_model, prompt_batch, scan):
        offset = transverse_patchscope.offset
        target_pathscope_prompts = TargetPromptBatch.from_prompts(prompt_batch, -1)
        source_prompt_batch = source_prompts_str[offset : offset + len(prompt_batch)]
        transverse_patchscope.offset += len(prompt_batch)
        return patchscope_lens(
            nn_model,
            source_prompt_batch,
            target_pathscope_prompts,
            scan=scan,
            remote=remote,
        )

    transverse_patchscope.offset = 0
    target_probs, latent_probs = run_prompts(
        nn_model, target_prompts, batch_size=batch_size, get_probs=transverse_patchscope
    )

    # Get the baseline to normalize the plots
    source_prompts_probs, _ = run_prompts(
        nn_model,
        source_prompts,
        batch_size=batch_size,
        get_prob_kwargs=dict(remote=remote),
    )
    target_prompts_probs, _ = run_prompts(
        nn_model,
        target_prompts,
        batch_size=batch_size,
        get_prob_kwargs=dict(remote=remote),
    )

    json_dic = {
        target_lang: target_probs.tolist(),
        "source prompt probs": source_prompts_probs.squeeze().tolist(),
        "target prompt probs": target_prompts_probs.squeeze().tolist(),
    }
    for label, probs in latent_probs.items():
        json_dic[label] = probs.tolist()
    path = (
        Path("results")
        / model_name
        / "shifted_translation"
        / (f"{source_input_lang}_{source_output_lang}-{input_lang}_{target_lang}")
    )
    path.mkdir(parents=True, exist_ok=True)
    json_file = path / (exp_id + ".json")
    with open(json_file, "w") as f:
        json.dump(json_dic, f, indent=4)

    fig, ax = plt.subplots(figsize=(10, 5))

    # Raw probabilities plot
    plot_results(
        ax,
        target_probs,
        latent_probs,
        target_lang,
        source_baseline=source_prompts_probs.mean(),
        target_baseline=target_prompts_probs.mean(),
    )
    ax.legend()
    title = f"{model_name}: HeteroPatch from ({source_input_lang} -> {source_output_lang}) into ({input_lang} -> {target_lang})"
    ax.set_title(title)
    plt.tight_layout()
    plot_file = path / (exp_id + ".png")
    plt.savefig(plot_file, dpi=300, bbox_inches="tight")
    plt.show()

    # Plot k examples
    fig, axes = k_subplots(k)
    plot_k_results(axes, target_probs, latent_probs, target_lang, k=k)
    axes[k - 1].legend()
    fig.suptitle(title)
    plt_file = path / (exp_id + "_k.png")
    fig.savefig(plt_file, dpi=300, bbox_inches="tight")
    fig.show()
    # Compute a single example
    json_meta = {}
    for i in range(k):
        json_meta[i] = {
            "source input lang": source_input_lang,
            "source target lang": source_output_lang,
            "input lang": input_lang,
            "target lang": target_lang,
            "source prompt": source_prompts_str[i],
            "source prompt target": source_prompts[i].target_strings,
            "source prompt latent": source_prompts[i].latent_strings,
            "target prompt": target_prompts[i].prompt,
            "target prompt target": target_prompts[i].target_strings,
            "target prompt latent": target_prompts[i].latent_strings,
        }
    json_df = pd.DataFrame(json_meta)
    with pd.option_context(
        "display.max_colwidth",
        None,
        "display.max_columns",
        None,
        "display.max_rows",
        None,
    ):
        display(json_df)
    target_prompt_batch = TargetPromptBatch.from_prompts(
        [p.prompt for p in target_prompts[:k]], -1
    )
    probs = patchscope_lens(
        nn_model, source_prompts_str[:k], target_prompt_batch, remote=remote
    )
    file = path / (exp_id + "_heatmap.png")
    plot_topk_tokens(probs, nn_model, title=title, file=file)

    meta_file = path / (exp_id + "_heatmap.meta.json")
    with open(meta_file, "w") as f:
        json.dump(json_meta, f, indent=4)

In [None]:
paper_args = [
    ("de", "it", "fr", "zh"),
]
for args in paper_args:
    shifted_translation_plot(*args, extra_langs=["en"])