In [1]:
INTERACTIVE_MODE = False  # Set to True to run the notebook interactively

import sys

if INTERACTIVE_MODE:
    sys.path.append("../src")
    %load_ext autoreload
    %autoreload 2
    from tqdm.notebook import tqdm, trange
else:
    sys.path.append("./src")
    from tqdm import tqdm, trange

In [2]:
import torch as th
import numpy as np
import pandas as pd
import json
import matplotlib.pyplot as plt
from pathlib import Path
from time import time
import itertools
from random import shuffle

_ = th.set_grad_enabled(False)

In [3]:
exp_name = "mean_repr_def"

In [4]:
# papermill parameters
batch_size = 8
model = "meta-llama/Llama-2-7b-hf"
model_path = None
trust_remote_code = False
device = "auto"
remote = False
num_few_shot = 5
exp_id = "test" + str(int(time()))
extra_args = []
use_tl = False

In [5]:
from nnterp import load_model
from dlabutils import model_path as dlab_model_path
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument("--gen-batch-size", type=int, default=batch_size)
pargs = parser.parse_args(extra_args)
gen_batch_size = pargs.gen_batch_size

if model_path is None:
    model_path = dlab_model_path(model)
nn_model = load_model(
    model_path,
    trust_remote_code=trust_remote_code,
    device_map=device,
    use_tl=use_tl,
    no_space_on_bos=True,
)
tokenizer = nn_model.tokenizer
model_name = model.split("/")[-1]

In [43]:
from nnterp.nnsight_utils import (
    collect_activations_batched,
    get_num_layers,
    get_layer_output,
)
from random import sample
from load_dataset import get_word_translation_dataset, get_cloze_dataset
from utils import ulist
from prompt_tools import translation_prompts, def_prompt, get_obj_id


def patched_generation(prompts, reprs, gen_batch_size):
    print(len(prompts))
    pos = get_obj_id(prompts[0], tokenizer)
    out = []
    for i in trange(0, len(prompts), gen_batch_size):
        end = min(i + gen_batch_size, len(prompts))
        with nn_model.generate(
            prompts[i:end],
            max_new_tokens=50,
            stop_strings=["\n"],
            tokenizer=nn_model.tokenizer,
        ):
            for layer in range(get_num_layers(nn_model)):
                get_layer_output(nn_model, layer)[:, pos] = reprs[i:end, layer]
            out_model = nn_model.generator.output.tolist().save()
        print(out_model.value)
        out.extend(out_model.value)
    return out


