# Setup

In [1]:
import torch as th
import numpy as np
import pandas as pd
import json
import matplotlib.pyplot as plt
from pathlib import Path
from time import time
import itertools
import gc
from random import shuffle

_ = th.set_grad_enabled(False)

In [2]:
exp_name = "mean_repr_def"

In [3]:
# papermill parameters
batch_size = 8
model = "google/gemma-2-2b"
model_path = None
trust_remote_code = False
device = "auto"
remote = False
num_few_shot = 5
exp_id = "test"
extra_args = []
use_tl = False

In [4]:
INTERACTIVE_MODE = exp_id == "test"  # Set to True to run the notebook interactively
DEBUG = exp_id == "test" and True
import sys
if DEBUG:
    print("Debugging...")
if INTERACTIVE_MODE:
    print("Interactive mode!")
    exp_id = "test" + str(int(time()))
    sys.path.append("../src")
    %load_ext autoreload
    %autoreload 2
    from tqdm.notebook import tqdm, trange
else:
    sys.path.append("./src")
    from tqdm import tqdm, trange

In [5]:
from nnterp import load_model
from dlabutils import model_path as dlab_model_path
from argparse import ArgumentParser
from sentence_transformers import SentenceTransformer

parser = ArgumentParser()
parser.add_argument("--gen-batch-size", type=int, default=batch_size)
parser.add_argument(
    "--embeddings-model",
    type=str,
    default="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
)
parser.add_argument("--emb-batch-size", type=int, default=256)
pargs = parser.parse_args(extra_args)
gen_batch_size = pargs.gen_batch_size
emb_batch_size = pargs.emb_batch_size
embeddings_model = SentenceTransformer(dlab_model_path(pargs.embeddings_model), device=device if device != "auto" else None)

if model_path is None:
    model_path = dlab_model_path(model)
nn_model = load_model(
    model_path,
    trust_remote_code=trust_remote_code,
    device_map=device,
    use_tl=use_tl,
    no_space_on_bos=True,
)
tokenizer = nn_model.tokenizer
model_name = model.split("/")[-1]



# Experiment functions

In [6]:
from nnterp.nnsight_utils import (
    collect_activations_batched,
    get_num_layers,
    get_layer_output,
)
from random import sample
from load_dataset import get_word_translation_dataset, get_cloze_dataset, load_synset
from utils import ulist
from prompt_tools import translation_prompts, def_prompt, get_obj_id
from torch.utils.data import DataLoader
import ast
import itertools
import warnings
from copy import deepcopy
from collections import defaultdict
from plot_utils import plot_defs_comparison, plot_compare_setup



def extract_def(generation):
    match generation[-2], generation[-1]:
        case '"', "\n":
            pass
        case _, "\n":
            end = generation.split("\n")[-2]
            warnings.warn(f"Last line does not end with a quote: {end}")
            generation = generation[:-1] + '"\n'
        case _, '"':
            generation += "\n"
        case _:
            end = generation.split("\n")[-1]
            warnings.warn(f"Last line does not end with a quote and newline: {end}")
            generation += '"\n'

    last_line = generation.split("\n")[-2]
    split = last_line.split('"')
    return '"'.join(split[3:-1])


def patched_generation(prompts, reprs, gen_batch_size):
    pos = get_obj_id(prompts[0], tokenizer)
    out = []
    for i in trange(0, len(prompts), gen_batch_size):
        end = min(i + gen_batch_size, len(prompts))
        with nn_model.generate(
            prompts[i:end],
            max_new_tokens=50,
            stop_strings=["\n"],
            tokenizer=nn_model.tokenizer,
        ):
            for layer in range(get_num_layers(nn_model)):
                get_layer_output(nn_model, layer)[:, pos] = reprs[i:end, layer]
            out_model = nn_model.generator.output.tolist().save()
        out.extend(out_model.value)
    return out


