# Setup

In [None]:
%load_ext autoreload
%autoreload 3

In [None]:
import torch as th
import numpy as np
import pandas as pd
import json
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import seaborn as sns
from pathlib import Path
from time import time
import itertools
from random import shuffle
from torch.utils.data import DataLoader

# Fix logger bug
import babelnet
from nnsight import logger

logger.disabled = True

_ = th.set_grad_enabled(False)

In [None]:
exp_name = "lang_patchscope"

## Papermill args

In [None]:
langs = ["fr", "de", "ru", "en", "zh"]
batch_size = 8
model = "Llama-2-7b"
device = "auto"
# model_path = "/dlabscratch1/public/llm_weights/llama2_hf/Llama-2-7b-hf"
model_path = None
trust_remote_code = False
extra_args = []
exp_id = None
paper_only = False
prob_treshold = 0.3

## CL Args

In [None]:
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument("--use-obj", "-o", action="store_true")
args = parser.parse_args(extra_args)
use_obj = args.use_obj
print(f"args: {args}")

## Load and prepare

In [None]:
from exp_tools import load_model

model_name = model.split("/")[-1]
langs = np.array(langs)
if model_path is None:
    model_path = model
nn_model = load_model(
    model_path,
    trust_remote_code=trust_remote_code,
    device_map=device,
    # dispatch=True,
)
tokenizer = nn_model.tokenizer

## Plots

In [None]:
from exp_tools import (
    run_prompts,
)
from interventions import patchscope_lens, TargetPromptBatch
from prompt_tools import lang_few_shot_prompts, translation_prompts
from translation_tools import get_cloze_dataset, get_bn_dataset

from display_utils import plot_topk_tokens, k_subplots, plot_results, plot_k_results
from utils import ulist


def plt_lang_patch(
    input_lang,
    target_lang,
    few_shot_langs,
    extra_langs=None,
    batch_size=batch_size,
    num_words=None,
    exp_id=None,
    num_examples=9,
    use_obj=use_obj,
):
    """
    func docstring
    """
    if extra_langs is None:
        extra_langs = []
    if isinstance(extra_langs, str):
        extra_langs = [extra_langs]
    global target_prompts, target_probs, latent_probs, source_prompts
    df_tr = get_bn_dataset(input_lang, target_lang, num_words=num_words)
    source_prompts = translation_prompts(
        df_tr, tokenizer, input_lang, target_lang, cut_at_obj=use_obj
    )
    source_prompts = [p.prompt for p in source_prompts]
    df = get_cloze_dataset(
        ulist([input_lang, target_lang, *extra_langs, *few_shot_langs]),
        num_words=num_words,
    )
    target_prompts = lang_few_shot_prompts(
        df,
        tokenizer,
        few_shot_langs,
        target_lang,
        [input_lang, *extra_langs],
        num_prompts=len(source_prompts),
    )

    def lang_patchscope(nn_model, prompt_batch, scan):
        src_prompts = source_prompts[
            lang_patchscope.idx : lang_patchscope.idx + len(prompt_batch)
        ]
        lang_patchscope.idx += len(prompt_batch)
        tgt_prompts = TargetPromptBatch.from_prompts(prompt_batch, -2)
        return patchscope_lens(nn_model, src_prompts, tgt_prompts, scan=scan)

    lang_patchscope.idx = 0

    target_probs, latent_probs = run_prompts(
        nn_model, target_prompts, batch_size=batch_size, get_probs=lang_patchscope
    )

    json_dic = {
        target_lang: target_probs.tolist(),
    }
    for label, probs in latent_probs.items():
        json_dic[label] = probs.tolist()
    pref = "_".join([])
    path = (
        Path("results")
        / model_name
        / exp_name
        / (f"{pref}-{input_lang}_{target_lang}-")
    )
    path.mkdir(parents=True, exist_ok=True)
    json_file = path / (exp_id + ".json")
    with open(json_file, "w") as f:
        json.dump(json_dic, f, indent=4)

    # fig, (ax, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    fig, ax = plt.subplots(1, 1, figsize=(10, 5))
    pref = pref.replace("_", " ")
    title = f"{model_name}: {exp_name} from ({input_lang} -> {target_lang})"
    plot_results(ax, target_probs, latent_probs, target_lang)
    ax.legend()
    ax.set_title(title)
    plt.tight_layout()
    plot_file = path / (exp_id + ".png")
    plt.savefig(plot_file, dpi=300, bbox_inches="tight")
    plt.show()

    # Plot k examples
    fig, axes = k_subplots(num_examples)
    plot_k_results(axes, target_probs, latent_probs, target_lang, num_examples)
    axes[num_examples - 1].legend()
    fig.suptitle(title)
    plt_file = path / (exp_id + "_k.png")
    fig.savefig(plt_file, dpi=300, bbox_inches="tight")
    fig.show()
    # Compute a single example
    json_meta = {}
    for i in range(num_examples):
        json_meta[i] = {
            "input lang": input_lang,
            "target lang": target_lang,
            "few shot langs": few_shot_langs,
            "extra langs": extra_langs,
            "source prompt": source_prompts[i],
            "target prompt": target_prompts[i].prompt,
            "target prompt target": target_prompts[i].target_strings,
            "target prompt latent": target_prompts[i].latent_strings,
        }
    json_df = pd.DataFrame(json_meta)
    with pd.option_context(
        "display.max_colwidth",
        None,
        "display.max_columns",
        None,
        "display.max_rows",
        None,
    ):
        display(json_df)
    target_prompt_batch = [p.prompt for p in target_prompts[:num_examples]]
    lang_patchscope.idx = 0
    probs = lang_patchscope(
        nn_model,
        target_prompt_batch,
        scan=True,
    )
    file = path / (exp_id + "_heatmap.png")
    plot_topk_tokens(probs, nn_model, title=title, file=file)

    meta_file = path / (exp_id + "_heatmap.meta.json")
    with open(meta_file, "w") as f:
        json.dump(json_meta, f, indent=4)

## Selected args for the paper

In [None]:
paper_args = [
    ("de", "it"),
    ("it", "de"),
    ("fr", "zh"),
    ("en", "zh"),
    ("zh", "fr"),
    ("zh", "en"),
]
fs_langs = ["fr", "de", "ru", "en", "zh", "it"]
for f_args in paper_args:
    th.cuda.empty_cache()
    plt_lang_patch(*f_args, fs_langs, extra_langs="en", exp_id=exp_id)

## All plots

In [None]:
# if not paper_only:
#     for in_lang in langs:
#         for out_lang in langs:
#             if in_lang == out_lang:
#                 continue
#             # ... more nested loops
#             th.cuda.empty_cache()
#             plt_lang_patch(
#                 in_lang,
#                 out_lang,
#                 exp_id=exp_id,
#             )