<span style="color:red; font-family:Helvetica Neue, Helvetica, Arial, sans-serif; font-size:2em;">An Exception was encountered at '<a href="#papermill-error-cell">In [8]</a>'.</span>

In [1]:
INTERACTIVE_MODE = True  # Set to True to run the notebook interactively

import sys

if INTERACTIVE_MODE:
    sys.path.append("./src")
    %load_ext autoreload
    %autoreload 3
    from tqdm.notebook import tqdm
else:
    sys.path.append("./src")
    from tqdm import tqdm

In [2]:
import torch as th
import numpy as np
import pandas as pd
import json
import matplotlib.pyplot as plt
from pathlib import Path
from time import time
import itertools
from random import shuffle

_ = th.set_grad_enabled(False)

In [3]:
exp_name = "obj_patch_translation"

In [4]:
# papermill parameters
batch_size = 8
model = "meta-llama/Llama-2-7b-hf"
model_path = None
trust_remote_code = False
device = "auto"
remote = False
num_few_shot = 5
exp_id = None
extra_args = []
use_tl = False

In [5]:
# Parameters
model = "bigscience/bloom-7b1"
model_path = None
remote = False
trust_remote_code = False
device = "auto"
batch_size = 8
num_few_shot = 5
use_tl = False
exp_id = "test"
extra_args = []


In [8]:
from nnterp import load_model
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument("--num-patches", type=int, default=-1)
pargs = parser.parse_args(extra_args)
num_patches = pargs.num_patches

if model_path is None:
    model_path = model
nn_model = load_model(
    model_path,
    trust_remote_code=trust_remote_code,
    no_space_on_bos=True,
    device_map=device,
    use_tl=use_tl,
)
tokenizer = nn_model.tokenizer

In [9]:
nn_model

