# Setup

In [None]:
%load_ext autoreload
%autoreload 3

In [None]:
import torch as th
import numpy as np
import pandas as pd
import json
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import seaborn as sns
from pathlib import Path
from time import time
import itertools
from random import shuffle
from torch.utils.data import DataLoader

# Fix logger bug
import babelnet
from nnsight import logger

logger.disabled = True

_ = th.set_grad_enabled(False)

In [None]:
exp_name = "Logit Lens on Definitions"

## Papermill args

In [None]:
langs = ["fr", "de", "ru", "en", "zh"]
batch_size = 8
model = "Llama-2-7b"
device = "auto"
# model_path = "/dlabscratch1/public/llm_weights/llama2_hf/Llama-2-7b-hf"
model_path = None
trust_remote_code = False
extra_args = []
exp_id = None
paper_only = False
prob_treshold = 0.3

## CL Args

In [None]:
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument("--thinking-langs", nargs="+", default=["en"])
args = parser.parse_args(extra_args)
print(f"args: {args}")

## Load and prepare

In [None]:
from exp_tools import load_model

model_name = model.split("/")[-1]
langs = np.array(langs)
if model_path is None:
    model_path = model
nn_model = load_model(
    model_path,
    trust_remote_code=trust_remote_code,
    device_map=device,
    # dispatch=True,
)
tokenizer = nn_model.tokenizer

## Plots

In [None]:
from exp_tools import (
    run_prompts,
    filter_prompts_by_prob,
    remove_colliding_prompts,
    logit_lens,
)
from translation_tools import cloze_prompts
from translation_tools import get_cloze_dataset

from utils import plot_ci, plot_k, plot_topk_tokens, k_subplots, ulist


def logit_lens_plots(
    lang,
    thinking_langs=args.thinking_langs,
    batch_size=batch_size,
    num_words=None,
    exp_id=None,
    num_examples=9,
):
    """
    func docstring
    """
    if thinking_langs is None:
        thinking_langs = []
    df = get_cloze_dataset(ulist([lang, *thinking_langs]), num_words=num_words)
    df = df[df[f"definitions_wo_ref_{lang}"].map(lambda x: x != [])]
    target_prompts = cloze_prompts(df, tokenizer, lang, thinking_langs)
    target_prompts = remove_colliding_prompts(target_prompts, ignore_langs=f"source_{lang}")
    target_prompts = filter_prompts_by_prob(
        target_prompts, nn_model, prob_treshold, batch_size=batch_size
    )
    if len(target_prompts) < num_examples:
        print(
            f"Skipping {lang} because only {len(target_prompts)} prompts available"
        )
        return
    print(f"Found {len(df)} definitions for {lang}")

    target_probs, latent_probs = run_prompts(
        nn_model, target_prompts, batch_size=batch_size, get_probs=logit_lens
    )

    json_dic = {
        lang: target_probs.tolist(),
    }
    for label, probs in latent_probs.items():
        json_dic[label] = probs.tolist()
    thinking_str = "_".join(thinking_langs)
    path = (
        Path("results")
        / model_name
        / exp_name
        / (f"{lang}--{thinking_str}")
    )
    path.mkdir(parents=True, exist_ok=True)
    json_file = path / (exp_id + ".json")
    with open(json_file, "w") as f:
        json.dump(json_dic, f, indent=4)

    # fig, (ax, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    fig, ax = plt.subplots(1, 1, figsize=(10, 5))
    colors = sns.color_palette("tab10", 1 + len(latent_probs))
    thinking_str = thinking_str.replace("_", " ")
    title = (
        f"{model_name}: {exp_name} {lang}"
    )
    plot_ci(ax, target_probs, label=lang, color=colors[0])
    for i, (label, probs) in enumerate(latent_probs.items()):
        plot_ci(ax, probs, label=label, color=colors[i + 1], init=False)
    ax.legend()
    ax.set_title(title)
    plt.tight_layout()
    plot_file = path / (exp_id + ".png")
    plt.savefig(plot_file, dpi=300, bbox_inches="tight")
    plt.show()

    # Plot k examples
    fig, axes = k_subplots(num_examples)
    plot_k(
        axes,
        target_probs[:num_examples],
        label=lang,
        color=colors[0],
        k=num_examples,
    )
    for i, (label, probs) in enumerate(latent_probs.items()):
        plot_k(
            axes,
            probs[:num_examples],
            label=label,
            color=colors[i + 1],
            init=False,
            k=num_examples,
        )
    axes[num_examples - 1].legend()
    fig.suptitle(title)
    plt_file = path / (exp_id + "_k.png")
    fig.savefig(plt_file, dpi=300, bbox_inches="tight")
    fig.show()
    # Compute a single example
    json_meta = {}
    for i in range(num_examples):
        json_meta[i] = {
            "definition lang": lang,
            "target prompt": target_prompts[i].prompt,
            "target prompt target": target_prompts[i].target_strings,
            "target prompt latent": target_prompts[i].latent_strings,
        }
    json_df = pd.DataFrame(json_meta)
    with pd.option_context(
        "display.max_colwidth",
        None,
        "display.max_columns",
        None,
        "display.max_rows",
        None,
    ):
        display(json_df)
    target_prompt_batch = [p.prompt for p in target_prompts[:num_examples]]
    probs = logit_lens(
        nn_model,
        target_prompt_batch,
        scan=True,
    )
    file = path / (exp_id + "_heatmap.png")
    plot_topk_tokens(probs, nn_model, title=title, file=file)

    meta_file = path / (exp_id + "_heatmap.meta.json")
    with open(meta_file, "w") as f:
        json.dump(json_meta, f, indent=4)

## Selected args for the paper

In [None]:
paper_args = []
for args_ in paper_args:
    th.cuda.empty_cache()
    logit_lens_plots(*args_, exp_id=exp_id)

## All plots

In [None]:
if not paper_only:
    for lang in langs:
        th.cuda.empty_cache()
        logit_lens_plots(
            lang,
            exp_id=exp_id,
        )