In [1]:
INTERACTIVE_MODE = False  # Set to True to run the notebook interactively

import sys

if INTERACTIVE_MODE:
    sys.path.append("../src")
    %load_ext autoreload
    %autoreload 3
# else:
sys.path.append("./src")

In [2]:
import torch as th
import pandas as pd
import json
import matplotlib.pyplot as plt
from pathlib import Path
from time import time

_ = th.set_grad_enabled(False)

In [3]:
from nnsight import CONFIG

CONFIG.set_default_api_key("2efc189f6ff94f36b7cd316bbad080b9")

In [6]:
# papermill parameters
# batch_size = 64
batch_size = 128
model = "meta-llama/Meta-Llama-3.1-8B"
model_path = None
# model_path = "/dlabscratch1/public/llm_weights/llama3.1_hf/Meta-Llama-3.1-8B/"
trust_remote_code = False
device = "auto"
# remote = True
remote = True
num_few_shot = 5
exp_id = None
extra_args = []
use_tl = False

In [7]:
from nnterp import load_model
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument("--langs", default=["en"], nargs="+")
parser.add_argument("--features", default=["gender"], nargs="+")
parser.add_argument("--num_pairs", default=64, type=int)
pargs = parser.parse_args(extra_args)


if model_path is None:
    model_path = model
nn_model = load_model(
    model_path,
    trust_remote_code=trust_remote_code,
    device_map=device,
    use_tl=use_tl,
    no_space_on_bos=True,
)

## Plots

In [8]:
from nnterp.interventions import TargetPromptBatch, patchscope_lens
from exp_tools import run_prompts
from prompt_tools import (
    feature_prompts,
    get_shifted_prompt_pairs,
    NotEnoughPromptsError,
)

from load_dataset import get_feature_dataset

from utils import ulist
from display_utils import plot_k_results, plot_topk_tokens, plot_results, k_subplots