def patched_repr_defs(lang_pairs, target_lang, path=None):
    global def_source_prompts, trans_source_prompts, trans_mean_reprs, def_mean_reprs, target_prompts, full_def_dataset, def_dataset, trans_dataset
    lang_pairs = np.array(lang_pairs)
    output_langs = ulist(lang_pairs[:, 1])
    trans_dataset = get_word_translation_dataset(
        target_lang, ulist(lang_pairs.flatten()), v2=True
    )
    full_def_dataset = get_cloze_dataset(
        output_langs + [target_lang], drop_no_defs=True
    )
    # compute intersection between trans_dataset and def_dataset
    common_words = set(trans_dataset["word_original"]).intersection(
        set(full_def_dataset["word_original"])
    )
    if DEBUG:
        common_words = sample(list(common_words), 8)
        print("Debug mode: using only 8 common words")
    trans_dataset = trans_dataset[trans_dataset["word_original"].isin(common_words)]
    def_dataset = full_def_dataset[full_def_dataset["word_original"].isin(common_words)]
    assert len(trans_dataset) == len(def_dataset)
    print(f"Found {len(common_words)} common words")
    trans_source_prompts = np.array(
        [
            translation_prompts(
                trans_dataset,
                tokenizer,
                input_lang,
                output_lang,
                n=num_few_shot,
                cut_at_obj=True,
            )
            for input_lang, output_lang in lang_pairs
        ]
    )
    trans_source_prompts_str = np.array(
        [[p.prompt for p in prompts] for prompts in trans_source_prompts]
    )
    def_source_prompts = [
        def_prompt(
            def_dataset,
            tokenizer,
            output_lang,
            n=num_few_shot,
            use_word_to_def=True,
            cut_at_obj=True,
        )
        for output_lang in output_langs
    ]
    def_source_prompts_str = np.array(
        [[p.prompt for p in prompts] for prompts in def_source_prompts]
    )
    trans_activations = (
        collect_activations_batched(
            nn_model,
            list(trans_source_prompts_str.flatten()),
            batch_size=batch_size,
            tqdm=tqdm,
        )
        .transpose(0, 1)
        .reshape(len(lang_pairs), len(trans_dataset), get_num_layers(nn_model), -1)
    )  # num_langs, num_words, num_layers, model_dim
    trans_mean_reprs = trans_activations.mean(dim=0)  # num_words, num_layers, model_dim
    trans_single_reprs = trans_activations[0]  # num_words, num_layers, model_dim
    def_activations = (
        collect_activations_batched(
            nn_model,
            list(def_source_prompts_str.flatten()),
            batch_size=batch_size,
            tqdm=tqdm,
        )
        .transpose(0, 1)
        .reshape(len(output_langs), len(def_dataset), get_num_layers(nn_model), -1)
    )  # num_langs, num_words, num_layers, model_dim
    def_mean_reprs = def_activations.mean(dim=0)  # num_words, num_layers, model_dim
    def_single_reprs = def_activations[0]  # num_words, num_layers, model_dim
    target_prompts = []
    for word_original in def_dataset["word_original"]:
        safe_df = full_def_dataset[full_def_dataset["word_original"] != word_original]
        safe_prompt = sample(
            def_prompt(
                safe_df,
                tokenizer,
                target_lang,
                n=num_few_shot,
                use_word_to_def=True,
                cut_at_obj=False,
            ),
            1,
        )[0]
        target_prompts.append(safe_prompt)
    target_prompts_str = [p.prompt for p in target_prompts]
    generations = {}
    for repr_type, reprs in [
        ("from def", def_mean_reprs),
        ("from single def", def_single_reprs),
        ("from trans", trans_mean_reprs),
        ("from single trans", trans_single_reprs),
    ]:
        output = patched_generation(target_prompts_str, reprs, gen_batch_size)
        generations[repr_type] = tokenizer.batch_decode(
            output, skip_special_tokens=True
        )

    json_file = path / (exp_id + "_patched_generations.json")
    defs = defaultdict(dict)
    generations_dict = defaultdict(dict)
    for repr_type in ["from def", "from single def", "from trans", "from single trans"]:
        for i, word_original in enumerate(def_dataset["word_original"]):
            gen = generations[repr_type][i]
            defs[word_original][repr_type] = extract_def(gen)
            generations_dict[word_original][repr_type] = gen
    json_dic = {
        "defs": defs,
        "generations": generations_dict,
    }
    with open(json_file, "w") as f:
        json.dump(json_dic, f, indent=4)
    return defs