def mean_repr_score(lang_pairs, target_lang, num_few_shot=5):
    global def_source_prompts, trans_source_prompts, trans_mean_reprs, def_mean_reprs, target_prompts, full_def_dataset, def_dataset, trans_dataset
    lang_pairs = np.array(lang_pairs)
    input_langs = ulist(lang_pairs[:, 0])
    output_langs = ulist(lang_pairs[:, 1])
    trans_dataset = get_word_translation_dataset(
        target_lang, ulist(lang_pairs.flatten()), v2=True
    )
    full_def_dataset = get_cloze_dataset(
        output_langs + [target_lang], drop_no_defs=True
    )
    # compute intersection between trans_dataset and def_dataset
    common_words = set(trans_dataset["word_original"]).intersection(
        set(full_def_dataset["word_original"])
    )
    # common_words = list(common_words)[:8]  # TODO: remove this
    trans_dataset = trans_dataset[trans_dataset["word_original"].isin(common_words)]
    def_dataset = full_def_dataset[full_def_dataset["word_original"].isin(common_words)]
    assert len(trans_dataset) == len(def_dataset)
    print(f"Found {len(common_words)} common words")
    trans_source_prompts = np.array(
        [
            translation_prompts(
                trans_dataset,
                tokenizer,
                input_lang,
                output_lang,
                n=num_few_shot,
                cut_at_obj=True,
            )
            for input_lang, output_lang in lang_pairs
        ]
    )
    trans_source_prompts_str = np.array(
        [[p.prompt for p in prompts] for prompts in trans_source_prompts]
    )
    def_source_prompts = [
        def_prompt(
            def_dataset,
            tokenizer,
            output_lang,
            n=num_few_shot,
            use_word_to_def=True,
            cut_at_obj=True,
        )
        for output_lang in output_langs
    ]
    def_source_prompts_str = np.array(
        [[p.prompt for p in prompts] for prompts in def_source_prompts]
    )
    trans_activations = collect_activations_batched(
        nn_model,
        list(trans_source_prompts_str.flatten()),
        batch_size=batch_size,
        tqdm=tqdm,
    ).transpose(
        0, 1
    )  # num_prompts, num_layers, model_dim
    trans_mean_reprs = trans_activations.reshape(
        len(lang_pairs), len(trans_dataset), get_num_layers(nn_model), -1
    ).mean(
        dim=0
    )  # num_words, num_layers, model_dim
    def_activations = collect_activations_batched(
        nn_model,
        list(def_source_prompts_str.flatten()),
        batch_size=batch_size,
        tqdm=tqdm,
    ).transpose(
        0, 1
    )  # num_prompts, num_layers, model_dim
    def_mean_reprs = def_activations.reshape(
        len(output_langs), len(def_dataset), get_num_layers(nn_model), -1
    ).mean(
        dim=0
    )  # num_words, num_layers, model_dim
    print(def_mean_reprs.shape, trans_mean_reprs.shape)
    target_prompts = []
    for word_original in def_dataset["word_original"]:
        safe_df = full_def_dataset[full_def_dataset["word_original"] != word_original]
        safe_prompt = sample(
            def_prompt(
                safe_df,
                tokenizer,
                target_lang,
                n=num_few_shot,
                use_word_to_def=True,
                cut_at_obj=False,
            ),
            1,
        )[0]
        target_prompts.append(safe_prompt)
    target_prompts_str = [p.prompt for p in target_prompts]
    def_output_generations = patched_generation(
        target_prompts_str, def_mean_reprs, gen_batch_size
    )
    def_output_generations = tokenizer.batch_decode(def_output_generations, skip_special_tokens=True)
    pref = "_".join("-".join(ls) for ls in lang_pairs)
    path = Path("results") / model_name / exp_name / f"{pref}-{target_lang}"
    path.mkdir(parents=True, exist_ok=True)
    json_file = path / (exp_id + "_def_generations.json")
    json_dic = {
        word_original: generation
        for word_original, generation in zip(
            def_dataset["word_original"], def_output_generations
        )
    }
    with open(json_file, "w") as f:
        json.dump(json_dic, f, indent=4)

    trans_output_generations = patched_generation(
        target_prompts_str, trans_mean_reprs, gen_batch_size
    )
    trans_output_generations = tokenizer.batch_decode(trans_output_generations)
    json_file = path / (exp_id + "_trans_generations.json")
    json_dic = {
        word_original: generation
        for word_original, generation in zip(
            def_dataset["word_original"], trans_output_generations
        )
    }
    with open(json_file, "w") as f:
        json.dump(json_dic, f, indent=4)

In [44]:
paper_pairs = [["de", "it"], ["nl", "fi"], ["zh", "es"], ["es", "ru"], ["ru", "ko"]]
paper_ins = ["de", "nl", "zh", "es", "ru"]
test_pairs = [["es", "fr"], ["es", "de"]]
paper_args = [
    (test_pairs, "en"),
    (paper_pairs, "en"),
    [[[l, "hi"] for l in paper_ins], "en"],
]
for pargs in paper_args:
    mean_repr_score(*pargs)


Found 8 common words


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

torch.Size([8, 32, 4096]) torch.Size([8, 32, 4096])
8


  0%|          | 0/1 [00:00<?, ?it/s]