def shifted_feature_plot(
    source_lang,
    target_lang,
    feature="gender",
    batch_size=batch_size,
    num_pairs=200,
    exp_id=exp_id,
    k=4,
    remote=remote,
    layers=None,
    num_few_shot=num_few_shot,
):
    """
    Patchscope with source hidden from:
    index -1 and Prompt = source_input_lang: A -> source_target_lang:
    Into target prompt:
    into index = -1, prompt = input_lang: A -> target_lang:
    Then plot with latent_langs, target_lang, source_target_lang
    """
    model_name = model.split("/")[-1]
    global source_df, target_df, target_prompts, target_probs, latent_probs, source_prompts, _source_prompts, _target_prompts, source_prompts_str
    if exp_id is None:
        exp_id = str(int(time()))
    else:
        exp_id = str(exp_id)
    source_df = get_feature_dataset(
        feature,
        source_lang,
    )
    target_df = get_feature_dataset(
        feature,
        target_lang,
    )

    _source_prompts = feature_prompts(
        source_df,
        source_lang,
        nn_model.tokenizer,
        augment_tokens=False,
        ignore_start_of_word=True,
        n=num_few_shot,
    )
    _target_prompts = feature_prompts(
        target_df,
        target_lang,
        nn_model.tokenizer,
        augment_tokens=False,
        ignore_start_of_word=True,
        n=num_few_shot,
    )
    try:
        source_prompts, target_prompts = get_shifted_prompt_pairs(
            source_df,
            target_df,
            _source_prompts,
            _target_prompts,
            None,
            "corr" if source_lang != target_lang else None,
            None,
            f"cfact {source_lang}" if source_lang != target_lang else None,
            [],
            num_pairs,
            merge_extra_langs=True,
            label_col="label",
        )
    except NotEnoughPromptsError:
        return
    source_prompts_str = [p.prompt for p in source_prompts]

    def transverse_patchscope(nn_model, prompt_batch, scan=False, remote=remote):
        offset = transverse_patchscope.offset
        target_pathscope_prompts = TargetPromptBatch.from_prompts(prompt_batch, -1)
        source_prompt_batch = source_prompts_str[offset : offset + len(prompt_batch)]
        transverse_patchscope.offset += len(prompt_batch)
        return patchscope_lens(
            nn_model,
            source_prompt_batch,
            target_pathscope_prompts,
            remote=remote,
            layers=layers,
        )

    transverse_patchscope.offset = 0
    target_probs, latent_probs = run_prompts(
        nn_model,
        target_prompts,
        batch_size=batch_size,
        get_probs=transverse_patchscope,
        # remote=remote,
    )

    # Get the baseline to normalize the plots
    source_prompts_probs, _ = run_prompts(
        nn_model,
        source_prompts,
        batch_size=batch_size,
        get_probs_kwargs=dict(remote=remote),
    )
    target_prompts_probs, _ = run_prompts(
        nn_model,
        target_prompts,
        batch_size=batch_size,
        get_probs_kwargs=dict(remote=remote),
    )

    json_dic = {
        target_lang: target_probs.tolist(),
        "source prompt probs": source_prompts_probs.squeeze().tolist(),
        "target prompt probs": target_prompts_probs.squeeze().tolist(),
    }
    for label, probs in latent_probs.items():
        json_dic[label] = probs.tolist()
    path = (
        Path("results")
        / model_name
        / "shifted_feature"
        / feature
        / (f"{source_lang}-{target_lang}")
    )
    path.mkdir(parents=True, exist_ok=True)
    json_file = path / (exp_id + ".json")
    with open(json_file, "w") as f:
        json.dump(json_dic, f, indent=4)

    fig, ax = plt.subplots(figsize=(10, 5))

    # Raw probabilities plot
    plot_results(
        ax,
        target_probs,
        latent_probs,
        target_lang,
        source_baseline=source_prompts_probs.mean(),
        target_baseline=target_prompts_probs.mean(),
    )
    ax.legend()
    title = (
        f"{model_name}: HeteroPatch for {feature} from {source_lang} into {target_lang}"
    )
    ax.set_title(title)
    plt.tight_layout()
    plot_file = path / (exp_id + ".png")
    plt.savefig(plot_file, dpi=300, bbox_inches="tight")
    plt.show()

    # Plot k examples
    fig, axes = k_subplots(k)
    plot_k_results(axes, target_probs, latent_probs, target_lang, k=k)
    axes[k - 1].legend()
    fig.suptitle(title)
    plt_file = path / (exp_id + "_k.png")
    fig.savefig(plt_file, dpi=300, bbox_inches="tight")
    fig.show()
    # Compute a single example
    json_meta = {}
    for i in range(k):
        json_meta[i] = {
            "source lang": source_lang,
            "target lang": target_lang,
            "source prompt": source_prompts_str[i],
            "source prompt target": source_prompts[i].target_strings,
            "source prompt latent": source_prompts[i].latent_strings,
            "target prompt": target_prompts[i].prompt,
            "target prompt target": target_prompts[i].target_strings,
            "target prompt latent": target_prompts[i].latent_strings,
        }
    json_df = pd.DataFrame(json_meta)
    with pd.option_context(
        "display.max_colwidth",
        None,
        "display.max_columns",
        None,
        "display.max_rows",
        None,
    ):
        display(json_df)
    target_prompt_batch = TargetPromptBatch.from_prompts(
        [p.prompt for p in target_prompts[:k]], -1
    )
    probs = patchscope_lens(
        nn_model, source_prompts_str[:k], target_prompt_batch, remote=remote
    )
    file = path / (exp_id + "_heatmap.png")
    plot_topk_tokens(probs, nn_model.tokenizer, title=title, file=file)

    meta_file = path / (exp_id + "_heatmap.meta.json")
    with open(meta_file, "w") as f:
        json.dump(json_meta, f, indent=4)

In [None]:
from load_dataset import NoDatasetFound
for source_lang in pargs.langs:
    for target_lang in pargs.langs:
        for feature in pargs.features:
            try:
                shifted_feature_plot(
                    source_lang,
                    target_lang,
                    feature=feature,
                    batch_size=batch_size,
                    num_pairs=pargs.num_pairs,
                    exp_id=exp_id,
                    remote=remote,
                )
            except NoDatasetFound as e:
                print(e)