# Setup

In [1]:
import torch as th
import numpy as np
import pandas as pd
import json
import matplotlib.pyplot as plt
from pathlib import Path
from time import time
import itertools
import gc
from random import shuffle

_ = th.set_grad_enabled(False)

In [2]:
exp_name = "mean_repr_def"

In [3]:
# papermill parameters
batch_size = 8
model = "google/gemma-2-2b"
model_path = None
trust_remote_code = False
device = "auto"
remote = False
num_few_shot = 5
exp_id = "test"
extra_args = []
use_tl = False

In [4]:
INTERACTIVE_MODE = exp_id == "test"  # Set to True to run the notebook interactively
DEBUG = exp_id == "test" and True
import sys
if DEBUG:
    print("Debugging...")
if INTERACTIVE_MODE:
    print("Interactive mode!")
    exp_id = "test" + str(int(time()))
    sys.path.append("../src")
    %load_ext autoreload
    %autoreload 2
    from tqdm.notebook import tqdm, trange
else:
    sys.path.append("./src")
    from tqdm import tqdm, trange

Debugging...
Interactive mode!


In [5]:
from nnterp import load_model
from dlabutils import model_path as dlab_model_path
from argparse import ArgumentParser
from sentence_transformers import SentenceTransformer

parser = ArgumentParser()
parser.add_argument("--gen-batch-size", type=int, default=batch_size // 2)
parser.add_argument(
    "--embeddings-model",
    type=str,
    default="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
)
parser.add_argument("--emb-batch-size", type=int, default=256)
parser.add_argument("--debug", action="store_true")
pargs = parser.parse_args(extra_args)
DEBUG = pargs.debug
if DEBUG:
    print("/!\\ Debugging...")
    exp_name = "debug-" + exp_name
gen_batch_size = pargs.gen_batch_size
emb_batch_size = pargs.emb_batch_size
embeddings_model = SentenceTransformer(
    dlab_model_path(pargs.embeddings_model), device=device if device != "auto" else None
)

if model_path is None:
    model_path = dlab_model_path(model)
nn_model = load_model(
    model_path,
    trust_remote_code=trust_remote_code,
    device_map=device,
    use_tl=use_tl,
    no_space_on_bos=True,
)
tokenizer = nn_model.tokenizer
model_name = model.split("/")[-1]



# Experiment functions

In [1]:
from nnterp.nnsight_utils import (
    collect_activations_batched,
    get_num_layers,
    get_layer_output,
)
from random import sample
from load_dataset import get_word_translation_dataset, get_cloze_dataset, load_synset
from utils import ulist
from prompt_tools import translation_prompts, def_prompt, get_obj_id
from torch.utils.data import DataLoader
import ast
import itertools
import warnings
from copy import deepcopy
from collections import defaultdict
from plot_utils import plot_defs_comparison, plot_compare_setup, plot_losses_comparison
import torch.nn.functional as F
from itertools import chain


def extract_def(generation):
    match generation[-2], generation[-1]:
        case '"', "\n":
            pass
        case _, "\n":
            end = generation.split("\n")[-2]
            warnings.warn(f"Last line does not end with a quote: {end}")
            generation = generation[:-1] + '"\n'
        case _, '"':
            generation += "\n"
        case _:
            end = generation.split("\n")[-1]
            warnings.warn(f"Last line does not end with a quote and newline: {end}")
            generation += '"\n'

    last_line = generation.split("\n")[-2]
    split = last_line.split('"')
    return '"'.join(split[3:-1])


def patched_generation(prompts, reprs, gen_batch_size):
    pos = get_obj_id(prompts[0], tokenizer)
    out = []
    for i in trange(
        0, len(prompts), gen_batch_size, desc="Generating patched generations"
    ):
        end = min(i + gen_batch_size, len(prompts))
        with nn_model.generate(
            prompts[i:end],
            max_new_tokens=50,
            stop_strings=["\n"],
            tokenizer=nn_model.tokenizer,
        ):
            if reprs is not None:
                for layer in range(get_num_layers(nn_model)):
                    get_layer_output(nn_model, layer)[:, pos] = reprs[i:end, layer]
            out_model = nn_model.generator.output.tolist().save()
        out.extend(out_model)
    return out


def loss_on_defs(prompts: list[str], defs: dict[str, list[str]], reprs=None):
    """
    Compute the loss on the definitions
    Args:
        prompts: list of prompts for each concept
        defs: dictionary of concept -> list of definitions
        reprs: optional tensor of shape (num_prompts, num_layers, model_dim) containing the representations to patch
    Returns:
        dictionary of concept -> list of losses
    """
    patch_pos = []
    inputs = []
    def_lengths = []
    all_defs = list(itertools.chain.from_iterable(defs.values()))
    all_prompts = list(
        itertools.chain.from_iterable(
            [p] * len(def_) for p, def_ in zip(prompts, defs.values())
        )
    )
    repr_index = list(
        itertools.chain.from_iterable(
            [[i] * len(def_) for i, def_ in enumerate(defs.values())]
        )
    )
    for prompt, def_ in zip(all_prompts, all_defs):
        pos = get_obj_id(prompt, tokenizer)
        def_ = def_ + '"'
        def_length = len(tokenizer.encode(def_, add_special_tokens=False))
        def_lengths.append(def_length)
        patch_pos.append(pos - def_length)
        prompt_toks = tokenizer(prompt).input_ids
        def_toks = tokenizer(def_, add_special_tokens=False).input_ids
        input = {
            "input_ids": prompt_toks + def_toks,
            "attention_mask": [1] * (len(prompt_toks) + len(def_toks)),
        }
        inputs.append(input)
    def_lengths = th.tensor(def_lengths)
    def_losses = []
    for i in trange(0, len(inputs), gen_batch_size, desc="Computing losses"):
        end = min(i + gen_batch_size, len(inputs))
        batch_inputs = tokenizer.pad(inputs[i:end], return_tensors="pt")
        batch_patch_pos = th.tensor(patch_pos[i:end])
        batch_repr_index = th.tensor(repr_index[i:end])
        with nn_model.trace(batch_inputs):
            if reprs is not None:
                for layer in range(get_num_layers(nn_model)):
                    get_layer_output(nn_model, layer)[
                        th.arange(end - i), batch_patch_pos
                    ] = reprs[batch_repr_index, layer]
            logits = nn_model.output.logits.save()
        # compute loss
        mask = th.arange(logits.size(1)).unsqueeze(0) >= (
            logits.size(1) - (def_lengths[i:end].unsqueeze(1))
        )
        assert mask.shape[0] == logits.shape[0] and mask.shape[1] == logits.shape[1]
        logits = logits[:, :-1, :]
        labels = batch_inputs.input_ids
        labels = labels[:, 1:].to(logits.device)
        labels[~mask[:, 1:]] = -100
        loss = F.cross_entropy(logits.transpose(1, 2), labels, reduction="none").sum(
            dim=1
        ) / mask[:, 1:].sum(
            dim=1
        ).to(logits.device)  # num_defs
        assert loss.dim() == 1
        def_losses.append(loss.cpu())
    losses = th.cat(def_losses)
    losses_dict = {}
    prev_idx = 0
    for concept in defs:
        losses_dict[concept] = losses[prev_idx : prev_idx + len(defs[concept])].tolist()
        prev_idx += len(defs[concept])
    return losses_dict


def patched_repr_defs(
    lang_pairs, target_lang, path=None, num_other_for_loss=4, generate_generations=True
):
    lang_pairs = np.array(lang_pairs)
    output_langs = ulist(lang_pairs[:, 1])
    trans_dataset = get_word_translation_dataset(
        target_lang, ulist(lang_pairs.flatten()), v2=True
    )
    full_def_dataset = get_cloze_dataset(
        output_langs + [target_lang], drop_no_defs=True
    )
    # compute intersection between trans_dataset and def_dataset
    common_concepts = set(trans_dataset["word_original"]).intersection(
        set(full_def_dataset["word_original"])
    )
    if DEBUG:
        common_concepts = sample(list(common_concepts), 8)
        print("Debug mode: using only 8 common concepts")
    trans_dataset = trans_dataset[trans_dataset["word_original"].isin(common_concepts)]
    def_dataset = full_def_dataset[
        full_def_dataset["word_original"].isin(common_concepts)
    ]
    common_concepts = def_dataset["word_original"].tolist()
    assert len(trans_dataset) == len(def_dataset)
    print(f"Found {len(common_concepts)} common concepts")
    trans_source_prompts = np.array(
        [
            translation_prompts(
                trans_dataset,
                tokenizer,
                input_lang,
                output_lang,
                n=num_few_shot,
                cut_at_obj=True,
            )
            for input_lang, output_lang in lang_pairs
        ]
    )
    trans_source_prompts_str = list(
        chain.from_iterable(
            [p.prompt for p in prompts] for prompts in trans_source_prompts
        )
    )
    def_source_prompts = [
        def_prompt(
            def_dataset,
            tokenizer,
            output_lang,
            n=num_few_shot,
            use_word_to_def=True,
            cut_at_obj=True,
        )
        for output_lang in output_langs
    ]
    def_source_prompts_str = list(
        chain.from_iterable(
            [p.prompt for p in prompts] for prompts in def_source_prompts
        )
    )

    gt_defs, _ = ground_truth_defs(target_lang, common_concepts)
    trans_activations = (
        collect_activations_batched(
            nn_model,
            list(trans_source_prompts_str),
            batch_size=batch_size,
            tqdm=tqdm,
        )
        .transpose(0, 1)
        .reshape(len(lang_pairs), len(trans_dataset), get_num_layers(nn_model), -1)
    )  # num_langs, num_concepts, num_layers, model_dim
    trans_mean_reprs = trans_activations.mean(
        dim=0
    )  # num_concepts, num_layers, model_dim
    trans_single_reprs = trans_activations[0]  # num_concepts, num_layers, model_dim
    def_activations = (
        collect_activations_batched(
            nn_model,
            def_source_prompts_str,
            batch_size=batch_size,
            tqdm=tqdm,
        )
        .transpose(0, 1)
        .reshape(len(output_langs), len(def_dataset), get_num_layers(nn_model), -1)
    )  # num_langs, num_concepts, num_layers, model_dim
    def_mean_reprs = def_activations.mean(dim=0)  # num_concepts, num_layers, model_dim
    def_single_reprs = def_activations[0]  # num_concepts, num_layers, model_dim
    target_prompts = []
    for concept in common_concepts:
        safe_df = full_def_dataset[full_def_dataset["word_original"] != concept]
        safe_prompt = sample(
            def_prompt(
                safe_df,
                tokenizer,
                target_lang,
                n=num_few_shot,
                use_word_to_def=True,
                cut_at_obj=False,
            ),
            1,
        )[0]
        target_prompts.append(safe_prompt)
    target_prompts_str = [p.prompt for p in target_prompts]
    baseline_target_prompts_str = [
        p.prompt
        for p in def_prompt(
            def_dataset,
            tokenizer,
            target_lang,
            n=num_few_shot,
            use_word_to_def=True,
            cut_at_obj=False,
        )
    ]
    other_concept_defs = [{} for _ in range(num_other_for_loss)]
    for i, concept in enumerate(common_concepts):
        other_concepts = th.randint(0, len(common_concepts) - 1, (num_other_for_loss,))
        other_concepts[other_concepts >= i] += 1
        for j, other_concept in enumerate(other_concepts):
            other_concept_defs[j][concept] = gt_defs[common_concepts[other_concept]]
    generations = {}
    losses = {}
    other_concepts_losses = [{} for _ in range(num_other_for_loss)]

    for repr_type, reprs in [
        ("from def", def_mean_reprs),
        ("from single def", def_single_reprs),
        ("from trans", trans_mean_reprs),
        ("from single trans", trans_single_reprs),
        ("prompting", None),
    ]:
        tgt_prompts = target_prompts_str
        if repr_type == "prompting":
            tgt_prompts = baseline_target_prompts_str
        if generate_generations:
            output = patched_generation(tgt_prompts, reprs, gen_batch_size)
            generations[repr_type] = tokenizer.batch_decode(
                output, skip_special_tokens=True
            )
        losses[repr_type] = loss_on_defs(tgt_prompts, gt_defs, reprs)
        for i, other_defs in enumerate(other_concept_defs):
            other_concepts_losses[i][repr_type] = loss_on_defs(
                tgt_prompts, other_defs, reprs
            )

    json_file = path / ("patched_generations_and_losses.json")
    defs = defaultdict(dict)
    generations_dict = defaultdict(dict)
    losses_dict = defaultdict(lambda: defaultdict(dict))
    for repr_type in [
        "from def",
        "from single def",
        "from trans",
        "from single trans",
        "prompting",
    ]:
        for i, concept in enumerate(common_concepts):
            if generate_generations:
                gen = generations[repr_type][i]
                defs[concept][repr_type] = extract_def(gen)
                generations_dict[concept][repr_type] = gen
            losses_dict[concept][repr_type]["other"] = [
                ocl[repr_type][concept] for ocl in other_concepts_losses
            ]
            losses_dict[concept][repr_type]["normal"] = losses[repr_type][concept]
    json_dic = {
        "defs": defs,
        "generations": generations_dict,
        "losses": losses_dict,
    }
    with open(json_file, "w") as f:
        json.dump(json_dic, f, indent=4)
    return defs, losses_dict, other_concept_defs


def generate_generations(prompts, concepts):
    str_prompts = [p.prompt for p in prompts]
    dataloader = DataLoader(str_prompts, batch_size=batch_size)
    generations = []
    for batch in dataloader:
        with nn_model.generate(
            batch, max_new_tokens=50, stop_strings=["\n"], tokenizer=nn_model.tokenizer
        ):
            out = nn_model.generator.output.tolist().save()
        generations.extend(out)
    generations = tokenizer.batch_decode(generations, skip_special_tokens=True)
    generations = {
        concept: "...\n" + "\n".join(generation.split("\n")[-3:-1])
        for concept, generation in zip(concepts, generations)
    }
    defs = {
        concept: extract_def(generation) for concept, generation in generations.items()
    }
    return {"defs": defs, "generations": generations}


def word_patching_defs(
    source_langs, target_lang, concepts, gt_defs, other_concept_defs, exp_path
):
    """
    Generate a definition in target_lang from the source_lang concept
    """
    df = get_cloze_dataset(source_langs + [target_lang], drop_no_defs=True)
    df = df[df["word_original"].isin(concepts)]
    assert len(df) == len(concepts)
    prompts = []
    for concept in concepts:
        safe_df = df[df["word_original"] != concept]
        source_lang = sample(source_langs, 1)[0]
        # add a new row with word_original = concept, senses_{target_lang} = senses_{source_lang}
        original_row = df[df["word_original"] == concept].iloc[0]
        new_row = original_row.copy()
        new_row[f"senses_{target_lang}"] = original_row[f"senses_{source_lang}"]
        safe_df = pd.concat([safe_df, pd.DataFrame([new_row])], ignore_index=True)
        prompt = def_prompt(
            safe_df,
            tokenizer,
            target_lang,
            n=num_few_shot,
            use_word_to_def=True,
            concepts=[concept],
        )
        prompts.append(prompt[0])
    result_dict = generate_generations(prompts, concepts)
    losses = {
        concept: {"normal": loss, "other": []}
        for concept, loss in loss_on_defs([p.prompt for p in prompts], gt_defs).items()
    }
    other_concepts_losses = [
        loss_on_defs([p.prompt for p in prompts], other_defs)
        for other_defs in other_concept_defs
    ]
    for concept in concepts:
        for other_losses in other_concepts_losses:
            losses[concept]["other"].append(other_losses[concept])
    json_file = exp_path / ("word_patching_defs.json")
    with open(json_file, "w") as f:
        json.dump(result_dict, f, indent=4)
    return result_dict["defs"], losses


# def gen_baseline_defs(lang, concepts, exp_path):
#     df = get_cloze_dataset(lang, drop_no_defs=True)
#     prompts = def_prompt(
#         df, tokenizer, lang, n=num_few_shot, use_word_to_def=True, concepts=concepts
#     )
#     json_dic = generate_generations(prompts, concepts)
#     json_file = exp_path / (exp_id + "baseline_defs.json")
#     with open(json_file, "w") as f:
#         json.dump(json_dic, f, indent=4)
#     return json_dic["defs"]


def ground_truth_defs(lang, concept):
    df = load_synset(lang)
    gt_defs = {
        concept: ulist(
            ast.literal_eval(
                df[df["word_original"] == concept]["definitions"].tolist()[0]
            )
        )
        for concept in concept
    }
    tgt_words = {
        concept: ast.literal_eval(
            df[df["word_original"] == concept]["senses"].tolist()[0]
        )[0]
        for concept in concept
    }
    return gt_defs, tgt_words


def generate_embeddings(prompts):
    if isinstance(prompts, dict):
        prompts = list(itertools.chain.from_iterable(prompts.values()))
    embeddings = embeddings_model.encode(
        prompts, batch_size=emb_batch_size, convert_to_tensor=True
    )
    return embeddings


def compare_defs(defs, gt_defs_embeddings):
    """
    Compare the definitions generated by the model to the ground truth definitions.

    Args:
        defs: a dictionary of concept -> list of definitions
        gt_defs_embeddings: a tensor of shape (num_concepts, num_defs, embedding_dim) containing the embeddings of the ground truth definitions

    Returns:
    """
    all_defs = list(itertools.chain.from_iterable(defs.values()))
    all_defs_embeddings = generate_embeddings(all_defs)
    similarities = embeddings_model.similarity(all_defs_embeddings, gt_defs_embeddings)
    return similarities


def plot_results(result_dict, method_names, title_suffix, path):
    # === Plotting ===
    extra_path = path / "extras"
    extra_path.mkdir(exist_ok=True)
    for method, title in method_names.items():
        plot_defs_comparison(result_dict, method, title, extra_path, show=False)
    plot_compare_setup(result_dict, path, title_suffix.replace("<br>", ""), exp_id)


def experiment(lang_pairs, target_lang):
    pref = "_".join("-".join(ls) for ls in lang_pairs)
    title_suffix = (
        "<br>(" + ", ".join([t for s, t in lang_pairs]) + ") -> " + target_lang
    )
    method_names = {
        "from trans": "Patching mean representation from translations" + title_suffix,
        "from def": "Patching mean representation from definitions" + title_suffix,
        "from single trans": "Patching single representation from translations"
        + title_suffix,
        "from single def": "Patching single representation from definitions"
        + title_suffix,
        "word patch": "Word patching" + title_suffix,
        "prompting": "Vanilla prompting" + f" {target_lang}",
        "repeat word": "Repeat word" + f" {target_lang}",
        "rnd gt": "Random GT definition" + f" {target_lang}",
    }
    methods = list(method_names.keys())
    path = Path("results") / model_name / exp_name / f"{pref}-{target_lang}" / exp_id
    path.mkdir(parents=True, exist_ok=True)
    source_output_langs = ulist([p[1] for p in lang_pairs])
    # === Generate Definitions ===
    defs, losses, other_concept_defs = patched_repr_defs(lang_pairs, target_lang, path)
    concepts = defs.keys()
    gt_defs, tgt_words = ground_truth_defs(target_lang, concepts)
    token_patching_defs, word_patching_losses = word_patching_defs(
        source_output_langs, target_lang, concepts, gt_defs, other_concept_defs, path
    )
    for w in losses:
        losses[w]["word patch"] = word_patching_losses[w]
    with open(path / "losses.json", "w") as f:
        json.dump(losses, f, indent=4)
    dict_factory_losses = lambda: {m: defaultdict(dict) for m in methods[:-2]}
    losses_dict = defaultdict(dict_factory_losses)
    for w in losses:
        for setup in losses[w]:
            max_loss = max(losses[w][setup]["normal"])
            losses_dict[w][setup]["max loss"] = max_loss
            losses_dict[w][setup]["mean loss"] = np.mean(losses[w][setup]["normal"])
            losses_dict[w][setup]["min loss"] = np.min(losses[w][setup]["normal"])
            all_losses = list(chain.from_iterable(losses[w][setup]["other"]))
            losses_dict[w][setup]["mean loss with others"] = np.mean(all_losses)
            losses_dict[w][setup]["mean max loss with others"] = np.mean(
                [max(other_losses) for other_losses in losses[w][setup]["other"]]
            )
            losses_dict[w][setup]["mean min loss with others"] = np.mean(
                [min(other_losses) for other_losses in losses[w][setup]["other"]]
            )
    with open(path / "losses_stats.json", "w") as f:
        json.dump(losses_dict, f, indent=4)
    plot_losses_comparison(losses_dict, path, title_suffix)

    for w_defs in gt_defs.values():
        shuffle(w_defs)
    all_defs = {
        concept: [
            defs[concept]["from trans"],
            defs[concept]["from def"],
            defs[concept]["from single trans"],
            defs[concept]["from single def"],
            token_patching_defs[concept],
            defs[concept]["prompting"],
            tgt_words[concept],
            gt_defs[concept][0],
        ]
        for concept in defs
    }

    # === Generate Embeddings ===
    gt_defs_embeddings = generate_embeddings(gt_defs)
    all_idx = []
    gt_embeddings_dict = {}
    prev_idx = 0
    for concept, defs_list in gt_defs.items():
        gt_embeddings_dict[concept] = gt_defs_embeddings[
            prev_idx : prev_idx + len(defs_list)
        ]
        prev_idx += len(defs_list)
        all_idx.append(prev_idx)
    all_idx.append(0)  # for the first concept

    mean_embeddings = th.stack([emb.mean(dim=0) for emb in gt_embeddings_dict.values()])
    mean_embeddings_wo_fst = th.stack(
        [
            emb[1:].mean(dim=0) if len(emb) > 1 else emb.mean(dim=0)
            for emb in gt_embeddings_dict.values()
        ]
    )
    assert (
        mean_embeddings.shape[0] == len(defs)
        and mean_embeddings.shape[1] == gt_defs_embeddings.shape[1]
    ), f"Shape mismatch: mean_embeddings.shape={mean_embeddings.shape} != gt_defs_embeddings.shape={gt_defs_embeddings.shape}, len(defs)={len(defs)}"
    similarities = compare_defs(all_defs, gt_defs_embeddings)
    similarities_mean = compare_defs(all_defs, mean_embeddings)
    similarities_mean_wo_fst = compare_defs(all_defs, mean_embeddings_wo_fst)

    dict_factory = lambda: {m: {} for m in methods}

    result_dict = defaultdict(dict_factory)
    for i, concept in enumerate(defs):
        for j, method in enumerate(methods):
            sim_w_mean = similarities_mean[len(methods) * i + j][i]
            sim_w_mean_wo_fst = similarities_mean_wo_fst[len(methods) * i + j][i]
            start_idx = all_idx[i - 1]
            if method == "rnd gt":
                start_idx += 1  # skip self
            sims = similarities[len(methods) * i + j][start_idx : all_idx[i]]
            other_sims = th.cat(
                [
                    similarities[len(methods) * i + j][0 : all_idx[i - 1]],
                    similarities[len(methods) * i + j][all_idx[i] :],
                ]
            )
            max_sim_with_others = []
            for k in range(len(defs)):
                if k == i:
                    continue
                max_sim_with_others.append(
                    similarities[len(methods) * i + j][all_idx[k - 1] : all_idx[k]]
                    .max()
                    .item()
                )
            result_dict[concept][method][
                "mean sim with others"
            ] = other_sims.mean().item()
            result_dict[concept][method]["mean max sim with others"] = np.mean(
                max_sim_with_others
            )
            if method == "rnd gt" and len(gt_defs[concept]) == 1:
                # doesn't make sense to compare to itself
                result_dict[concept][method]["sim w mean"] = None
                result_dict[concept][method]["sim w mean fst"] = None
                result_dict[concept][method]["mean sim"] = None
                result_dict[concept][method]["max sim"] = None
            else:
                result_dict[concept][method]["sim w mean"] = sim_w_mean.item()
                result_dict[concept][method][
                    "sim w mean fst"
                ] = sim_w_mean_wo_fst.item()
                result_dict[concept][method]["mean sim"] = sims.mean().item()
                result_dict[concept][method]["max sim"] = sims.max().item()

    json_file = path / ("defs_comparison.json")
    with open(json_file, "w") as f:
        json.dump(result_dict, f, indent=4)
    plot_results(result_dict, method_names, title_suffix, path)
    return result_dict, method_names, title_suffix, path

ModuleNotFoundError: No module named 'load_dataset'

# Run experiment

In [7]:
paper_pairs = [["de", "it"], ["nl", "fi"], ["zh", "es"], ["es", "ru"], ["ru", "ko"]]
paper_ins = ["de", "nl", "zh", "es", "ru"]
test_pairs = [["es", "fr"], ["es", "de"]]
hard_langs = ["ko", "ja", "et", "fi"]
easy_langs = ["en", "fr", "zh", "de"]
paper_args = [
    ([["es", "fr"], ["es", "de"]], "en"),
    (paper_pairs, "zh"),
    ([["it", "es"], ["it", "de"]], "fr"),
    (paper_pairs, "en"),
    ([["en", hl] for hl in hard_langs], "fr"),
    ([["hi", el] for el in easy_langs], "et"),
]
if DEBUG:
    paper_args = [paper_args[0]]
for pargs in paper_args:
    try:
        plot_args = experiment(*pargs)
    finally:
        gc.collect()
        th.cuda.empty_cache()
        th.cuda.synchronize()

Found 156 common concepts


  0%|          | 0/39 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

--- Logging error ---
Traceback (most recent call last):
  File "/dlabscratch1/cdumas/thinking-lang/.conda/lib/python3.11/logging/__init__.py", line 1110, in emit
    msg = self.format(record)
          ^^^^^^^^^^^^^^^^^^^
  File "/dlabscratch1/cdumas/thinking-lang/.conda/lib/python3.11/logging/__init__.py", line 953, in format
    return fmt.format(record)
           ^^^^^^^^^^^^^^^^^^
  File "/dlabscratch1/cdumas/thinking-lang/.conda/lib/python3.11/logging/__init__.py", line 687, in format
    record.message = record.getMessage()
                     ^^^^^^^^^^^^^^^^^^^
  File "/dlabscratch1/cdumas/thinking-lang/.conda/lib/python3.11/logging/__init__.py", line 377, in getMessage
    msg = msg % self.args
          ~~~~^~~~~~~~~~~
TypeError: not all arguments converted during string formatting
Call stack:
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/dlabscratch1/cdumas/thinking-lang/.conda/lib/python3.11/site-package

  0%|          | 0/39 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

  0%|          | 0/293 [00:00<?, ?it/s]

You're using a GemmaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


i: 0 < len(inputs): 1169
batch_patch_pos dtype (int?): torch.int64
batch_repr_index dtype (int?): torch.int64
inputs len : 8
{'input_ids': tensor([[     0,      0,      0,  ...,    689, 117095, 235281],
        [     2, 235281,  23427,  ...,    573,   9561,   1464],
        [     0,      0,      0,  ...,    576,  19071,   1464],
        ...,
        [     0,      0,      0,  ...,   1809,   2377, 235281],
        [     0,      0,      0,  ...,   3818,  13120, 235281],
        [     0,      0,      0,  ...,   6792,   9561,  71982]]), 'attention_mask': tensor([[0, 0, 0,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        ...,
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1]])}
mask: torch.Size([8, 192])
logits: torch.Size([8, 192, 256000])
reprs: torch.Size([8, 26, 2304])
i: 4 < len(inputs): 1169
batch_patch_pos dtype (int?): torch.int64
batch_repr_index dtype (int?): torch.int64
inputs len : 8
{'i