In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import transformers
from src.models import ModelandTokenizer

print(f"{torch.__version__=}, {torch.version.cuda=}")
print(
    f"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}"
)
print(f"{transformers.__version__=}")

model_key = "meta-llama/Llama-3.3-70B-Instruct"
# model_key = "google/gemma-2-27b-it"

mt = ModelandTokenizer(
    model_key=model_key,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="eager",
)

meta-llama/Llama-3.3-70B-Instruct not found in /disk/u/arnab/Codes/Models
If not found in cache, model will be downloaded from HuggingFace to cache directory


torch.__version__='2.7.0+cu126', torch.version.cuda='12.6'
torch.cuda.is_available()=True, torch.cuda.device_count()=8, torch.cuda.get_device_name()='NVIDIA A100 80GB PCIe'
transformers.__version__='4.55.3'


Loading checkpoint shards:   0%|          | 0/30 [00:00<?, ?it/s]

In [3]:
# select one of the filter heads
if model_key == "meta-llama/Llama-3.3-70B-Instruct":
    layer_idx, head_idx = 35, 19
elif model_key == "google/gemma-2-27b-it":
    layer_idx, head_idx = 29, 3
else:
    raise ValueError("For other models you need to localize the heads first. Check scripts/locate_selection_heads.py")

## Checking the behavior of a filter head on one example 

In [4]:
from src.selection.data import SelectOneTask
from typing import Literal
import os
# from src.utils import env_utils # you should create the env.yml file as per instructions in readme.md

##########################################################
prompt_template_idx = 3 # try out different templates
option_style: Literal["single_line", "numbered"] = "single_line"
n_distractors = 5 # number of distractors. total options = n_distractors + 1 for SingleOne task
##########################################################

select_task = SelectOneTask.load(
    path=os.path.join(
        # env_utils.DEFAULT_DATA_DIR,
        "data_save", 
        "selection", 
        "objects.json" # you can also load other entity types such as "profession.json", ...
    )
)


['name', 'prompt_templates', 'odd_one_prompt_templates', 'order_prompt_templates', 'count_prompt_templates', 'yes_no_prompt_templates', 'first_item_in_cat_prompt_templates', 'last_item_in_cat_prompt_templates', 'categories', 'exclude_categories']


In [5]:
sample = select_task.get_random_sample(
    mt = mt,
    option_style=option_style,
    prompt_template_idx=prompt_template_idx,
    category="fruit",
    # category="actor",
    filter_by_lm_prediction=True, 
)

print(sample.prompt(), ">>", sample.obj)
print(f'"{mt.tokenizer.decode([sample.ans_token_id])}"')

fruit >> ['Grape', 'Kiwi', 'Plum', 'Pineapple', 'Orange', 'Raspberry', 'Strawberry', 'Cherry', 'Banana', 'Pear', 'Blueberry', 'Mango', 'Watermelon', 'Apple', 'Peach']
Options: Socks, Nightstand, Wardrobe, Slow cooker, Skyscraper, Pineapple.
Which among these objects mentioned above is a fruit?
Answer: >> Pineapple
" Pine"


In [6]:
from src.selection.functional import verify_head_patterns

#! the select head is good but not 100% perfect. try out other samples/heads if you see noisy results.
attn_pattern = verify_head_patterns(
    mt=mt,
    prompt=sample.prompt(),
    heads=[(layer_idx, head_idx)],
)

## Patching the query state to transfer the predicate

See Figure 1 in the paper.

<p align="center">
<img src="notebooks/figures/fig_1_sliced-crop-1.png" style="width:100%;"/>
</p> 

In [7]:
from src.selection.data import get_counterfactual_samples_within_task

source_sample, destination_sample = get_counterfactual_samples_within_task(
    mt=mt,
    task=select_task,
    prompt_template_idx=prompt_template_idx,
    option_style=option_style,
    patch_category="fruit",
    clean_category="vehicle",
    # mcqify=True, # make it a multi-choice question
)