BloomForCausalLM(
  (model): BloomModel(
    (word_embeddings): Embedding(250880, 4096)
    (word_embeddings_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
    (layers): ModuleList(
      (0-29): 30 x BloomBlock(
        (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (self_attn): BloomAttention(
          (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)
          (dense): Linear(in_features=4096, out_features=4096, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (mlp): BloomMLP(
          (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)
          (gelu_impl): BloomGelu()
          (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)
        )
      )
    )
    (ln_final): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
  )
  

## Plots

In [7]:
from exp_tools import (
    run_prompts,
)
from interventions import (
    object_lens,
    collect_activations,
    collect_activations_batched,
    get_num_layers,
)
from prompt_tools import translation_prompts, get_obj_id
from load_dataset import get_word_translation_dataset as get_translations

from utils import ulist
from display_utils import plot_topk_tokens, plot_results, plot_k_results, k_subplots
from copy import deepcopy


def object_patching_plot(
    source_lang_pairs,
    input_lang,
    target_lang,
    extra_langs=None,
    batch_size=batch_size,
    num_words=None,
    num_pairs=200,
    exp_id=exp_id,
    k=4,
    remote=remote,
):
    """
    Experiment 2 of the paper:
    - For each source_lang_pairs, construct a prompt translating the same concept (e.g. DOG):
    L1: "CAT^L1" - L2: "CAT^L2"
    ...
    L1: "DOG^L1

    - Collect activation at the last token of the prompt and generate a mean latent representation for each layer

    - For each layer `j`: Run the target prompts which are translations from the input_lang to the target_lang. During the forward pass, patch at the last token of the concept to be translated with the mean latent representation of the source prompts from `j` to the last layer.

    We plot both the probabilities you get from the mean latent and the probabilities you get from the first source_lang_pairs latent.
    """
    source_lang_pairs = np.array(source_lang_pairs)
    if extra_langs is None:
        extra_langs = []
    if isinstance(extra_langs, str):
        extra_langs = [extra_langs]
    model_name = model.split("/")[-1]
    global source_df, target_df, target_prompts, target_probs, latent_probs, source_prompts, _source_prompts, _target_prompts
    if exp_id is None:
        exp_id = str(int(time()))
    else:
        exp_id = str(exp_id)
    source_df = get_translations(
        "en",
        ulist([*source_lang_pairs.flatten(), input_lang, target_lang, *extra_langs]),
        num_words,
    )
    target_df = get_translations(
        input_lang,
        ulist([*source_lang_pairs.flatten(), target_lang, *extra_langs]),
        num_words,
    )

    _source_prompts = list(
        zip(
            *[
                translation_prompts(
                    source_df,
                    nn_model.tokenizer,
                    inp_lang,
                    targ_lang,
                    [target_lang, *extra_langs],
                    augment_tokens=False,
                    n=num_few_shot,
                )
                for inp_lang, targ_lang in source_lang_pairs
            ]
        )
    )
    _target_prompts = translation_prompts(
        target_df,
        nn_model.tokenizer,
        input_lang,
        target_lang,
        [*extra_langs],  # [*list(zip(*source_lang_pairs))[1], *extra_langs],
        augment_tokens=False,
        n=num_few_shot,
    )

    collected_pairs = 0
    source_prompts = []
    target_prompts = []
    source_target = list(itertools.product(source_df.iterrows(), target_df.iterrows()))
    shuffle(source_target)

    for (i, source_row), (j, target_row) in source_target:
        if source_row["word_original"] == target_row["word_original"]:
            continue
        src_p = _source_prompts[i]
        targ_p = deepcopy(_target_prompts[j])
        latent_tokens = {f"source_{target_lang}": src_p[0].latent_tokens[target_lang]}
        latent_tokens.update(**targ_p.latent_tokens)
        targ_p.latent_tokens = latent_tokens
        targ_p.latent_strings[f"sources"] = ulist(
            sum([p.target_strings for p in src_p], [])
        )
        targ_p.latent_strings[f"source_{target_lang}"] = src_p[0].latent_strings[
            target_lang
        ]
        for lang in extra_langs:
            targ_p.latent_tokens[f"src + tgt {lang}"] = ulist(
                targ_p.latent_tokens[lang] + src_p[0].latent_tokens[lang]
            )
            targ_p.latent_strings[f"src + tgt {lang}"] = ulist(
                targ_p.latent_strings[lang] + src_p[0].latent_strings[lang]
            )
            del targ_p.latent_tokens[lang]
        if targ_p.has_no_collisions():
            source_prompts.append(src_p)
            target_prompts.append(targ_p)
            collected_pairs += 1
        if collected_pairs >= num_pairs:
            break
    if collected_pairs < num_pairs:
        print(
            f"Could only collect {collected_pairs} pairs for {source_lang_pairs.tolist()} - {input_lang} -> {target_lang}, skipping..."
        )
        return
    source_prompts = np.array(source_prompts)
    source_prompts_str = np.array(
        [['"'.join(p.prompt.split('"')[:-2]) for p in ps] for ps in source_prompts]
    )
    idx = get_obj_id(target_prompts[0].prompt, nn_model.tokenizer)

    def object_patching(
        nn_model, prompt_batch, scan, source_prompt_batch=None, only_first=False
    ):
        offset = object_patching.offset
        batch_size = len(prompt_batch)
        if source_prompt_batch is None:
            source_prompt_batch = source_prompts_str[offset : offset + batch_size]
        if only_first:
            source_prompt_batch = source_prompt_batch[:, :1]
        hiddens = collect_activations_batched(
            nn_model,
            source_prompt_batch.flatten(),
            batch_size=batch_size,
            remote=remote,
        )
        hiddens = hiddens.transpose(0, 1)  # (all_prompts, layer, hidden_size)
        hiddens = hiddens.reshape(
            batch_size, source_prompt_batch.shape[1], get_num_layers(nn_model), -1
        ).mean(
            dim=1
        )  # (batch_size, num_layers, hidden_size)
        hiddens = hiddens.transpose(0, 1)  # (num_layers, batch_size, hidden_size)
        object_patching.offset += batch_size
        return object_lens(
            nn_model,
            prompt_batch,
            idx,
            hiddens=hiddens,
            scan=scan,
            num_patches=num_patches,
            remote=remote,
        )

    object_patching.offset = 0
    target_probs, latent_probs = run_prompts(
        nn_model,
        target_prompts,
        batch_size=batch_size,
        get_probs=object_patching,
        tqdm=tqdm,
    )

    object_patching.offset = 0
    of_target_probs, of_latent_probs = run_prompts(
        nn_model,
        target_prompts,
        batch_size=batch_size,
        get_probs=object_patching,
        get_probs_kwargs=dict(only_first=True),
        tqdm=tqdm,
    )

    # Get the baseline to normalize the plots
    all_source_prompts = source_prompts.flatten()
    source_prompts_probs = (
        run_prompts(
            nn_model,
            all_source_prompts,
            batch_size=batch_size,
            get_probs_kwargs=dict(remote=remote),
            tqdm=tqdm,
        )[0]
        .squeeze()
        .reshape(len(source_prompts), -1)
    )

    target_prompts_probs, _ = run_prompts(
        nn_model,
        target_prompts,
        batch_size=batch_size,
        get_probs_kwargs=dict(remote=remote),
        tqdm=tqdm,
    )

    json_dic = {
        target_lang: target_probs.tolist(),
        "source prompt probs": source_prompts_probs.squeeze().tolist(),
        "target prompt probs": target_prompts_probs.squeeze().tolist(),
    }
    for label, probs in latent_probs.items():
        json_dic[label] = probs.tolist()
    json_dic["only first"] = {target_lang: of_target_probs.tolist()}
    for label, probs in of_latent_probs.items():
        json_dic["only first"][label] = probs.tolist()
    pref = "_".join("-".join(ls) for ls in source_lang_pairs)
    path = (
        Path("results")
        / model_name
        / exp_name
        / (f"{pref}-{input_lang}_{target_lang}-")
    )
    path.mkdir(parents=True, exist_ok=True)
    json_file = path / (exp_id + ".json")
    with open(json_file, "w") as f:
        json.dump(json_dic, f, indent=4)

    fig, ax = plt.subplots(figsize=(10, 5))
    pref = pref.replace("_", " ")
    title = f"{model_name}: ObjPatch from ({pref}) into ({input_lang} -> {target_lang})"
    plot_results(
        ax,
        target_probs,
        latent_probs,
        target_lang,
        source_baseline=source_prompts_probs.mean(),
        target_baseline=target_prompts_probs.mean(),
    )
    ax.legend()
    ax.set_title(title)
    plt.tight_layout()
    plot_file = path / (exp_id + ".png")
    plt.savefig(plot_file, dpi=300, bbox_inches="tight")
    plt.show()

    fig, ax = plt.subplots(figsize=(10, 5))
    plot_results(ax, of_target_probs, of_latent_probs, target_lang)
    ax.legend()
    pref2 = pref.split(" ")[0]
    ax.set_title(
        f"{model_name}: ObjPatch from ({pref2}) into ({input_lang} -> {target_lang}"
    )
    plt.tight_layout()
    plot_file = path / (exp_id + "_only_first.png")
    plt.savefig(plot_file, dpi=300, bbox_inches="tight")
    plt.show()

    # Plot k examples
    fig, axes = k_subplots(k)
    plot_k_results(axes, target_probs, latent_probs, target_lang, k)
    axes[k - 1].legend()
    fig.suptitle(title)
    plt_file = path / (exp_id + "_k.png")
    fig.savefig(plt_file, dpi=300, bbox_inches="tight")
    fig.show()
    # Compute a single example
    json_meta = {}
    for i in range(k):
        json_meta[i] = {
            "source pairs": source_lang_pairs.tolist(),
            "input lang": input_lang,
            "target lang": target_lang,
            "source prompt": {
                "-".join(l) + " " + str(j): source_prompts_str[i][j]
                for j, l in enumerate(source_lang_pairs)
            },
            "source prompt target": {
                "-".join(l) + " " + str(j): source_prompts[i][j].target_strings
                for j, l in enumerate(source_lang_pairs)
            },
            "source prompt latent": {
                "-".join(l) + " " + str(j): source_prompts[i][j].latent_strings
                for j, l in enumerate(source_lang_pairs)
            },
            "target prompt": target_prompts[i].prompt,
            "target prompt target": target_prompts[i].target_strings,
            "target prompt latent": target_prompts[i].latent_strings,
        }
    json_df = pd.DataFrame(json_meta)
    with pd.option_context(
        "display.max_colwidth",
        None,
        "display.max_columns",
        None,
        "display.max_rows",
        None,
    ):
        display(json_df)
    target_prompt_batch = [p.prompt for p in target_prompts[:k]]
    probs = object_patching(
        nn_model,
        target_prompt_batch,
        scan=True,
        source_prompt_batch=source_prompts_str[:k],
    )
    file = path / (exp_id + "_heatmap.png")
    plot_topk_tokens(probs, nn_model, title=title, file=file)

    meta_file = path / (exp_id + "_heatmap.meta.json")
    with open(meta_file, "w") as f:
        json.dump(json_meta, f, indent=4)

<span id="papermill-error-cell" style="color:red; font-family:Helvetica Neue, Helvetica, Arial, sans-serif; font-size:2em;">Execution using papermill encountered an exception here and stopped:</span>

In [8]:
paper_pairs = [("de", "it"), ("nl", "fi"), ("zh", "es"), ("es", "ru"), ("ru", "ko")]
paper_ins = ["de", "nl", "zh", "es", "ru"]
paper_args = [
    [[(l, "hi") for l in paper_ins], "fr", "et"],
    (paper_pairs, "fr", "zh"),
]
for pargs in paper_args:
    object_patching_plot(*pargs, extra_langs=["en"])


Running prompts:   0%|          | 0/25 [00:00<?, ?it/s]


Running prompts:   0%|          | 0/25 [00:00<?, ?it/s]




AttributeError: 'BloomForCausalLM' object has no attribute 'model'