In [1]:
import os
import json

import sys

sys.path.append("../")

##################################################################
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"
##################################################################

import logging
from src.utils import logging_utils
from src.utils import env_utils

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

import torch
import transformers

logger.info(f"{torch.__version__=}, {torch.version.cuda=}")
logger.info(
    f"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}"
)
logger.info(f"{transformers.__version__=}")

from src.utils.training_utils import get_device_map

model_key = "meta-llama/Llama-3.3-70B-Instruct"

device_map = get_device_map(model_key, 32, n_gpus=8)
print(device_map)

os.chdir("/disk/u/gio/mechinterp")
print(os.getcwd())

from src.models import ModelandTokenizer

mt = ModelandTokenizer(
    model_key=model_key,
    torch_dtype=torch.bfloat16,
    device_map=device_map,
    attn_implementation="eager"
)


2025-09-01 09:01:06 __main__ INFO     torch.__version__='2.8.0+cu128', torch.version.cuda='12.8'
2025-09-01 09:01:06 __main__ INFO     torch.cuda.is_available()=True, torch.cuda.device_count()=8, torch.cuda.get_device_name()='NVIDIA A100 80GB PCIe'
2025-09-01 09:01:06 __main__ INFO     transformers.__version__='4.54.1'
2025-09-01 09:01:09 git.cmd DEBUG    Popen(['git', 'version'], cwd=/disk/u/gio/retrieval/notebooks, stdin=None, shell=False, universal_newlines=False)
2025-09-01 09:01:09 git.cmd DEBUG    Popen(['git', 'version'], cwd=/disk/u/gio/retrieval/notebooks, stdin=None, shell=False, universal_newlines=False)
{'model.embed_tokens': 7, 'model.norm': 7, 'model.rotary_emb': 7, 'lm_head': 7, 'model.layers.0': 0, 'model.layers.1': 1, 'model.layers.2': 2, 'model.layers.3': 3, 'model.layers.4': 4, 'model.layers.5': 5, 'model.layers.6': 6, 'model.layers.7': 7, 'model.layers.8': 0, 'model.layers.9': 1, 'model.layers.10': 2, 'model.layers.11': 3, 'model.layers.12': 4, 'model.layers.13': 5,

Loading checkpoint shards:   0%|          | 0/30 [00:00<?, ?it/s]

2025-09-01 09:02:00 src.models INFO     loaded model <models/meta-llama/Llama-3.3-70B-Instruct> | size: 134570.516 MB | dtype: torch.bfloat16 | device: cuda:7


In [2]:
os.chdir("/disk/u/gio/retrieval")
print(os.getcwd())

/disk/u/gio/retrieval


## Load the optimized heads

In [3]:
from matplotlib import pyplot as plt
import numpy as np

optimized_path = os.path.join(
    "/disk/u/arnab/Codes/Projects/retrieval/results",
    "selection/optimized_heads",
    mt.name.split("/")[-1],
    "distinct_options/select_one/epoch_10.npz"
)

optimization_results = np.load(optimized_path, allow_pickle=True)

2025-09-01 09:08:31 matplotlib DEBUG    matplotlib data path: /disk/u/gio/.conda/envs/retrieval2/lib/python3.11/site-packages/matplotlib/mpl-data
2025-09-01 09:08:31 matplotlib DEBUG    CONFIGDIR=/disk/u/gio/.config/matplotlib
2025-09-01 09:08:31 matplotlib DEBUG    interactive is False
2025-09-01 09:08:31 matplotlib DEBUG    platform is linux
2025-09-01 09:08:31 matplotlib DEBUG    CACHEDIR=/disk/u/gio/.cache/matplotlib
2025-09-01 09:08:31 matplotlib.font_manager DEBUG    Using fontManager instance from /disk/u/gio/.cache/matplotlib/fontlist-v390.json


In [7]:
optimal_head_mask = torch.tensor(optimization_results["optimal_mask"]).to(torch.float32)
heads_selected = torch.nonzero(optimal_head_mask > 0.5, as_tuple=False).tolist()
heads_selected = [(layer_idx, head_idx) for layer_idx, head_idx in heads_selected if layer_idx < 52]
print(len(heads_selected))

HEADS = heads_selected

79


In [8]:
import copy
import random
from src.selection.utils import KeyedSet, get_first_token_id, verify_correct_option
from src.selection.data import SelectionSample
from src.tokens import prepare_input
from src.selection.data import SelectOddOneOutTask

@torch.inference_mode()
def get_counterfactual_samples_odd_one_out(
    task: SelectOddOneOutTask,
    obj_category: str | None = None,
    distractor_category: str | None = None,
    prompt_template_idx=3,
    option_style="single_line",
    filter_by_lm_prediction: bool = True,
    n_distractors: int = 5,
    counterfact_obj_idx: int | None = None,
    verbose=True,
):
    obj_category = obj_category or random.choice(task.categories)
    distractor_category = distractor_category or random.choice(
        list(set(task.categories) - {obj_category})
    )
    assert obj_category != distractor_category, f"{obj_category=} {distractor_category=}"

    if verbose:
        logger.info(
            f"obj_category={obj_category}, distractor_category={distractor_category}, prompt_template_idx={prompt_template_idx}, option_style={option_style}, n_distractors={n_distractors}"
        )

    patch_sample = task.get_random_sample(
        mt=mt,
        option_style=option_style,
        prompt_template_idx=prompt_template_idx,
        obj_category=obj_category,
        distractor_category=distractor_category,
        filter_by_lm_prediction=False,
        n_distractors=n_distractors,
        obj_idx = random.choice(list(set(list(range(n_distractors + 1))) - {0, 1}))
    )
    if verbose: logger.info(f"patch_sample={str(patch_sample)}")

    #! criterion = not distractor_category
    # Options (2): 
    # 1. distractor_category (selected)
    # 2. category not in [obj_category, distractor_category] 
    not_dist_category_sample = task.get_random_sample(
        mt=mt,
        option_style=option_style,
        prompt_template_idx=prompt_template_idx,
        obj_category=distractor_category,
        distractor_category=random.choice(
            list(set(task.categories) - {obj_category, distractor_category})
        ),
        filter_by_lm_prediction=False,
        exclude_objs=patch_sample.options,
        n_distractors=1,
        obj_idx=counterfact_obj_idx
    )
    if verbose: logger.info(f"not_dist_category_sample={str(not_dist_category_sample)}")
    track_idx = 1 ^ not_dist_category_sample.obj_idx
    not_dist_category_sample.metadata = {
        "track_type_obj": not_dist_category_sample.options[track_idx],
        "track_type_obj_idx": track_idx,
        "track_type_obj_token_id": get_first_token_id(
            not_dist_category_sample.options[track_idx], mt.tokenizer, prefix=" "
        ),
    }

    #! criterion = is obj_category
    # Options (2):
    # 1. obj_category
    # 2. category not in [obj_category, distractor_category] (selected)
    is_obj_category_sample = task.get_random_sample(
        mt=mt,
        option_style=option_style,
        prompt_template_idx=prompt_template_idx,
        distractor_category=obj_category,
        obj_category=random.choice(
            list(set(task.categories) - {obj_category, distractor_category})
        ),
        filter_by_lm_prediction=False,
        exclude_objs=patch_sample.options,
        n_distractors=1,
        obj_idx=counterfact_obj_idx
    )
    if verbose: logger.info(f"is_obj_category_sample={str(is_obj_category_sample)}")
    track_idx = 1 ^ is_obj_category_sample.obj_idx
    is_obj_category_sample.metadata = {
        "track_type_obj_idx": track_idx,
        "track_type_obj": is_obj_category_sample.options[track_idx],
        "track_type_obj_category": obj_category,
        "track_type_obj_token_id": get_first_token_id(
            is_obj_category_sample.options[track_idx], mt.tokenizer, prefix=" "
        ),
    }

    if filter_by_lm_prediction:
        test_samples = [patch_sample, not_dist_category_sample, is_obj_category_sample]

        for sample in test_samples:
            tokenized = prepare_input(tokenizer=mt, prompts=sample.prompt())
            is_correct, predictions, track_options = verify_correct_option(
                mt=mt, target=sample.obj, options=sample.options, input=tokenized
            )
            sample.metadata["tokenized"] = tokenized.data
            if verbose:
                logger.info(sample.prompt())
                logger.info(
                    f"{sample.subj} | {sample.category} -> {sample.obj} | pred={[str(p) for p in predictions]}"
                )
            if not is_correct:
                if verbose:
                    logger.error(
                        f'Prediction mismatch: {track_options[list(track_options.keys())[0]]}["{mt.tokenizer.decode(predictions[0].token_id)}"] != {sample.ans_token_id}["{mt.tokenizer.decode(sample.ans_token_id)}"]'
                    )
                return get_counterfactual_samples_odd_one_out(
                    task=task,
                    obj_category=obj_category,
                    distractor_category=distractor_category,
                    prompt_template_idx=prompt_template_idx,
                    option_style=option_style,
                    filter_by_lm_prediction=filter_by_lm_prediction,
                    n_distractors=n_distractors,
                )
            sample.prediction = predictions

    return {
        "patch_sample": patch_sample,
        "not_dist_category_sample": not_dist_category_sample,
        "is_obj_category_sample": is_obj_category_sample
    }

## Experiments

In [10]:
from src.selection.data import SelectOddOneOutTask

select_odd_one = SelectOddOneOutTask.load(
    path=os.path.join(
        env_utils.DEFAULT_DATA_DIR,
        "selection",
        "objects.json"
    )
)

print(select_odd_one)

SelectOddOneOutTask: (different objects)
Categories: fruit(14), vehicle(15), furniture(15), animal(15), music instrument(15), clothing(15), electronics(15), sport equipment(15), kitchen appliance(15), vegetable(14), building(15), office supply(15), bathroom item(15), flower(15), tree(15), jewelry(14)



In [11]:
distractor_category="fruit"
obj_category="electronics"
option_style="single_line"
prompt_template_idx=3
N_DISTRACTORS=5
counterfact_obj_idx=1

In [12]:
exp_samples = get_counterfactual_samples_odd_one_out(
    task=select_odd_one,
    prompt_template_idx=prompt_template_idx,
    option_style=option_style,
    filter_by_lm_prediction=True,
    n_distractors=5,
    verbose=False
)

patch_sample = exp_samples["patch_sample"]
not_dist_category_sample = exp_samples["not_dist_category_sample"]
is_obj_category_sample = exp_samples["is_obj_category_sample"]

print(exp_samples['patch_sample'])
print(exp_samples['not_dist_category_sample'])
print(exp_samples['is_obj_category_sample'])

Watermelon -> Celery (3): ['Pineapple', 'Cantaloupe', 'Blackberry', 'Celery', 'Honeydew', 'Cranberry']
Peony -> Raspberry (0): ['Raspberry', 'Iris']
Potato -> Monitor (1): ['Turnip', 'Monitor']


In [13]:
from src.selection.functional import verify_head_patterns

patch_attn_info = verify_head_patterns(
    prompt = patch_sample.prompt(),
    options = patch_sample.options,
    pivot = patch_sample.subj,
    mt = mt,
    heads = HEADS
)

not_dist_category_attn_info = verify_head_patterns(
    prompt = not_dist_category_sample.prompt(),
    options = not_dist_category_sample.options,
    pivot = not_dist_category_sample.subj,
    mt = mt,
    heads = HEADS
)

is_obj_category_attn_info = verify_head_patterns(
    prompt = is_obj_category_sample.prompt(),
    options = is_obj_category_sample.options,
    pivot = is_obj_category_sample.subj,
    mt = mt,
    heads = HEADS
)

0 patches to ablate possible answer information from options
2025-09-01 09:09:48 src.selection.functional DEBUG    Predictions: ['" Cel"[47643] (p=0.867, logit=21.375)', '" The"[578] (p=0.071, logit=18.875)', '" Among"[22395] (p=0.014, logit=17.250)', '" It"[1102] (p=0.005, logit=16.125)', '" "[220] (p=0.005, logit=16.125)']
2025-09-01 09:09:48 src.selection.functional INFO     Combined attention matrix for all heads


0 patches to ablate possible answer information from options
2025-09-01 09:09:48 src.selection.functional DEBUG    Predictions: ['" Raspberry"[48665] (p=0.926, logit=20.875)', '" ("[320] (p=0.017, logit=16.875)', '" raspberry"[94802] (p=0.012, logit=16.500)', '" R"[432] (p=0.005, logit=15.688)', '" None"[2290] (p=0.005, logit=15.688)']
2025-09-01 09:09:49 src.selection.functional INFO     Combined attention matrix for all heads


0 patches to ablate possible answer information from options
2025-09-01 09:09:49 src.selection.functional DEBUG    Predictions: ['" Monitor"[24423] (p=0.898, logit=20.750)', '" The"[578] (p=0.019, logit=16.875)', '" ("[320] (p=0.016, logit=16.750)', '" A"[362] (p=0.015, logit=16.625)', '" Turn"[12268] (p=0.008, logit=16.000)']
2025-09-01 09:09:49 src.selection.functional INFO     Combined attention matrix for all heads