print("=" * 20)
print(
    "Source:",
    source_sample.prompt(),
    ">>",
    f'"{mt.tokenizer.decode([source_sample.ans_token_id])}"',
)
print(
    "Destination:",
    destination_sample.prompt(),
    ">>",
    f'"{mt.tokenizer.decode([destination_sample.ans_token_id])}"',
)

print(
    destination_sample.metadata["track_type_obj"],
    destination_sample.metadata["track_type_obj_idx"],
    mt.tokenizer.decode(destination_sample.metadata["track_type_obj_token_id"]),
)

type(task)=<class 'src.selection.data.SelectOneTask'>
Source: Options: Dog, Wardrobe, Banana, Earring, Tractor, Comb.
Which among these objects mentioned above is a fruit?
Answer: >> " Banana"
Destination: Options: Tape, Grape, Apartment, Scooter, Saxophone, Surfboard.
Which among these objects mentioned above is a vehicle?
Answer: >> " Sco"
Grape 1  Grape


In [8]:
# Setting the options and the prompt template manually to replicate Figure 1 exactly. 
# For L35 H19 of Llama-3.3-70B-Instruct
#! You should comment this out to check for a random sample pair

from src.selection.data import MCQify_sample
from src.selection.utils import get_first_token_id

source_sample.options = ["Cherry", "Knife", "Pants", "Car"]
source_sample.prompt_template = "<_options_>\nFind the <_category_>\nAnswer:"
print("Source:", source_sample.prompt())

destination_sample.options = ["Binder", "Peach", "Watch", "Scooter", "Phone"]
destination_sample.prompt_template = "<_options_>\nFind the <_category_>\nAnswer:"
destination_sample.object = "Scooter"
destination_sample.obj_idx = 3
destination_sample.metadata["track_type_obj_token_id"] = get_first_token_id(
    name="b", tokenizer=mt.tokenizer, prefix=" "
)
destination_sample = MCQify_sample(sample=destination_sample, tokenizer=mt)
print("\nDestination:", destination_sample.prompt())

Source: Options: Cherry, Knife, Pants, Car.
Find the fruit
Answer:

Destination: a. Binder
b. Peach
c. Watch
d. Scooter
e. Phone
Find the vehicle
Answer:


In [9]:
from src.tokens import prepare_input
from src.functional import interpret_logits

source_tokenized = prepare_input(
    prompts=source_sample.prompt(), 
    tokenizer=mt,
)

source_attn = verify_head_patterns(
    mt=mt,
    prompt=source_sample.prompt(),
    heads=[(layer_idx, head_idx)],
)

source_predictions = interpret_logits(
    tokenizer=mt.tokenizer,
    logits=source_attn["logits"].squeeze(),
    k=5
)
print("Source predictions:", [str(pred) for pred in source_predictions])


destination_attn = verify_head_patterns(
    mt=mt,
    prompt=destination_sample.prompt(),
    heads=[(layer_idx, head_idx)],
)
destination_tokenized = prepare_input(
    prompts=destination_sample.prompt(), 
    tokenizer=mt,
)


destination_predictions, dest_track = interpret_logits(
    tokenizer=mt.tokenizer,
    logits=destination_attn["logits"].squeeze(),
    k=5,
    interested_tokens=[destination_sample.metadata["track_type_obj_token_id"]],
)
print("Destination predictions:", [str(pred) for pred in destination_predictions])
print(dest_track)

clean_score = dest_track[destination_sample.metadata["track_type_obj_token_id"]][1].logit
print(f"{clean_score=}")

Source predictions: ['" Cherry"[45805] (p=0.945, logit=22.250)', '" The"[578] (p=0.032, logit=18.875)', '" CH"[6969] (p=0.007, logit=17.375)', '" ""[330] (p=0.003, logit=16.500)', '" cherry"[41980] (p=0.001, logit=15.500)']


