In [1]:
INTERACTIVE_MODE = False  # Set to True to run the notebook interactively

import sys

if INTERACTIVE_MODE:
    sys.path.append("../src")
    %load_ext autoreload
    %autoreload 3
    from tqdm.notebook import tqdm
else:
    sys.path.append("./src")
    from tqdm import tqdm

In [2]:
import torch as th
import numpy as np
import pandas as pd
import json
import matplotlib.pyplot as plt
import textwrap
from pathlib import Path
from time import time
import itertools
import os
from random import shuffle
from IPython.display import display

_ = th.set_grad_enabled(False)

hf_token = "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"
os.environ["HF_TOKEN"] = "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"

In [3]:
exp_name = "obj_patch_translation"

In [4]:
# papermill parameters
batch_size = 8
model = "meta-llama/Llama-2-7b-hf"
model_path = None
trust_remote_code = False
device = "auto"
remote = False
num_few_shot = 5
exp_id = None
extra_args = []
use_tl = False
paper_args_str = ''
del_model = 'no'


In [5]:
# Parameters
model = "meta-llama/Llama-3.2-3B"
paper_args_str = "[[[\"hi_translit\", \"it\"], [\"ml_translit\", \"it\"], [\"ta_translit\", \"it\"], [\"te_translit\", \"it\"], [\"gu_translit\", \"it\"]], \"ml\", \"it\"]"
del_model = "no"


In [6]:
from exp_tools import load_model
from argparse import ArgumentParser



parser = ArgumentParser()
parser.add_argument("--num-patches", type=int, default=-1)
pargs = parser.parse_args(extra_args)
num_patches = pargs.num_patches

if model_path is None:
    model_path = model
nn_model = load_model(
    model_path,
    trust_remote_code=trust_remote_code,
    device_map=device,
    use_tl=use_tl,
)
tokenizer = nn_model.tokenizer

## Plots

In [7]:
from exp_tools import (
    run_prompts,
)
from interventions import (
    object_lens,
    collect_activations,
    collect_activations_batched,
    get_num_layers,
)
from prompt_tools import translation_prompts, get_obj_id
from load_dataset import get_word_translation_dataset as get_translations

from utils import ulist
from display_utils import plot_topk_tokens, plot_results, plot_k_results, k_subplots
from copy import deepcopy


