# Setup

In [None]:
%load_ext autoreload
%autoreload 3

In [None]:
import torch as th
import numpy as np
import pandas as pd
import json
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import seaborn as sns
from pathlib import Path
from time import time
import itertools
from random import shuffle
from torch.utils.data import DataLoader

# Fix logger bug
import babelnet
from nnsight import logger

logger.disabled = True

_ = th.set_grad_enabled(False)

In [None]:
exp_name = "{YOUR EXP NAME}"

## Papermill args

In [None]:
langs = ["fr", "de", "ru", "en", "zh"]
batch_size = 8
model = "Llama-2-7b"
device = "auto"
# model_path = "/dlabscratch1/public/llm_weights/llama2_hf/Llama-2-7b-hf"
model_path = None
trust_remote_code = False
extra_args = []
exp_id = None
paper_only = False
skip_paper = False
prob_treshold = 0.3
map_source_lang = None
map_source_lang_kwargs = {}
map_target_lang = None
map_target_lang_kwargs = {}
use_tl = False
num_few_shot = 5

## CL Args

In [None]:
from argparse import ArgumentParser

parser = ArgumentParser()
args = parser.parse_args(extra_args)
print(f"args: {args}")

## Loading and arg preprocessing

In [None]:
from exp_tools import load_model
import prompt_tools
from functools import partial

model_name = model.split("/")[-1]
langs = np.array(langs)
if model_path is None:
    model_path = model
nn_model = load_model(
    model_path,
    trust_remote_code=trust_remote_code,
    device_map=device,
    use_tl=use_tl,
    # dispatch=True,
)
tokenizer = nn_model.tokenizer

if isinstance(map_source_lang, str):
    map_source_lang = getattr(prompt_tools, map_source_lang)
    map_source_lang = partial(map_source_lang, **map_source_lang_kwargs)
if isinstance(map_target_lang, str):
    map_target_lang = getattr(prompt_tools, map_target_lang)
    map_target_lang = partial(map_target_lang, **map_target_lang_kwargs)

## Plots

In [None]:
from exp_tools import (
    run_prompts,
    next_token_probs,
    filter_prompts_by_prob,
    remove_colliding_prompts,
)
from prompt_tools import translation_prompts
from translation_tools import get_bn_dataset as get_translations

from display_utils import plot_topk_tokens, k_subplots, plot_results, plot_k_results


def your_ploting_func(
    input_lang,
    target_lang,
    extra_langs=None,
    batch_size=batch_size,
    num_words=None,
    exp_id=None,
    num_examples=9,
):
    """
    func docstring
    """
    if extra_langs is None:
        extra_langs = []
    if isinstance(extra_langs, str):
        extra_langs = [extra_langs]
    # global foo  # Var you might want to access for debugging purposes
    # foo = 2
    df = get_translations(
        input_lang,
        [target_lang, *extra_langs],
        num_words=num_words,
    )
    no_proc = []
    if map_source_lang is not None:
        df[input_lang + " no proc"] = df[input_lang]
        df[input_lang] = df[input_lang].apply(map_source_lang)
        no_proc.append(input_lang + " no proc")
    if map_target_lang is not None:
        df[target_lang + " no proc"] = df[target_lang]
        df[target_lang] = df[target_lang].apply(map_target_lang)
        no_proc.append(target_lang + " no proc")

    target_prompts = translation_prompts(
        df, tokenizer, input_lang, target_lang, extra_langs, n=num_few_shot
    )
    target_prompts = remove_colliding_prompts(target_prompts, ignore_langs=no_proc)
    print(f"Number of non-colliding prompts: {len(target_prompts)}")
    target_prompts = filter_prompts_by_prob(
        target_prompts, nn_model, prob_treshold, batch_size=batch_size
    )
    print(f"Number of prompts after filtering: {len(target_prompts)}")
    if len(target_prompts) < num_examples:
        print("Not enough prompts after filtering")
        return

    def get_prob_func(nn_model, prompt_batch, scan):
        return next_token_probs(
            nn_model,
            prompt_batch,
        ).unsqueeze(1)

    target_probs, latent_probs = run_prompts(
        nn_model, target_prompts, batch_size=batch_size, get_probs=get_prob_func
    )

    json_dic = {
        target_lang: target_probs.tolist(),
    }
    for label, probs in latent_probs.items():
        json_dic[label] = probs.tolist()
    path = (
        Path("results") / model_name / exp_name / (f"{input_lang}_{target_lang}")
    )
    path.mkdir(parents=True, exist_ok=True)
    json_file = path / (exp_id + ".json")
    with open(json_file, "w") as f:
        json.dump(json_dic, f, indent=4)
    print(f"Saved data to {json_file.resolve()}")

    fig, ax = plt.subplots(1, 1, figsize=(10, 5))
    pref = pref.replace("_", " ")
    title = (
        f"{model_name}: {exp_name} from ({pref}) into ({input_lang} -> {target_lang})"
    )
    plot_results(ax, target_probs, latent_probs, target_lang)
    ax.legend()
    ax.set_title(title)
    plt.tight_layout()
    plot_file = path / (exp_id + ".png")
    plt.savefig(plot_file, dpi=300, bbox_inches="tight")
    plt.show()

    # Plot k examples
    fig, axes = k_subplots(num_examples)
    plot_k_results(axes, target_probs, latent_probs, target_lang, num_examples)
    axes[num_examples - 1].legend()
    fig.suptitle(title)
    plt_file = path / (exp_id + "_k.png")
    fig.savefig(plt_file, dpi=300, bbox_inches="tight")
    fig.show()
    # Compute a single example
    json_heatmap = {}
    for i in range(num_examples):
        json_heatmap[i] = {
            "input lang": input_lang,
            "target lang": target_lang,
            "target prompt": target_prompts[i].prompt,
            "target prompt target": target_prompts[i].target_strings,
            "target prompt latent": target_prompts[i].latent_strings,
        }
    json_df = pd.DataFrame(json_heatmap)
    with pd.option_context(
        "display.max_colwidth",
        None,
        "display.max_columns",
        None,
        "display.max_rows",
        None,
    ):
        display(json_df)
    target_prompt_batch = [p.prompt for p in target_prompts[:num_examples]]
    probs = get_prob_func(
        nn_model,
        target_prompt_batch,
        scan=True,
    )
    file = path / (exp_id + "_heatmap.png")
    plot_topk_tokens(probs, nn_model, title=title, file=file)

    meta_file = path / (exp_id + "_heatmap.meta.json")
    with open(meta_file, "w") as f:
        json.dump(json_heatmap, f, indent=4)

## Selected args for the paper

In [None]:
if not skip_paper:
    paper_args = []
    for f_args in paper_args:
        th.cuda.empty_cache()
        your_ploting_func(*f_args, exp_id=exp_id)

## All plots

In [None]:
if not paper_only:
    for in_lang in langs:
        for out_lang in langs:
            if in_lang == out_lang:
                continue
            # ... more nested loops
            th.cuda.empty_cache()
            your_ploting_func(
                in_lang,
                out_lang,
                exp_id=exp_id,
            )