In [None]:
%load_ext autoreload
%autoreload 3
import sys

sys.path.append("..")

In [None]:
from nnsight import LanguageModel
from transformers import AutoTokenizer
import torch.nn.functional as F
import torch as th
import numpy as np
import pandas as pd
import json
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import seaborn as sns
from pathlib import Path
from time import time
import itertools
from random import shuffle
from torch.utils.data import DataLoader

# Fix logger bug
import babelnet
from nnsight import logger

logger.disabled = True

_ = th.set_grad_enabled(False)

In [None]:
model = "croissantllm/CroissantLLMBase"
model_path = None
check_translation_performance = False
batch_size = 64
thinking_langs = ["en", "fr"]
langs = ["en", "zh", "fr", "ru", "de"]
method = "logit_lens"
trust_remote_code = False
device = "auto"

In [None]:
del method  # Not used in this notebook
langs = np.array(langs)
out_langs = {lang: np.array([l for l in langs if l != lang]) for lang in langs}
if model_path is None:
    model_path = model
tokenizer = None
if trust_remote_code:
    from transformers import AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
nn_model = LanguageModel(
    model_path,
    device_map=device,
    torch_dtype=th.float16,
    trust_remote_code=trust_remote_code,
    tokenizer=tokenizer,
    dispatch=True,
)

## Plots

In [None]:
from exp_tools import (
    run_prompts,
    BatchPatchscopePrompt,
    patchscope_lens_llama,
)
from translation_tools import translation_prompts

from translation_tools import get_bn_dataset as get_translations

# from translation_tools import get_gpt4_dataset as get_translations
from utils import plot_ci, plot_k, plot_topk_tokens
from copy import deepcopy


