# Setup

In [None]:
INTERACTIVE_MODE = False  # Set to True to run the notebook interactively

import sys

if INTERACTIVE_MODE:
    sys.path.append("../src")
    %load_ext autoreload
    %autoreload 3
else:
    sys.path.append("./src")

In [None]:
import torch as th
import pandas as pd
import json
import matplotlib.pyplot as plt
from pathlib import Path


_ = th.set_grad_enabled(False)

In [None]:
exp_name = "translation_lens"

## Papermill args

In [None]:
# papermill parameters
batch_size = 8
model = "meta-llama/Llama-2-7b-hf"
model_path = None
trust_remote_code = False
device = "auto"
remote = False
extra_args = []
exp_id = None
num_few_shot = 5
use_tl = False

## CL Args

In [None]:
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument("--lens", type=str, default="patchscope")
parser.add_argument("--thinking-langs", "-t", type=str, nargs="*", default=["en"])
parser.add_argument("--map-source-lang", type=str, default=None)
parser.add_argument("--map-source-lang_kwargs", type=dict, default={})
parser.add_argument("--map-target-lang", type=str, default=None)
parser.add_argument("--map-target-lang_kwargs", type=dict, default={})
parser.add_argument("--prob-treshold", type=float, default=0.0)

args = parser.parse_args(extra_args)
print(f"args: {args}")

## Load and prepare

In [None]:
from exp_tools import load_model
from interventions import logit_lens, patchscope_lens
import prompt_tools
from functools import partial

model_name = model.split("/")[-1]
if model_path is None:
    model_path = model
nn_model = load_model(
    model_path,
    trust_remote_code=trust_remote_code,
    device_map=device,
    use_tl=use_tl,
)
tokenizer = nn_model.tokenizer


if isinstance(args.map_source_lang, str):
    args.map_source_lang = getattr(prompt_tools, args.map_source_lang)
    args.map_source_lang = partial(args.map_source_lang, **args.map_source_lang_kwargs)
if isinstance(args.map_target_lang, str):
    args.map_target_lang = getattr(prompt_tools, args.map_target_lang)
    args.map_target_lang = partial(args.map_target_lang, **args.map_target_lang_kwargs)

lens_name = args.lens
if lens_name == "logit_lens":
    lens_func = logit_lens
elif "patchscope" in args.lens:
    lens_func = patchscope_lens
else:
    raise ValueError(f"Invalide method:  {args.lens}")

## Plots

In [None]:
from exp_tools import (
    run_prompts,
    filter_prompts_by_prob,
    remove_colliding_prompts,
)
from prompt_tools import translation_prompts
from load_dataset import get_word_translation_dataset as get_translations

from display_utils import (
    plot_topk_tokens,
    k_subplots,
    plot_results,
    plot_k_results,
)


def plot_lens_results(
    input_lang,
    output_lang,
    thinking_langs=None,
    batch_size=batch_size,
    num_words=None,
    exp_id=None,
    num_examples=9,
):
    if thinking_langs is None:
        thinking_langs = []
    thinking_langs = [l for l in thinking_langs if l != output_lang]
    df = get_translations(
        input_lang,
        [output_lang, *thinking_langs],
        num_words=num_words,
    )
    no_proc = []
    if args.map_source_lang is not None:
        df[input_lang + " no proc"] = df[input_lang]
        df[input_lang] = df[input_lang].apply(args.map_source_lang)
        no_proc.append(input_lang + " no proc")
    if args.map_target_lang is not None:
        df[output_lang + " no proc"] = df[output_lang]
        df[output_lang] = df[output_lang].apply(args.map_target_lang)
        no_proc.append(output_lang + " no proc")
    target_prompts = translation_prompts(
        df,
        tokenizer,
        input_lang,
        output_lang,
        thinking_langs + no_proc,
        augment_tokens=False,
        n=num_few_shot,
    )
    target_prompts = remove_colliding_prompts(target_prompts, ignore_langs=no_proc)
    print(f"Number of non-colliding prompts: {len(target_prompts)}")
    target_prompts = filter_prompts_by_prob(
        target_prompts, nn_model, args.prob_treshold, batch_size=batch_size
    )
    print(f"Number of prompts after filtering: {len(target_prompts)}")
    if len(target_prompts) < num_examples:
        print("Not enough prompts after filtering")
        return

    if len(target_prompts) == 0:
        print(f"No prompts left after filtering for {input_lang} -> {output_lang}")
        return

    target_probs, latent_probs = run_prompts(
        nn_model,
        target_prompts,
        batch_size=batch_size,
        get_probs=lens_func,
        get_probs_kwargs=dict(remote=remote),
    )

    json_dic = {
        output_lang: target_probs.tolist(),
    }
    for label, probs in latent_probs.items():
        json_dic[label] = probs.tolist()
    path = (
        Path("results")
        / model_name
        / exp_name
        / (f"{pref}-{input_lang}_{output_lang}-")
    )
    path.mkdir(parents=True, exist_ok=True)
    json_file = path / (exp_id + ".json")
    with open(json_file, "w") as f:
        json.dump(json_dic, f, indent=4)

    fig, ax = plt.subplots(1, 1, figsize=(10, 5))
    pref = pref.replace("_", " ")
    title = f"{model_name}: {lens_name} for {input_lang} -> {output_lang}"
    plot_results(ax, target_probs, latent_probs, output_lang)
    ax.legend()
    ax.set_title(title)
    plt.tight_layout()
    plot_file = path / (exp_id + ".png")
    plt.savefig(plot_file, dpi=300, bbox_inches="tight")
    plt.show()

    # Plot k examples
    fig, axes = k_subplots(num_examples)
    plot_k_results(axes, target_probs, latent_probs, output_lang, num_examples)
    axes[num_examples - 1].legend()
    fig.suptitle(title)
    plt_file = path / (exp_id + "_k.png")
    fig.savefig(plt_file, dpi=300, bbox_inches="tight")
    fig.show()
    # Compute a single example
    json_meta = {}
    for i in range(num_examples):
        json_meta[i] = {
            "input lang": input_lang,
            "target lang": output_lang,
            "target prompt": target_prompts[i].prompt,
            "target prompt target": target_prompts[i].target_strings,
            "target prompt latent": target_prompts[i].latent_strings,
        }
    json_df = pd.DataFrame(json_meta)
    with pd.option_context(
        "display.max_colwidth",
        None,
        "display.max_columns",
        None,
        "display.max_rows",
        None,
    ):
        display(json_df)
    target_prompt_batch = [p.prompt for p in target_prompts[:num_examples]]
    probs = lens_func(
        nn_model,
        target_prompt_batch,
        scan=True,
        remote=remote,
    )
    file = path / (exp_id + "_heatmap.png")
    plot_topk_tokens(probs, nn_model, title=title, file=file)

    meta_file = path / (exp_id + "_heatmap.meta.json")
    with open(meta_file, "w") as f:
        json.dump(json_meta, f, indent=4)

## Selected args for the paper

In [None]:
paper_args = [["de", "it"]]
for f_args in paper_args:
    th.cuda.empty_cache()
    plot_lens_results(*f_args, exp_id=exp_id, thinking_langs=args.thinking_langs.copy())