def object_patching_plot(
    source_lang_pairs,
    input_lang,
    target_lang,
    extra_langs=None,
    batch_size=batch_size,
    num_words=None,
    num_pairs=200,
    exp_id=exp_id,
    k=4,
    remote=remote,
):
    """
    Experiment 2 of the paper:
    - For each source_lang_pairs, construct a prompt translating the same concept (e.g. DOG):
    L1: "CAT^L1" - L2: "CAT^L2"
    ...
    L1: "DOG^L1

    - Collect activation at the last token of the prompt and generate a mean latent representation for each layer

    - For each layer `j`: Run the target prompts which are translations from the input_lang to the target_lang. During the forward pass, patch at the last token of the concept to be translated with the mean latent representation of the source prompts from `j` to the last layer.

    We plot both the probabilities you get from the mean latent and the probabilities you get from the first source_lang_pairs latent.
    """
    source_lang_pairs = np.array(source_lang_pairs)
    if extra_langs is None:
        extra_langs = []
    if isinstance(extra_langs, str):
        extra_langs = [extra_langs]
    model_name = model.split("/")[-1]
    global source_df, target_df, target_prompts, target_probs, latent_probs, source_prompts, _source_prompts, _target_prompts
    if exp_id is None:
        exp_id = str(int(time()))
    else:
        exp_id = str(exp_id)
    source_df = get_translations(
        "en",
        ulist([*source_lang_pairs.flatten(), input_lang, target_lang, *extra_langs]),
        num_words,
    )
    target_df = get_translations(
        input_lang,
        ulist([*source_lang_pairs.flatten(), target_lang, *extra_langs]),
        num_words,
    )

    _source_prompts = list(
        zip(
            *[
                translation_prompts(
                    source_df,
                    nn_model.tokenizer,
                    inp_lang,
                    targ_lang,
                    [target_lang, *extra_langs],
                    augment_tokens=False,
                    n=num_few_shot,
                )
                for inp_lang, targ_lang in source_lang_pairs
            ]
        )
    )
    _target_prompts = translation_prompts(
        target_df,
        nn_model.tokenizer,
        input_lang,
        target_lang,
        [*extra_langs],  # [*list(zip(*source_lang_pairs))[1], *extra_langs],
        augment_tokens=False,
        n=num_few_shot,
    )

    collected_pairs = 0
    source_prompts = []
    target_prompts = []
    source_target = list(itertools.product(source_df.iterrows(), target_df.iterrows()))
    shuffle(source_target)

    for (i, source_row), (j, target_row) in source_target:
        if source_row["word_original"] == target_row["word_original"]:
            continue
        src_p = _source_prompts[i]
        targ_p = deepcopy(_target_prompts[j])
        latent_tokens = {f"source_{target_lang}": src_p[0].latent_tokens[target_lang]}
        latent_tokens.update(**targ_p.latent_tokens)
        targ_p.latent_tokens = latent_tokens
        targ_p.latent_strings[f"sources"] = ulist(
            sum([p.target_strings for p in src_p], [])
        )
        targ_p.latent_strings[f"source_{target_lang}"] = src_p[0].latent_strings[
            target_lang
        ]
        for lang in extra_langs:
            targ_p.latent_tokens[f"src + tgt {lang}"] = ulist(
                targ_p.latent_tokens[lang] + src_p[0].latent_tokens[lang]
            )
            targ_p.latent_strings[f"src + tgt {lang}"] = ulist(
                targ_p.latent_strings[lang] + src_p[0].latent_strings[lang]
            )
            del targ_p.latent_tokens[lang]
        if targ_p.has_no_collisions():
            source_prompts.append(src_p)
            target_prompts.append(targ_p)
            collected_pairs += 1
        if collected_pairs >= num_pairs:
            break
    if collected_pairs < num_pairs:
        print(
            f"Could only collect {collected_pairs} pairs for {source_lang_pairs.tolist()} - {input_lang} -> {target_lang}, skipping..."
        )
        return
    source_prompts = np.array(source_prompts)
    source_prompts_str = np.array(
        [['"'.join(p.prompt.split('"')[:-2]) for p in ps] for ps in source_prompts]
    )
    idx = get_obj_id(target_prompts[0].prompt, nn_model.tokenizer)

    def object_patching(
        nn_model, prompt_batch, scan, source_prompt_batch=None, only_first=False
    ):
        offset = object_patching.offset
        batch_size = len(prompt_batch)
        if source_prompt_batch is None:
            source_prompt_batch = source_prompts_str[offset : offset + batch_size]
        if only_first:
            source_prompt_batch = source_prompt_batch[:, :1]
        hiddens = collect_activations_batched(
            nn_model, source_prompt_batch.flatten(), batch_size=batch_size, remote=remote
        )
        hiddens = hiddens.transpose(0, 1)  # (all_prompts, layer, hidden_size)
        hiddens = hiddens.reshape(
            batch_size, source_prompt_batch.shape[1], get_num_layers(nn_model), -1
        ).mean(
            dim=1
        )  # (batch_size, num_layers, hidden_size)
        hiddens = hiddens.transpose(0, 1)  # (num_layers, batch_size, hidden_size)
        object_patching.offset += batch_size
        return object_lens(
            nn_model,
            prompt_batch,
            idx,
            hiddens=hiddens,
            scan=scan,
            num_patches=num_patches,
            remote=remote,
        )

    object_patching.offset = 0
    target_probs, latent_probs = run_prompts(
        nn_model, target_prompts, batch_size=batch_size, get_probs=object_patching, tqdm=tqdm
    )
    
    display('111111111111111')
     # Add these print statements
    print("Target Probabilities Shape:", target_probs.shape)
    print("Target Probabilities Sample:\n", target_probs)  # First 5 examples
    
    print("\nLatent Probabilities Keys:", latent_probs.keys())
    for key, probs in latent_probs.items():
        print(f"\n{key} Shape:", probs.shape)
        print(f"{key} Sample:\n", probs) 



In [8]:
paper_args = json.loads(paper_args_str)
print(paper_args)
object_patching_plot(*paper_args, extra_langs=["en"])
    

[[['hi_translit', 'it'], ['ml_translit', 'it'], ['ta_translit', 'it'], ['te_translit', 'it'], ['gu_translit', 'it']], 'ml', 'it']
source_lang: en
df keys: ['fr', 'de', 'ru', 'en', 'zh', 'es', 'ja', 'ko', 'et', 'fi', 'nl', 'hi', 'word_original', 'it', 'hi_translit', 'ml_translit', 'ta_translit', 'te_translit', 'gu_translit', 'ml', 'ta', 'te', 'gu']
tatget_lang: hi_translit
tatget_lang: it
tatget_lang: ml_translit
tatget_lang: ta_translit
tatget_lang: te_translit
tatget_lang: gu_translit
tatget_lang: ml
tatget_lang: en
Out_DF
                                          hi_translit  \