[[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 29908, 307, 974, 29908, 448, 376, 29909, 12566, 573, 21653, 393, 18469, 470, 7190, 278, 2246, 310, 263, 5214, 29908, 13, 29908, 791, 2330, 29908, 448, 376, 10773, 4482, 29899, 5890, 2982, 5139, 287, 491, 6133, 5962, 29936, 7148, 385, 560, 549, 403, 29892, 13774, 2919, 29892, 330, 2705, 269, 4757, 292, 316, 2590, 310, 278, 11563, 29915, 29879, 7101, 29892, 15574, 24046, 1546, 1023, 19223, 470, 1546, 20238, 310, 22696, 470, 19223, 29892, 322, 4049, 6943, 263, 4840, 411, 385, 714, 1026, 1213, 13, 29908, 272, 12644, 404, 29908, 448, 376, 1625, 473, 29892, 5982, 1546, 2654, 322, 13328, 297, 278, 18272, 310, 3578, 29908, 13, 29908, 1129, 3522, 29908, 448, 376, 29909, 2924, 310, 282, 3222, 29892, 607, 338, 15579, 491, 385, 8718, 297, 385, 4274, 310, 1067, 6046, 322, 19700, 363, 8635, 29936, 5491, 2919, 3307, 297, 1797, 304, 2480

  0%|          | 0/1 [00:00<?, ?it/s]

[[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 29908, 307, 974, 29908, 448, 376, 29909, 12566, 573, 21653, 393, 18469, 470, 7190, 278, 2246, 310, 263, 5214, 29908, 13, 29908, 791, 2330, 29908, 448, 376, 10773, 4482, 29899, 5890, 2982, 5139, 287, 491, 6133, 5962, 29936, 7148, 385, 560, 549, 403, 29892, 13774, 2919, 29892, 330, 2705, 269, 4757, 292, 316, 2590, 310, 278, 11563, 29915, 29879, 7101, 29892, 15574, 24046, 1546, 1023, 19223, 470, 1546, 20238, 310, 22696, 470, 19223, 29892, 322, 4049, 6943, 263, 4840, 411, 385, 714, 1026, 1213, 13, 29908, 272, 12644, 404, 29908, 448, 376, 1625, 473, 29892, 5982, 1546, 2654, 322, 13328, 297, 278, 18272, 310, 3578, 29908, 13, 29908, 1129, 3522, 29908, 448, 376, 29909, 2924, 310, 282, 3222, 29892, 607, 338, 15579, 491, 385, 8718, 297, 385, 4274, 310, 1067, 6046, 322, 19700, 363, 8635, 29936, 5491, 2919, 3307, 297, 1797, 304, 2480

In [32]:
print(target_prompts[0].prompt)


"function" - "(mathematics) A relation in which each element of the domain is associated with exactly one element of the codomain."
"pin" - "A human limb; commonly used to refer to a whole limb but technically only the part of the limb between the knee and ankle"
"whistle" - "Single-note woodwind instrument, in the percussion section of the orchestra, and with many other applications in sport and other fields."
"church" - "A Christian house of worship; a building where Christian religious services take place."
"button" - "A piece of wood or metal, usually flat and elongated, turning on a nail or screw, to fasten something, such as a door."
"hammer" - "


In [42]:
toks = tokenizer(target_prompts[0].prompt, return_tensors="pt")
nn_model._model.generate(**toks, max_new_tokens=10, stop_strings=["\n"], tokenizer=nn_model.tokenizer)

tensor([[    1, 29908,   845,  1698, 17977,   495, 29908,   448,   376, 29909,
          4742,   304,   270,  1160,   264,  6709,   337,  9917,   491,  1518,
          2548,   664,   491, 28172,   263, 22576,  1549, 16169, 29892,   322,
           577, 17415,   278, 28310,  5864,   304, 12871, 29908,    13, 29908,
           845,  7297, 29908,   448,   376, 13440,   705,   279,   528, 10501,
           304,  6216,   278,  3661,   313, 22503,   278,   385, 29895,   280,
         29897,   411,   263, 25706,  7568,   310,   454,  1624,   470,   715,
          6288,   322,   263, 14419,   322,   540,   295,   310, 14200,   631,
          5518, 29908,    13, 29908, 29500, 29908,   448,   376,  1576,  1667,
          2894,   310, 20612,   948, 26533,   322,  1301, 29886, 12232,   297,
          6133, 18577, 29892,  5491, 19849,   310,   263, 12151,  7933, 12995,
           311, 10959,   304,   278, 20805,  4153,   470,   491,   263,   380,
          2235,  1213,    13, 29908,   272, 12644,  