In [1]:
INTERACTIVE_MODE = False  # Set to True to run the notebook interactively
DEBUG = False
import sys

if INTERACTIVE_MODE:
    sys.path.append("../src")
    %load_ext autoreload
    %autoreload 2
    from tqdm.notebook import tqdm, trange
else:
    sys.path.append("./src")
    from tqdm import tqdm, trange

In [2]:
import torch as th
import numpy as np
import pandas as pd
import json
import matplotlib.pyplot as plt
from pathlib import Path
from time import time
import itertools
from random import shuffle

_ = th.set_grad_enabled(False)

In [3]:
exp_name = "mean_repr_def"

In [4]:
# papermill parameters
batch_size = 8
# model = "meta-llama/Llama-2-7b-hf"
model = "google/gemma-2-2b"
model_path = None
trust_remote_code = False
# device = "auto"
device = "cuda:1"
remote = False
num_few_shot = 5
exp_id = "test" + str(int(time()))
extra_args = []
use_tl = False

In [5]:
from nnterp import load_model
from dlabutils import model_path as dlab_model_path
from argparse import ArgumentParser
from sentence_transformers import SentenceTransformer

parser = ArgumentParser()
parser.add_argument("--gen-batch-size", type=int, default=batch_size)
parser.add_argument(
    "--embeddings-model",
    type=str,
    default="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
)
parser.add_argument("--emb-batch-size", type=int, default=256)
pargs = parser.parse_args(extra_args)
gen_batch_size = pargs.gen_batch_size
emb_batch_size = pargs.emb_batch_size
embeddings_model = SentenceTransformer(dlab_model_path(pargs.embeddings_model), device=device)

if model_path is None:
    model_path = dlab_model_path(model)
nn_model = load_model(
    model_path,
    trust_remote_code=trust_remote_code,
    device_map=device,
    use_tl=use_tl,
    no_space_on_bos=True,
)
tokenizer = nn_model.tokenizer
model_name = model.split("/")[-1]



In [20]:
from nnterp.nnsight_utils import (
    collect_activations_batched,
    get_num_layers,
    get_layer_output,
)
from random import sample
from load_dataset import get_word_translation_dataset, get_cloze_dataset, load_synset
from utils import ulist
from prompt_tools import translation_prompts, def_prompt, get_obj_id
from torch.utils.data import DataLoader
import ast
import itertools
import warnings
from copy import deepcopy
from collections import defaultdict
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go


def extract_def(generation):
    match generation[-2], generation[-1]:
        case '"', "\n":
            pass
        case _, "\n":
            end = generation.split("\n")[-2]
            warnings.warn(f"Last line does not end with a quote: {end}")
            generation = generation[:-1] + '"\n'
        case _, '"':
            generation += "\n"
        case _:
            end = generation.split("\n")[-1]
            warnings.warn(f"Last line does not end with a quote and newline: {end}")
            generation += '"\n'

    last_line = generation.split("\n")[-2]
    split = last_line.split('"')
    return '"'.join(split[3:-1])


def patched_generation(prompts, reprs, gen_batch_size):
    pos = get_obj_id(prompts[0], tokenizer)
    out = []
    for i in trange(0, len(prompts), gen_batch_size):
        end = min(i + gen_batch_size, len(prompts))
        with nn_model.generate(
            prompts[i:end],
            max_new_tokens=50,
            stop_strings=["\n"],
            tokenizer=nn_model.tokenizer,
        ):
            for layer in range(get_num_layers(nn_model)):
                get_layer_output(nn_model, layer)[:, pos] = reprs[i:end, layer]
            out_model = nn_model.generator.output.tolist().save()
        out.extend(out_model.value)
    return out