def gen_baseline_defs(lang, words, exp_path):
    df = get_cloze_dataset(lang, drop_no_defs=True)
    prompts = def_prompt(
        df, tokenizer, lang, n=num_few_shot, use_word_to_def=True, words=words
    )
    str_prompts = [p.prompt for p in prompts]
    dataloader = DataLoader(str_prompts, batch_size=batch_size)
    generations = []
    for batch in dataloader:
        with nn_model.generate(
            batch, max_new_tokens=50, stop_strings=["\n"], tokenizer=nn_model.tokenizer
        ):
            out = nn_model.generator.output.tolist().save()
        generations.extend(out.value)
    generations = tokenizer.batch_decode(generations, skip_special_tokens=True)
    generations = {
        word: generation.split("\n")[-2] for word, generation in zip(words, generations)
    }
    defs = {word: extract_def(generation) for word, generation in generations.items()}
    json_dic = {"defs": defs, "generations": generations}
    json_file = exp_path / (exp_id + "baseline_defs.json")
    with open(json_file, "w") as f:
        json.dump(json_dic, f, indent=4)
    return defs


def ground_truth_defs(lang, words):
    df = load_synset(lang)
    return {
        word: ulist(
            ast.literal_eval(df[df["word_original"] == word]["definitions"].tolist()[0])
        )
        for word in words
    }


def generate_embeddings(prompts):
    if isinstance(prompts, dict):
        prompts = list(itertools.chain.from_iterable(prompts.values()))
    embeddings = embeddings_model.encode(
        prompts, batch_size=emb_batch_size, convert_to_tensor=True
    )
    return embeddings


def compare_defs(defs, gt_defs_embeddings):
    """
    Compare the definitions generated by the model to the ground truth definitions.

    Args:
        defs: a dictionary of word -> list of definitions
        gt_defs_embeddings: a tensor of shape (num_words, num_defs, embedding_dim) containing the embeddings of the ground truth definitions

    Returns:
    """
    all_defs = list(itertools.chain.from_iterable(defs.values()))
    all_defs_embeddings = generate_embeddings(all_defs)
    similarities = embeddings_model.similarity(all_defs_embeddings, gt_defs_embeddings)
    return similarities