0              [kitab, pustak, granth, sastra, pothi]   
1                                [megha, badal, nimb]   
2                                      [jhola, potli]   
3                      [parvat, giri, shikhar, pahar]   
4               [vastra, kapda, malmal, rakhi, dhoti]   
..                                                ...   
89  [official, karamchari, prashasanik, sarkari, n...   
90            

Running prompts:   0%|                                                                                                       | 0/25 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


Running prompts:   4%|███▊                                                                                           | 1/25 [00:11<04:35, 11.47s/it]

Running prompts:   8%|███████▌                                                                                       | 2/25 [00:18<03:20,  8.71s/it]

Running prompts:  12%|███████████▍                                                                                   | 3/25 [00:25<02:54,  7.93s/it]

Running prompts:  16%|███████████████▏                                                                               | 4/25 [00:32<02:37,  7.50s/it]

Running prompts:  20%|███████████████████                                                                            | 5/25 [00:39<02:28,  7.44s/it]

Running prompts:  24%|██████████████████████▊                                                                        | 6/25 [00:46<02:17,  7.23s/it]

Running prompts:  28%|██████████████████████████▌                                                                    | 7/25 [00:53<02:08,  7.15s/it]

Running prompts:  32%|██████████████████████████████▍                                                                | 8/25 [01:00<01:59,  7.05s/it]

Running prompts:  36%|██████████████████████████████████▏                                                            | 9/25 [01:06<01:51,  6.97s/it]

Running prompts:  40%|█████████████████████████████████████▌                                                        | 10/25 [01:14<01:46,  7.08s/it]

Running prompts:  44%|█████████████████████████████████████████▎                                                    | 11/25 [01:21<01:37,  7.00s/it]

Running prompts:  48%|█████████████████████████████████████████████                                                 | 12/25 [01:27<01:30,  6.96s/it]

Running prompts:  52%|████████████████████████████████████████████████▉                                             | 13/25 [01:34<01:24,  7.01s/it]

Running prompts:  56%|████████████████████████████████████████████████████▋                                         | 14/25 [01:41<01:16,  6.95s/it]

Running prompts:  60%|████████████████████████████████████████████████████████▍                                     | 15/25 [01:48<01:09,  7.00s/it]

Running prompts:  64%|████████████████████████████████████████████████████████████▏                                 | 16/25 [01:56<01:03,  7.05s/it]

Running prompts:  68%|███████████████████████████████████████████████████████████████▉                              | 17/25 [02:02<00:55,  6.98s/it]

Running prompts:  72%|███████████████████████████████████████████████████████████████████▋                          | 18/25 [02:09<00:48,  6.93s/it]

Running prompts:  76%|███████████████████████████████████████████████████████████████████████▍                      | 19/25 [02:16<00:41,  7.00s/it]

Running prompts:  80%|███████████████████████████████████████████████████████████████████████████▏                  | 20/25 [02:23<00:35,  7.02s/it]

Running prompts:  84%|██████████████████████████████████████████████████████████████████████████████▉               | 21/25 [02:30<00:27,  6.96s/it]

Running prompts:  88%|██████████████████████████████████████████████████████████████████████████████████▋           | 22/25 [02:37<00:21,  7.01s/it]

Running prompts:  92%|██████████████████████████████████████████████████████████████████████████████████████▍       | 23/25 [02:44<00:13,  6.94s/it]

Running prompts:  96%|██████████████████████████████████████████████████████████████████████████████████████████▏   | 24/25 [02:51<00:06,  6.91s/it]

Running prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [02:58<00:00,  6.96s/it]

Running prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [02:58<00:00,  7.14s/it]




'111111111111111'

Target Probabilities Shape: torch.Size([200, 28])
Target Probabilities Sample:
 tensor([[1.2146e-02, 1.2680e-02, 1.2764e-02,  ..., 8.4668e-01, 8.4277e-01,
         8.4424e-01],
        [6.5804e-03, 6.3286e-03, 6.2599e-03,  ..., 8.4766e-01, 8.4326e-01,
         8.4424e-01],
        [2.4585e-01, 2.3523e-01, 2.3633e-01,  ..., 6.4990e-01, 6.5479e-01,
         6.5283e-01],
        ...,
        [3.7849e-05, 3.8683e-05, 3.7611e-05,  ..., 8.6133e-01, 8.5742e-01,
         8.5938e-01],
        [6.2988e-01, 6.3965e-01, 6.3867e-01,  ..., 5.8984e-01, 5.8984e-01,
         5.8984e-01],
        [4.5443e-04, 4.7398e-04, 4.7970e-04,  ..., 5.4779e-02, 5.4535e-02,
         5.3711e-02]], dtype=torch.float16)

Latent Probabilities Keys: dict_keys(['source_it', 'src + tgt en'])

source_it Shape: torch.Size([200, 28])
source_it Sample:
 tensor([[1.6785e-01, 1.5027e-01, 1.5063e-01,  ..., 6.6662e-04, 6.7329e-04,
         6.6900e-04],
        [4.0259e-01, 4.0405e-01, 4.0527e-01,  ..., 2.6398e-03, 2.7027e-03,
   

In [9]:
if del_model == 'yes':
    del model
    del tokenizer