def patched_repr_defs(lang_pairs, target_lang, path=None):
    global def_source_prompts, trans_source_prompts, trans_mean_reprs, def_mean_reprs, target_prompts, full_def_dataset, def_dataset, trans_dataset
    lang_pairs = np.array(lang_pairs)
    input_langs = ulist(lang_pairs[:, 0])
    output_langs = ulist(lang_pairs[:, 1])
    trans_dataset = get_word_translation_dataset(
        target_lang, ulist(lang_pairs.flatten()), v2=True
    )
    full_def_dataset = get_cloze_dataset(
        output_langs + [target_lang], drop_no_defs=True
    )
    # compute intersection between trans_dataset and def_dataset
    common_words = set(trans_dataset["word_original"]).intersection(
        set(full_def_dataset["word_original"])
    )
    if DEBUG:
        common_words = list(common_words)[:8]
        print("Debug mode: using only 8 common words")
    trans_dataset = trans_dataset[trans_dataset["word_original"].isin(common_words)]
    def_dataset = full_def_dataset[full_def_dataset["word_original"].isin(common_words)]
    assert len(trans_dataset) == len(def_dataset)
    print(f"Found {len(common_words)} common words")
    trans_source_prompts = np.array(
        [
            translation_prompts(
                trans_dataset,
                tokenizer,
                input_lang,
                output_lang,
                n=num_few_shot,
                cut_at_obj=True,
            )
            for input_lang, output_lang in lang_pairs
        ]
    )
    trans_source_prompts_str = np.array(
        [[p.prompt for p in prompts] for prompts in trans_source_prompts]
    )
    def_source_prompts = [
        def_prompt(
            def_dataset,
            tokenizer,
            output_lang,
            n=num_few_shot,
            use_word_to_def=True,
            cut_at_obj=True,
        )
        for output_lang in output_langs
    ]
    def_source_prompts_str = np.array(
        [[p.prompt for p in prompts] for prompts in def_source_prompts]
    )
    trans_activations = (
        collect_activations_batched(
            nn_model,
            list(trans_source_prompts_str.flatten()),
            batch_size=batch_size,
            tqdm=tqdm,
        )
        .transpose(0, 1)
        .reshape(len(lang_pairs), len(trans_dataset), get_num_layers(nn_model), -1)
    )  # num_langs, num_words, num_layers, model_dim
    trans_mean_reprs = trans_activations.mean(dim=0)  # num_words, num_layers, model_dim
    trans_single_reprs = trans_activations[0]  # num_words, num_layers, model_dim
    def_activations = (
        collect_activations_batched(
            nn_model,
            list(def_source_prompts_str.flatten()),
            batch_size=batch_size,
            tqdm=tqdm,
        )
        .transpose(0, 1)
        .reshape(len(output_langs), len(def_dataset), get_num_layers(nn_model), -1)
    )  # num_langs, num_words, num_layers, model_dim
    def_mean_reprs = def_activations.mean(dim=0)  # num_words, num_layers, model_dim
    def_single_reprs = def_activations[0]  # num_words, num_layers, model_dim
    target_prompts = []
    for word_original in def_dataset["word_original"]:
        safe_df = full_def_dataset[full_def_dataset["word_original"] != word_original]
        safe_prompt = sample(
            def_prompt(
                safe_df,
                tokenizer,
                target_lang,
                n=num_few_shot,
                use_word_to_def=True,
                cut_at_obj=False,
            ),
            1,
        )[0]
        target_prompts.append(safe_prompt)
    target_prompts_str = [p.prompt for p in target_prompts]
    generations = {}
    for repr_type, reprs in [
        ("from def", def_mean_reprs),
        ("from single def", def_single_reprs),
        ("from trans", trans_mean_reprs),
        ("from single trans", trans_single_reprs),
    ]:
        output = patched_generation(target_prompts_str, reprs, gen_batch_size)
        generations[repr_type] = tokenizer.batch_decode(
            output, skip_special_tokens=True
        )

    json_file = path / (exp_id + "_patched_generations.json")
    defs = defaultdict(dict)
    generations_dict = defaultdict(dict)
    for repr_type in ["from def", "from single def", "from trans", "from single trans"]:
        for i, word_original in enumerate(def_dataset["word_original"]):
            gen = generations[repr_type][i]
            defs[word_original][repr_type] = extract_def(gen)
            generations_dict[word_original][repr_type] = gen
    json_dic = {
        "defs": defs,
        "generations": generations_dict,
    }
    with open(json_file, "w") as f:
        json.dump(json_dic, f, indent=4)
    return defs