def experiment(lang_pairs, target_lang):
    global defs, baseline_defs, gt_defs, gt_defs_embeddings, mean_embeddings, gt_embeddings_dict, result_dict
    pref = "_".join("-".join(ls) for ls in lang_pairs)
    path = Path("results") / model_name / exp_name / f"{pref}-{target_lang}"
    path.mkdir(parents=True, exist_ok=True)

    # === Generate Definitions ===
    defs = patched_repr_defs(lang_pairs, target_lang, path)
    baseline_defs = gen_baseline_defs(target_lang, defs.keys(), path)
    gt_defs = ground_truth_defs(target_lang, defs.keys())
    print(
        f"len gt_defs: {len(gt_defs)}, len defs: {len(defs)}, len baseline_defs: {len(baseline_defs)}"
    )
    for w_defs in gt_defs.values():
        shuffle(w_defs)
    all_defs = {
        word: [
            defs[word]["from trans"],
            defs[word]["from def"],
            defs[word]["from single trans"],
            defs[word]["from single def"],
            baseline_defs[word],
            gt_defs[word][0],
        ]
        for word in defs
    }
    print(f"len all_defs: {len(all_defs)}")

    # === Generate Embeddings ===
    gt_defs_embeddings = generate_embeddings(gt_defs)
    all_idx = []
    gt_embeddings_dict = {}
    prev_idx = 0
    for word, defs_list in gt_defs.items():
        gt_embeddings_dict[word] = gt_defs_embeddings[
            prev_idx : prev_idx + len(defs_list)
        ]
        prev_idx += len(defs_list)
        all_idx.append(prev_idx)
    all_idx.append(0)  # for the first word

    mean_embeddings = th.stack([emb.mean(dim=0) for emb in gt_embeddings_dict.values()])
    mean_embeddings_wo_fst = th.stack(
        [
            emb[1:].mean(dim=0) if len(emb) > 1 else emb.mean(dim=0)
            for emb in gt_embeddings_dict.values()
        ]
    )
    assert (
        mean_embeddings.shape[0] == len(defs)
        and mean_embeddings.shape[1] == gt_defs_embeddings.shape[1]
    ), f"Shape mismatch: mean_embeddings.shape={mean_embeddings.shape} != gt_defs_embeddings.shape={gt_defs_embeddings.shape}, len(defs)={len(defs)}"
    similarities = compare_defs(all_defs, gt_defs_embeddings)
    similarities_mean = compare_defs(all_defs, mean_embeddings)
    similarities_mean_wo_fst = compare_defs(all_defs, mean_embeddings_wo_fst)
    title_suffix = (
        "<br>(" + ", ".join([t for s, t in lang_pairs]) + ") -> " + target_lang
    )
    method_names = {
        "from trans": "Patching mean representation from translations" + title_suffix,
        "from def": "Patching mean representation from definitions" + title_suffix,
        "from single trans": "Patching single representation from translations"
        + title_suffix,
        "from single def": "Patching single representation from definitions"
        + title_suffix,
        "baseline": "Vanilla prompting" + f" {target_lang}",
        "rnd gt": "Random GT definition" + f" {target_lang}",
    }
    methods = method_names.keys()
    dict_factory = lambda: {m: {} for m in methods}

    result_dict = defaultdict(dict_factory)
    for i, word in enumerate(defs):
        for j, method in enumerate(methods):
            sim_w_mean = similarities_mean[len(methods) * i + j][i]
            sim_w_mean_wo_fst = similarities_mean_wo_fst[len(methods) * i + j][i]
            start_idx = all_idx[i - 1]
            if method == "rnd gt":
                start_idx += 1  # skip self
            sims = similarities[len(methods) * i + j][start_idx : all_idx[i]]
            other_sims = th.cat(
                [
                    similarities[len(methods) * i + j][0 : all_idx[i - 1]],
                    similarities[len(methods) * i + j][all_idx[i] :],
                ]
            )
            result_dict[word][method]["mean sim with others"] = other_sims.mean().item()
            result_dict[word][method]["max sim with others"] = other_sims.max().item()
            if method == "rnd gt" and len(gt_defs[word]) == 1:
                # doesn't make sense to compare to itself
                result_dict[word][method]["sim w mean"] = None
                result_dict[word][method]["sim w mean fst"] = None
                result_dict[word][method]["mean sim"] = None
                result_dict[word][method]["max sim"] = None
            else:
                result_dict[word][method]["sim w mean"] = sim_w_mean.item()
                result_dict[word][method]["sim w mean fst"] = sim_w_mean_wo_fst.item()
                result_dict[word][method]["mean sim"] = sims.mean().item()
                result_dict[word][method]["max sim"] = sims.max().item()

    json_file = path / (exp_id + "_defs_comparison.json")
    with open(json_file, "w") as f:
        json.dump(result_dict, f, indent=4)

    # === Plotting ===
    figs = []
    for method, title in method_names.items():
        fig = plot_defs_comparison(result_dict, method, title, path, exp_id)
        figs.append(fig)
    figs_compare = plot_compare_setup(
        result_dict, path, title_suffix.replace("<br>", ""), exp_id
    )
    figs.extend(figs_compare)
    return figs

# Run experiment

In [8]:
paper_pairs = [["de", "it"], ["nl", "fi"], ["zh", "es"], ["es", "ru"], ["ru", "ko"]]
paper_ins = ["de", "nl", "zh", "es", "ru"]
test_pairs = [["es", "fr"], ["es", "de"]]
hard_langs = ["ko", "ja", "et", "fi"]
easy_langs = ["en", "fr", "zh", "de"]
paper_args = [
    (paper_pairs, "zh"),
    ([["es", "fr"], ["es", "de"]], "en"),
    ([["it", "es"], ["it", "de"]], "fr"),
    (paper_pairs, "en"),
    ([["en", hl] for hl in hard_langs], "fr"),
    ([["hi", el] for el in easy_langs], "et"),
]
for pargs in paper_args:
    figs = experiment(*pargs)
    gc.collect()
    th.cuda.empty_cache()
    th.cuda.synchronize()