Destination predictions: ['" d"[294] (p=0.789, logit=21.500)', '" Option"[7104] (p=0.044, logit=18.625)', '" ("[320] (p=0.039, logit=18.500)', '" Sco"[50159] (p=0.035, logit=18.375)', '" The"[578] (p=0.031, logit=18.250)']
OrderedDict([(293, (6, PredictedToken(token=' b', prob=0.00775146484375, logit=16.875, token_id=293, metadata=None)))])
clean_score=16.875


In [10]:
source_attn["logits"].shape

torch.Size([128256])

In [11]:
# checking the effect of transferring the q_state

from src.selection.functional import cache_q_projections
from src.functional import PatchSpec

map_indices = {-3: -3, -2: -2, -1: -1} # source_token_idx -> destination_token_idx
q_states = cache_q_projections(
    mt=mt,
    input=source_tokenized,
    heads=[(layer_idx, head_idx)],
    token_indices=[map_indices.keys()],
)[0]

q_patches = []
for (l_idx, h_idx, source_token_idx), q_proj in q_states.items():
    q_patches.append(PatchSpec(
        location=(
            mt.attn_module_name_format.format(l_idx)+".q_proj",
            h_idx,
            map_indices[source_token_idx]
        ),
        patch=q_proj.squeeze()
    ))

patched_run = verify_head_patterns(
    prompt = destination_sample.prompt(),
    mt = mt,
    heads = [(layer_idx, head_idx)],
    query_patches = q_patches
)

patched_predictions, patched_track = interpret_logits(
    tokenizer=mt.tokenizer,
    logits=patched_run["logits"].squeeze(),
    k=5,
    interested_tokens=[destination_sample.metadata["track_type_obj_token_id"]],
)
print("Patched predictions:", [str(pred) for pred in patched_predictions])
print(patched_track)
patched_score = patched_track[destination_sample.metadata["track_type_obj_token_id"]][1].logit
print(f"{patched_score=}")

improvement = patched_score - clean_score
print(f"Δ score after patching query state of a single head: {improvement:.4f}")

Patched predictions: ['" d"[294] (p=0.750, logit=21.000)', '" ("[320] (p=0.054, logit=18.375)', '" Option"[7104] (p=0.042, logit=18.125)', '" Sco"[50159] (p=0.029, logit=17.750)', '" The"[578] (p=0.029, logit=17.750)']
OrderedDict([(293, (6, PredictedToken(token=' b', prob=0.02001953125, logit=17.375, token_id=293, metadata=None)))])
patched_score=17.375
Δ score after patching query state of a single head: 0.5000


## Patching all the identified filter heads