def shifted_cross_translation_plot(
    source_input_lang,
    source_target_lang,
    input_lang,
    target_lang,
    extra_langs=None,
    batch_size=batch_size,
    num_words=None,
    num_pairs=200,
    ax=None,
    time_=None,
    k=4,
):
    """
    Patchscope with source hidden from:
    index -1 and Prompt = source_input_lang: A -> source_target_lang:
    Into target prompt:
    into index = -1, prompt = input_lang: A -> target_lang:
    Then plot with latent_langs, target_lang, source_target_lang
    """
    if extra_langs is None:
        extra_langs = []
    if isinstance(extra_langs, str):
        extra_langs = [extra_langs]
    model_name = model.split("/")[-1]
    global source_df, target_df, target_prompts, target_probs, latent_probs, source_prompts
    if time_ is None:
        time_ = str(int(time()))
    else:
        time_ = str(time_)
    source_df = get_translations(
        source_input_lang,
        [source_target_lang, input_lang, target_lang],
        num_words,
    )
    target_df = get_translations(
        input_lang,
        [source_target_lang, source_input_lang, target_lang],
        num_words,
    )

    _source_prompts = translation_prompts(
        source_df,
        nn_model.tokenizer,
        source_input_lang,
        source_target_lang,
        [target_lang],
    )
    _target_prompts = translation_prompts(
        target_df,
        nn_model.tokenizer,
        input_lang,
        target_lang,
        [source_target_lang, *extra_langs],
    )

    collected_pairs = 0
    source_prompts = []
    target_prompts = []
    source_target = list(itertools.product(source_df.iterrows(), target_df.iterrows()))
    shuffle(source_target)

    for (i, source_row), (j, target_row) in source_target:
        if source_row["word_original"] == target_row["word_original"]:
            continue
        # Check for token overlap
        source_prompts.append(deepcopy(_source_prompts[i]))
        target_prompts.append(deepcopy(_target_prompts[j]))
        collected_pairs += 1
        if collected_pairs >= num_pairs:
            break

    for targ_p, src_p in zip(target_prompts, source_prompts):
        targ_p.latent_tokens[f"source_{source_target_lang}"] = src_p.target_tokens
        targ_p.latent_tokens[f"source_{target_lang}"] = src_p.latent_tokens[target_lang]
        targ_p.latent_strings[f"source_{source_target_lang}"] = src_p.target_string
        targ_p.latent_strings[f"source_{target_lang}"] = src_p.latent_strings[
            target_lang
        ]
    source_prompts_str = [p.prompt for p in source_prompts]

    def transverse_patchscope(nn_model, prompt_batch, scan):
        offset = transverse_patchscope.offset
        target_pathscope_prompts = BatchPatchscopePrompt.from_prompts(prompt_batch, -1)
        source_prompt_batch = source_prompts_str[offset : offset + len(prompt_batch)]
        transverse_patchscope.offset += len(prompt_batch)
        return patchscope_lens_llama(
            nn_model, source_prompt_batch, target_pathscope_prompts, scan=scan
        )

    transverse_patchscope.offset = 0
    target_probs, latent_probs = run_prompts(
        nn_model, target_prompts, batch_size=batch_size, method=transverse_patchscope
    )

    # Get the baseline to normalize the plots
    source_prompts_probs, _ = run_prompts(
        nn_model, source_prompts, batch_size=batch_size, method="next_token_probs"
    )
    source_prompts_baseline = source_prompts_probs.mean()
    target_prompts_probs, _ = run_prompts(
        nn_model, target_prompts, batch_size=batch_size, method="next_token_probs"
    )
    target_prompts_baseline = target_prompts_probs.mean()

    json_dic = {
        target_lang: target_probs.tolist(),
        "source prompt probs": source_prompts_probs.squeeze().tolist(),
        "target prompt probs": target_prompts_probs.squeeze().tolist(),
    }
    for label, probs in latent_probs.items():
        json_dic[label] = probs.tolist()
    path = (
        Path("results")
        / model_name
        / "shifted_cross_translation"
        / (f"{source_input_lang}_{source_target_lang}-{input_lang}_{target_lang}-")
    )
    path.mkdir(parents=True, exist_ok=True)
    json_file = path / (time_ + ".json")
    with open(json_file, "w") as f:
        json.dump(json_dic, f, indent=4)

    if ax is None:
        fig, ax = plt.subplots()
    colors = sns.color_palette("tab10", 4 + len(extra_langs))
    plot_ci(
        ax, target_probs / target_prompts_baseline, label=target_lang, color=colors[0]
    )
    for i, (label, probs) in enumerate(latent_probs.items()):
        if "source" in label:
            baseline = source_prompts_baseline
        else:
            baseline = target_prompts_baseline
        plot_ci(ax, probs / baseline, label=label, color=colors[i + 1], init=False)
    ax.legend()
    title = f"{model_name}: HeteroPatch from ({source_input_lang} -> {source_target_lang}) into ({input_lang} -> {target_lang})"
    ax.set_title(title)
    # Save the plot
    plot_file = path / (time_ + ".png")
    plt.savefig(plot_file, dpi=300, bbox_inches="tight")
    plt.show()
    # change the size
    fig, axes = plt.subplots(1, k, figsize=(5 * k, 5))

    plot_k(
        axes,
        target_probs[:k],
        label=target_lang,
        color=colors[0],
        k=k,
    )
    for i, (label, probs) in enumerate(latent_probs.items()):
        plot_k(
            axes,
            probs[:k],
            label=label,
            color=colors[i + 1],
            init=False,
            k=k,
        )
    axes[-1].legend()
    fig.suptitle(title + " - Raw probabilities")
    plt_file = path / (time_ + "_k.png")
    fig.savefig(plt_file, dpi=300, bbox_inches="tight")
    fig.show()
    # Compute a single example
    json_meta = {}
    for i in range(k):
        json_meta[i] = {
            "source input lang": source_input_lang,
            "source target lang": source_target_lang,
            "input lang": input_lang,
            "target lang": target_lang,
            "source prompt": source_prompts_str[i],
            "source prompt target": source_prompts[i].target_string,
            "source prompt latent": source_prompts[i].latent_strings,
            "target prompt": target_prompts[i].prompt,
            "target prompt target": target_prompts[i].target_string,
            "target prompt latent": target_prompts[i].latent_strings,
        }
    json_df = pd.DataFrame(json_meta)
    with pd.option_context(
        "display.max_colwidth",
        None,
        "display.max_columns",
        None,
        "display.max_rows",
        None,
    ):
        display(json_df)
    target_prompt_batch = BatchPatchscopePrompt.from_prompts(
        [p.prompt for p in target_prompts[:k]], -1
    )
    probs = patchscope_lens_llama(nn_model, source_prompts_str[:k], target_prompt_batch)
    file = path / (time_ + "_heatmap.png")
    plot_topk_tokens(probs, nn_model, title=title, file=file)

    meta_file = path / (time_ + "_heatmap.meta.json")
    with open(meta_file, "w") as f:
        json.dump(json_meta, f, indent=4)

In [None]:
time_ = int(time())
print(f"Experiment time id: {time_}")

In [None]:
for source_input_lang in langs:
    for source_target_lang in out_langs[source_input_lang]:
        for in_lang in langs:
            for out_lang in out_langs[in_lang]:
                th.cuda.empty_cache()
                shifted_cross_translation_plot(
                    source_input_lang,
                    source_target_lang,
                    in_lang,
                    out_lang,
                    time_=time_,
                )
                plt.show()