Debug mode: using only 8 common words
Found 8 common words


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


Last line does not end with a quote and newline: "刀子" - "角（gǎn）是几何学中，两条直线或两条线段相交的点，两条直线或两条线段相交的点，两条直线或两条线段相交


Last line does not end with a quote and newline: "房屋" - "水生動物，有脊椎，有四肢，有肺，有鰓，有魚鱗，有魚骨，有魚尾，有魚嘴，有魚眼，有魚鼻，有魚舌，有魚鰓，


Last line does not end with a quote and newline: "房屋" - "水生動物，有鱗、有鰓，有四肢，有尾，有肺，有眼睛，有耳，有鼻孔，有口，有肛門，有腎，有膀胱，有肝，有脾，



len gt_defs: 8, len defs: 8, len baseline_defs: 8
len all_defs: 8


Debug mode: using only 8 common words
Found 8 common words


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

len gt_defs: 8, len defs: 8, len baseline_defs: 8
len all_defs: 8


Debug mode: using only 8 common words
Found 8 common words


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


Last line does not end with a quote and newline: "porte" - "Animal domestique de la famille des Suidae, de la famille des Artiodactyla, de la famille des Artiodactyla, de la famille des Artiodactyla, de la famille des Artiodactyla, de la famille des Artiodactyla, de


Last line does not end with a quote and newline: "porte" - "Animal domestique de la famille des ruminants, de la famille des Bovidae, de la famille des Artiodactyla, de la famille des Artiodactyla, de la famille des Artiodactyla, de la famille des Artiodactyla, de



len gt_defs: 8, len defs: 8, len baseline_defs: 8
len all_defs: 8


Debug mode: using only 8 common words
Found 8 common words


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


Last line does not end with a quote and newline: "television" - "The organ of vision in humans and other vertebrates, consisting of a transparent, spherical structure in the eye socket, which is lined with a pigmented layer called the retina, and which contains the photoreceptors, the rods and cones, which are responsible for



len gt_defs: 8, len defs: 8, len baseline_defs: 8
len all_defs: 8


Debug mode: using only 8 common words
Found 8 common words


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


Last line does not end with a quote and newline: "fille" - "La science qui étudie les relations entre les individus et les groupes sociaux, et les relations entre les individus et les groupes sociaux, et les relations entre les individus et les groupes sociaux, et les relations entre les individus et les groupes sociaux, et les


Last line does not end with a quote and newline: "fille" - "Science qui étudie les relations entre les ressources et les besoins, les moyens de production et les moyens de distribution, les échanges et les échanges, les prix et les quantités, les revenus et les dépenses, les profits et les pertes, les richesses



len gt_defs: 8, len defs: 8, len baseline_defs: 8
len all_defs: 8


Debug mode: using only 8 common words
Found 8 common words


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


Last line does not end with a quote and newline: "koer" - "Pühajärve (Pühajärve) ja Pühajärve (Pühajärve) kõrval asuv linn, kus asub Pühajärve (Pühajärve) ja Pü


Last line does not end with a quote and newline: "Kirst" - "Kõrge, ümmargune, ümbritsetud kaela ja kaela taga, peaga, peaga, peaga, peaga, peaga, peaga, peaga, peaga, peaga


Last line does not end with a quote and newline: "Kirst" - "Kõrge, ümmargune, tavaliselt kaunistatud, tavaliselt kaunistatud, tavaliselt kaunistatud, tavaliselt kaunistatud, tavaliselt kaunistatud, tavaliselt ka


Last line does not end with a quote and newline: "puu" - "täisnurkne, ühekorruseline, ühekorruseline, ühekorruseline, ühekorruseline, ühekorruseline, ühekorruseline, ühekor


Last line does not end with a quote and newline: "silm" - "suur, ükski teine ​​väike, ükski teine ​​väike, ükski teine ​​väike, ükski teine ​​väike, ükski teine ​​väike, üks


Last line does not end with a quote and newline: "koer" - "suur, õhuke, õhukesed, õhukesed, õhuk

len gt_defs: 8, len defs: 8, len baseline_defs: 8
len all_defs: 8