In [12]:
filter_heads = {
    "Llama-3.3-70B-Instruct": [
        (28, 40),
        (28, 45),
        (29, 56),
        (29, 57),
        (29, 60),
        (29, 61),
        (29, 62),
        (30, 62),
        (31, 0),
        (31, 32),
        (31, 33),
        (31, 36),
        (31, 37),
        (31, 38),
        (31, 39),
        (31, 40),
        (31, 43),
        (32, 12),
        (32, 19),
        (32, 48),
        (33, 18),
        (33, 21),
        (33, 23),
        (33, 30),
        (33, 43),
        (33, 46),
        (34, 1),
        (34, 6),
        (34, 33),
        (34, 45),
        (35, 5),
        (35, 17),
        (35, 18),
        (35, 19),
        (35, 20),
        (35, 22),
        (35, 23),
        (35, 27),
        (35, 28),
        (35, 36),
        (35, 40),
        (35, 42),
        (36, 17),
        (36, 22),
        (36, 40),
        (36, 44),
        (36, 47),
        (36, 52),
        (36, 54),
        (37, 0),
        (37, 3),
        (37, 4),
        (37, 7),
        (37, 16),
        (37, 28),
        (37, 30),
        (37, 36),
        (37, 39),
        (38, 19),
        (38, 23),
        (38, 49),
        (38, 50),
        (38, 51),
        (39, 35),
        (39, 36),
        (39, 41),
        (39, 44),
        (39, 45),
        (42, 28),
        (42, 30),
        (42, 31),
        (45, 1),
        (47, 17),
        (47, 18),
        (49, 1),
        (49, 4),
        (49, 5),
        (49, 7),
        (50, 34),
    ],
    "google/gemma-2-27b-it": [
        (20, 3),
        (21, 13),
        (21, 29),
        (22, 5),
        (22, 6),
        (22, 7),
        (22, 22),
        (22, 30),
        (23, 2),
        (23, 6),
        (23, 13),
        (23, 19),
        (23, 20),
        (23, 22),
        (23, 24),
        (23, 31),
        (24, 4),
        (24, 5),
        (24, 6),
        (24, 7),
        (24, 9),
        (24, 12),
        (24, 14),
        (25, 8),
        (25, 15),
        (26, 2),
        (26, 4),
        (26, 5),
        (26, 16),
        (26, 18),
        (26, 23),
        (26, 25),
        (26, 30),
        (27, 5),
        (28, 3),
        (28, 12),
        (28, 13),
        (28, 16),
        (28, 17),
        (28, 20),
        (28, 21),
        (28, 27),
        (28, 31),
        (29, 2),
        (29, 10),
        (29, 16),
        (29, 22),
        (29, 23),
        (29, 24),
        (29, 26),
        (29, 27),
        (29, 29),
        (30, 6),
        (30, 8),
        (30, 11),
        (30, 14),
        (30, 15),
        (30, 20),
        (30, 21),
        (31, 2),
        (31, 3),
        (31, 24),
        (31, 31),
        (33, 12),
        (33, 16),
        (33, 17),
        (34, 14),
        (34, 19),
        (35, 9),
        (35, 25),
    ],
}

In [13]:
heads = filter_heads[model_key.split("/")[-1]]

source_attn = verify_head_patterns(
    mt=mt,
    prompt=source_sample.prompt(),
    heads=heads,
)

destination_attn = verify_head_patterns(
    mt=mt,
    prompt=destination_sample.prompt(),
    heads=heads,
)

In [14]:
from src.selection.functional import cache_q_projections
from src.functional import PatchSpec

#! checkout validate_q_proj_ie_on_sample_pair in src/selection/optimization.py

map_indices = {-3: -3, -2: -2, -1: -1} # source_token_idx -> destination_token_idx
q_states = cache_q_projections(
    mt=mt,
    input=source_tokenized,
    heads=heads,
    token_indices=[list(map_indices.keys())],
)[0]

q_patches = []
for (l_idx, h_idx, patch_token_idx), q_proj in q_states.items():
    q_patches.append(PatchSpec(
        location=(
            mt.attn_module_name_format.format(l_idx)+".q_proj",
            h_idx,
            map_indices[patch_token_idx]
        ),
        patch=q_proj.squeeze()
    ))

patched_run = verify_head_patterns(
    prompt = destination_sample.prompt(),
    mt = mt,
    heads = heads,
    query_patches = q_patches
)

patched_predictions, patched_track = interpret_logits(
    tokenizer=mt.tokenizer,
    logits=patched_run["logits"].squeeze(),
    k=5,
    interested_tokens=[destination_sample.metadata["track_type_obj_token_id"]],
)
print("Patched predictions:", [str(pred) for pred in patched_predictions])
print(patched_track)
patched_score = patched_track[destination_sample.metadata["track_type_obj_token_id"]][1].logit
print(f"{patched_score=}")

improvement = patched_score - clean_score
print(f"Δ score after patching query state for {len(heads)} filter heads: {improvement:.4f}")

Patched predictions: ['" b"[293] (p=0.758, logit=21.750)', '" Peach"[64695] (p=0.169, logit=20.250)', '" ("[320] (p=0.014, logit=17.750)', '" Option"[7104] (p=0.014, logit=17.750)', '" The"[578] (p=0.012, logit=17.625)']
OrderedDict([(293, (1, PredictedToken(token=' b', prob=0.7578125, logit=21.75, token_id=293, metadata=None)))])
patched_score=21.75
Δ score after patching query state for 79 filter heads: 4.8750