def gen_baseline_defs(lang, words, exp_path):
    df = get_cloze_dataset(lang, drop_no_defs=True)
    prompts = def_prompt(
        df, tokenizer, lang, n=num_few_shot, use_word_to_def=True, words=words
    )
    str_prompts = [p.prompt for p in prompts]
    dataloader = DataLoader(str_prompts, batch_size=batch_size)
    generations = []
    for batch in dataloader:
        with nn_model.generate(
            batch, max_new_tokens=50, stop_strings=["\n"], tokenizer=nn_model.tokenizer
        ):
            out = nn_model.generator.output.tolist().save()
        generations.extend(out.value)
    generations = tokenizer.batch_decode(generations, skip_special_tokens=True)
    generations = {
        word: generation.split("\n")[-2] for word, generation in zip(words, generations)
    }
    defs = {word: extract_def(generation) for word, generation in generations.items()}
    json_dic = {"defs": defs, "generations": generations}
    json_file = exp_path / (exp_id + "baseline_defs.json")
    with open(json_file, "w") as f:
        json.dump(json_dic, f, indent=4)
    return defs


def ground_truth_defs(lang, words):
    df = load_synset(lang)
    return {
        word: ulist(
            ast.literal_eval(df[df["word_original"] == word]["definitions"].tolist()[0])
        )
        for word in words
    }


def generate_embeddings(prompts):
    if isinstance(prompts, dict):
        prompts = list(itertools.chain.from_iterable(prompts.values()))
    embeddings = embeddings_model.encode(
        prompts, batch_size=emb_batch_size, convert_to_tensor=True
    )
    return embeddings


def compare_defs(defs, gt_defs_embeddings):
    """
    Compare the definitions generated by the model to the ground truth definitions.

    Args:
        defs: a dictionary of word -> list of definitions
        gt_defs_embeddings: a tensor of shape (num_words, num_defs, embedding_dim) containing the embeddings of the ground truth definitions

    Returns:
    """
    all_defs = list(itertools.chain.from_iterable(defs.values()))
    all_defs_embeddings = generate_embeddings(all_defs)
    similarities = embeddings_model.similarity(all_defs_embeddings, gt_defs_embeddings)
    return similarities


def plot_defs_comparison(result_dict, from_what, title=None, path=None):
    titles = {
        "mean sim": "Mean Similarity with GT Defs",
        "sim w mean": "Similarity with Mean Embedding of GT Defs",
        "max sim": "Max Similarity with GT Defs",
        "mean sim with others": "Mean Similarity with Other Words",
        "max sim with others": "Max Similarity with Other Words",
    }
    metrics = list(titles.keys())
    values = {
        titles[m]: [result_dict[word][from_what][m] for word in result_dict]
        for m in metrics
    }
    fig = make_subplots(
        rows=2,
        cols=3,
        subplot_titles=[titles[m] for m in metrics],
    )
    for i, (name, vals) in enumerate(values.items()):
        fig.add_trace(
            go.Histogram(x=vals, name=name, bingroup="a"),
            row=(i // 3) + 1,
            col=(i % 3) + 1,
        )
    fig.update_layout(
        title_text=f"Definition Embedding Comparison Metrics - {title or from_what}",
        showlegend=False,
        autosize=True,
        height=800,
        width=1000,
    )

    for i in range(1, 6):
        fig.update_xaxes(
            title_text="Similarity Score",
            row=(i - 1) // 3 + 1,
            col=(i - 1) % 3 + 1,
            range=[0, 1],
        )
        fig.update_yaxes(title_text="Count", row=(i - 1) // 3 + 1, col=(i - 1) % 3 + 1)
    fig.show()
    if path is not None:
        fig.write_image(path / f"{exp_id}_{from_what}_defs_comparison.png", scale=3)
    return fig


def plot_compare_setup(result_dict, path=None):
    titles = {
        "mean sim": "Mean Similarity with GT Defs",
        "sim w mean": "Similarity with Mean Embedding of GT Defs",
        "max sim": "Max Similarity with GT Defs",
        "mean sim with others": "Mean Similarity with Other Words",
        "max sim with others": "Max Similarity with Other Words",
    }
    metrics = list(titles.keys())

    # Calculate means for each metric and source
    means = {}
    for source in ["from trans", "from def", "baseline", "rnd gt"]:
        means[source] = {
            metric: np.mean([result_dict[word][source][metric] for word in result_dict])
            for metric in metrics
        }

    # Create subplots
    fig = make_subplots(
        rows=2,
        cols=3,
        subplot_titles=[titles[m] for m in metrics],
    )

    # Add bars for each metric
    colors = {
        "from trans": "blue",
        "from def": "red",
        "baseline": "green",
        "rnd gt": "purple",
    }
    for i, metric in enumerate(metrics):
        row = (i // 3) + 1
        col = (i % 3) + 1

        fig.add_trace(
            go.Bar(
                x=list(means.keys()),
                y=[means[source][metric] for source in means.keys()],
                name=titles[metric],
                marker_color=[colors[source] for source in means.keys()],
            ),
            row=row,
            col=col,
        )

        # Update axes
        if col == 1:
            fig.update_yaxes(title_text="Mean Score", row=row, col=col, range=[0, 1])
        if row == 2:
            fig.update_xaxes(title_text="Source", row=row, col=col)

    fig.update_layout(
        title_text="Mean Metrics Across Definition Sources",
        showlegend=False,
        autosize=True,
        title_font_size=24,
        margin=dict(t=80, b=30, l=30, r=30, pad=0),
        height=800,
        width=1200,
    )

    fig.show()
    if path is not None:
        fig.write_image(path / f"{exp_id}_defs_comparison.png", scale=3)
    # Create histogram plots
    fig2 = make_subplots(
        rows=2,
        cols=3,
        subplot_titles=[titles[m] for m in metrics],
    )

    # Add histograms for each metric
    colors = {
        "from trans": "blue",
        "from def": "red",
        "baseline": "green",
        "rnd gt": "purple",
    }
    for i, metric in enumerate(metrics):
        row = (i // 3) + 1
        col = (i % 3) + 1

        for source in means.keys():
            values = [result_dict[word][source][metric] for word in result_dict]
            fig2.add_trace(
                go.Histogram(
                    x=values,
                    name=source,
                    showlegend=i == 0,
                    marker_color=colors[source],
                    opacity=0.75,
                    nbinsx=20,
                ),
                row=row,
                col=col,
            )

        # Update axes
        if col == 1:
            fig2.update_yaxes(title_text="Count", row=row, col=col)
        if row == 2:
            fig2.update_xaxes(title_text="Score", row=row, col=col, range=[0, 1])

    fig2.update_layout(
        title_text="Distribution of Metrics Across Definition Sources",
        showlegend=True,
        barmode="group",
        autosize=True,
        title_font_size=24,
        margin=dict(t=80, b=30, l=60, r=60, pad=0),
        height=800,
        width=1200,
    )

    fig2.show()
    if path is not None:
        fig2.write_image(path / f"{exp_id}_defs_histogram.png", scale=3)
    return fig, fig2


def experiment(lang_pairs, target_lang):
    global defs, baseline_defs, gt_defs, gt_defs_embeddings, mean_embeddings, gt_embeddings_dict, result_dict
    pref = "_".join("-".join(ls) for ls in lang_pairs)
    path = Path("results") / model_name / exp_name / f"{pref}-{target_lang}"
    path.mkdir(parents=True, exist_ok=True)

    # === Generate Definitions ===
    defs = patched_repr_defs(lang_pairs, target_lang, path)
    baseline_defs = gen_baseline_defs(target_lang, defs.keys(), path)
    gt_defs = ground_truth_defs(target_lang, defs.keys())
    print(
        f"len gt_defs: {len(gt_defs)}, len defs: {len(defs)}, len baseline_defs: {len(baseline_defs)}"
    )
    for w_defs in gt_defs.values():
        shuffle(w_defs)
    all_defs = {
        word: [
            defs[word]["from trans"],
            defs[word]["from def"],
            defs[word]["from single trans"],
            defs[word]["from single def"],
            baseline_defs[word],
            gt_defs[word][0],
        ]
        for word in defs
    }
    print(f"len all_defs: {len(all_defs)}")

    # === Generate Embeddings ===
    gt_defs_embeddings = generate_embeddings(gt_defs)
    all_idx = []
    gt_embeddings_dict = {}
    prev_idx = 0
    for word, defs_list in gt_defs.items():
        gt_embeddings_dict[word] = gt_defs_embeddings[
            prev_idx : prev_idx + len(defs_list)
        ]
        prev_idx += len(defs_list)
        all_idx.append(prev_idx)
    all_idx.append(0)  # for the first word

    mean_embeddings = th.stack([emb.mean(dim=0) for emb in gt_embeddings_dict.values()])
    assert (
        mean_embeddings.shape[0] == len(defs)
        and mean_embeddings.shape[1] == gt_defs_embeddings.shape[1]
    ), f"Shape mismatch: mean_embeddings.shape={mean_embeddings.shape} != gt_defs_embeddings.shape={gt_defs_embeddings.shape}, len(defs)={len(defs)}"
    similarities = compare_defs(all_defs, gt_defs_embeddings)
    similarities_mean = compare_defs(all_defs, mean_embeddings)
    methods = [
        "from trans",
        "from def",
        "from single trans",
        "from single def",
        "baseline",
        "rnd gt",
    ]
    dict_factory = lambda: {m: {} for m in methods}

    result_dict = defaultdict(dict_factory)
    for i, word in enumerate(defs):
        for j, method in enumerate(methods):
            sim_w_mean = similarities_mean[len(methods) * i + j][i]
            start_idx = all_idx[i - 1]
            if method == "rnd gt":
                start_idx += 1  # skip self
            sims = similarities[len(methods) * i + j][start_idx : all_idx[i]]
            other_sims = th.cat(
                [
                    similarities[len(methods) * i + j][0 : all_idx[i - 1]],
                    similarities[len(methods) * i + j][all_idx[i] :],
                ]
            )
            result_dict[word][method]["mean sim with others"] = other_sims.mean().item()
            result_dict[word][method]["max sim with others"] = other_sims.max().item()
            if method == "rnd gt" and len(gt_defs[word]) == 1:
                # doesn't make sense to compare to itself
                result_dict[word][method]["sim w mean"] = None
                result_dict[word][method]["mean sim"] = None
                result_dict[word][method]["max sim"] = None
            else:
                result_dict[word][method]["sim w mean"] = sim_w_mean.item()
                result_dict[word][method]["mean sim"] = sims.mean().item()
                result_dict[word][method]["max sim"] = sims.max().item()

    json_file = path / (exp_id + "_defs_comparison.json")
    with open(json_file, "w") as f:
        json.dump(result_dict, f, indent=4)

    # === Plotting ===
    fig_trans = plot_defs_comparison(
        result_dict,
        "from trans",
        "Patching mean representation from translations",
        path,
    )
    fig_def = plot_defs_comparison(
        result_dict, "from def", "Patching mean representation from definitions", path
    )
    fig_baseline = plot_defs_comparison(
        result_dict, "baseline", "Vanilla prompting", path
    )
    fig_rnd_gt = plot_defs_comparison(
        result_dict, "rnd gt", "Random GT definition", path
    )
    fig_compare = plot_compare_setup(result_dict, path)
    return fig_trans, fig_def, fig_baseline, fig_rnd_gt, fig_compare

In [12]:
paper_pairs = [["de", "it"], ["nl", "fi"], ["zh", "es"], ["es", "ru"], ["ru", "ko"]]
paper_ins = ["de", "nl", "zh", "es", "ru"]
test_pairs = [["es", "fr"], ["es", "de"]]
paper_args = [
    (test_pairs, "en"),
    (paper_pairs, "en"),
    ([["en", "fr"]], "en"),
    [[[l, "hi"] for l in paper_ins], "en"],
]
for pargs in paper_args:
    figs = experiment(*pargs)
