In [20]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import os, time, json
import pandas as pd
from openai import OpenAI
from tqdm.auto import tqdm

import sys
sys.path.append("../")
import os

import logging
from src.utils import logging_utils
from src.utils import env_utils
from src import functional

from nnsight import LanguageModel
from src.models import ModelandTokenizer
from transformers import BitsAndBytesConfig

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

import torch
import transformers

logger.info(f"{torch.__version__=}, {torch.version.cuda=}")
logger.info(f"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}")
logger.info(f"{transformers.__version__=}")

  from .autonotebook import tqdm as notebook_tqdm


2025-04-18 16:11:12 __main__ INFO     torch.__version__='2.6.0+cu124', torch.version.cuda='12.4'
2025-04-18 16:11:12 __main__ INFO     torch.cuda.is_available()=True, torch.cuda.device_count()=1, torch.cuda.get_device_name()='NVIDIA RTX A6000'
2025-04-18 16:11:12 __main__ INFO     transformers.__version__='4.51.2'


In [9]:
from src.functional import extract_entities_with_oracle_LM

# entity = "Leonardo da Vinci"
# entity = "Benjamin Franklin"
# entity = "Japan"
entity = "Daredevil"

keywords_gpt = extract_entities_with_oracle_LM(entity, oracle="gpt4o")
keywords_gpt

2025-04-18 16:24:33 src.functional DEBUG    found cached gpt4o response for 93ba798cd545a644364989fde8874a7e - loading


[['Marvel Comics',
  'Daredevil is a superhero character published by Marvel Comics'],
 ['Matt Murdock', "Daredevil's real name is Matt Murdock"],
 ['Blindness',
  'Daredevil is blind due to a childhood accident involving radioactive material'],
 ['Heightened Senses',
  'Daredevil has superhuman senses that compensate for his lack of sight'],
 ['Radar Sense',
  'Daredevil possesses a radar sense that allows him to perceive his surroundings'],
 ["Hell's Kitchen",
  "Daredevil operates primarily in the Hell's Kitchen neighborhood of New York City"],
 ['Lawyer', 'Matt Murdock is a lawyer by profession'],
 ['Foggy Nelson',
  "Foggy Nelson is Matt Murdock's best friend and law partner"],
 ['Karen Page',
  'Karen Page is a close friend and love interest of Matt Murdock'],
 ['Kingpin',
  "Kingpin, also known as Wilson Fisk, is one of Daredevil's main adversaries"],
 ['Elektra', 'Elektra is a skilled assassin and love interest of Daredevil'],
 ['The Hand', 'The Hand is a ninja organization tha

In [10]:
keywords_claude = extract_entities_with_oracle_LM(entity, oracle="claude")
keywords_claude

2025-04-18 16:24:33 src.functional DEBUG    found cached gpt4o response for 93ba798cd545a644364989fde8874a7e - loading


[['Matt Murdock', 'Daredevil is the superhero alter ego of Matt Murdock'],
 ['Marvel Comics', 'Daredevil is a character published by Marvel Comics'],
 ['Stan Lee', 'Daredevil was co-created by Stan Lee'],
 ['Bill Everett', 'Daredevil was co-created by Bill Everett'],
 ["Hell's Kitchen",
  "Daredevil operates primarily in the Hell's Kitchen neighborhood of New York City"],
 ['Blindness',
  'Matt Murdock/Daredevil is blind, having lost his sight in a childhood accident'],
 ['Enhanced senses',
  'Despite being blind, Daredevil possesses superhuman senses of hearing, smell, taste, and touch'],
 ['Radar sense',
  'Daredevil has a radar sense that functions as a form of echolocation'],
 ['Lawyer', 'In his civilian identity, Matt Murdock works as an attorney'],
 ['Nelson and Murdock',
  'Matt Murdock co-founded the law firm Nelson and Murdock with his friend Foggy Nelson'],
 ['Foggy Nelson',
  "Franklin 'Foggy' Nelson is Matt Murdock's best friend and law partner"],
 ['Karen Page',
  'Karen P

In [11]:
####################################
# entity = "Japan"
# other_entity = "Germany"

# entity = "Benjamin Franklin"
# other_entity = "Leonardo da Vinci"

entity = "Daredevil"
other_entity = "Toph Beifong"
####################################

In [12]:
connections_gpt = extract_entities_with_oracle_LM(entity, oracle="gpt4o", other_entity=other_entity)
connections_gpt

2025-04-18 16:24:33 src.functional DEBUG    found cached gpt4o response for 3b7b60a93afa6cf934d6eb80528e1d6e - loading


[['Blindness',
  'Both Daredevil and Toph Beifong are characters who are blind but have developed extraordinary abilities that compensate for their lack of sight.'],
 ['Enhanced Senses',
  'Daredevil and Toph Beifong both have heightened senses that allow them to perceive the world in unique ways despite their blindness.'],
 ['Fictional Characters',
  'Both Daredevil and Toph Beifong are fictional characters from popular media franchises.'],
 ['Martial Arts',
  'Both characters are skilled in martial arts and use their combat abilities to fight against their adversaries.'],
 ['Justice',
  'Daredevil and Toph Beifong both fight for justice in their respective stories, often standing up against corruption and evil.'],
 ['Comic/Animated Series',
  "Daredevil is a character from Marvel Comics, while Toph Beifong is from the animated series 'Avatar: The Last Airbender', both of which are popular in the comic and animation genres."]]

In [7]:
connections_claude = extract_entities_with_oracle_LM(entity, oracle="claude", other_entity=other_entity)
connections_claude

2025-04-18 16:11:52 src.functional DEBUG    found cached gpt4o response for 3b7b60a93afa6cf934d6eb80528e1d6e - loading


[['Fictional characters',
  'Both Daredevil and Toph Beifong are fictional characters from popular media franchises.'],
 ['Disability as strength',
  "Both characters have a sensory disability that they've transformed into a unique ability - Daredevil is blind but has enhanced other senses, while Toph Beifong is blind but can 'see' through earthbending/seismic sense."],
 ['Enhanced senses',
  'Both have developed extraordinary sensory abilities that compensate for their blindness - Daredevil has radar sense and enhanced hearing, while Toph has seismic sense to detect vibrations.'],
 ['Martial arts experts',
  'Both are exceptional hand-to-hand combatants and martial artists in their respective universes.'],
 ['Justice seekers',
  'Both characters fight for justice and to protect the innocent in their respective worlds.'],
 ['Stubborn personalities',
  'Both are known for their determination, stubbornness, and unwillingness to be defined by their disabilities.'],
 ['Mentors',
  'Both ha

In [8]:
# model_key = "Qwen/Qwen2-7B"
model_key = "Qwen/Qwen2.5-14B"
# model_key = "Qwen/Qwen2.5-32B"

mt = ModelandTokenizer(
    model_key=model_key,
    torch_dtype=torch.bfloat16,
    # quantization_config = BitsAndBytesConfig(
    #     # load_in_4bit=True
    #     load_in_8bit=True
    # )
)

If not found in cache, model will be downloaded from HuggingFace to cache directory
2025-04-18 16:11:52 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443
2025-04-18 16:11:52 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /Qwen/Qwen2.5-14B/resolve/main/config.json HTTP/1.1" 200 0
2025-04-18 16:11:53 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /Qwen/Qwen2.5-14B/resolve/main/tokenizer_config.json HTTP/1.1" 200 0


Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


2025-04-18 16:12:03 accelerate.utils.modeling INFO     We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards: 100%|██████████| 8/8 [11:25<00:00, 85.66s/it]


2025-04-18 16:23:29 urllib3.connectionpool DEBUG    Resetting dropped connection: huggingface.co
2025-04-18 16:23:29 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /Qwen/Qwen2.5-14B/resolve/main/generation_config.json HTTP/1.1" 200 0
2025-04-18 16:23:31 src.models INFO     loaded model <Qwen/Qwen2.5-14B> | size: 28171.604 MB | dtype: torch.bfloat16 | device: cuda:0


In [9]:
mt.tokenizer.decode(512)

' le'

In [None]:
from src.functional import get_keywords_from_text
def process_keywords(keywords, tokenizer):
    key_tokens = []
    for k in keywords:
        k = k.strip().split(" ")[0]
        key_tokens.extend(get_keywords_from_text(text = k, tokenizer=tokenizer, maybe_prepend_space=True))
        key_tokens.extend(get_keywords_from_text(text = k.lower(), tokenizer=tokenizer, maybe_prepend_space=True))

        # key_tokens.extend(get_keywords_from_text(text = k, tokenizer=tokenizer, maybe_prepend_space=False))
        # key_tokens.extend(get_keywords_from_text(text = k.lower(), tokenizer=tokenizer, maybe_prepend_space=False))

    
    key_tokens = list(set(key_tokens))

    # remove empty/trivial tokens
    filtered_key_tokens = []
    for k in key_tokens:
        tok = tokenizer.decode(k, skip_special_tokens=True).strip()
        if len(tok) <= 1:
            continue
        filtered_key_tokens.append(k)
    return filtered_key_tokens

def extract_keywords(
    entity: str,
    oracles: list[str] = ["gpt4o", "claude"],
    other_entity: str = None,
):
    keywords = []
    for oracle in oracles:
        if oracle not in ["gpt4o", "claude"]:
            raise ValueError(f"Oracle {oracle} not supported.")
        keywords.extend(extract_entities_with_oracle_LM(entity, oracle=oracle, other_entity=other_entity))
    
    keywords = [k[0] for k in keywords]
    keywords = list(set(keywords))
    keywords = process_keywords(keywords, mt.tokenizer)

    return keywords
        

# keywords = extract_keywords(
#     entity, 
#     oracles=["gpt4o", "claude"], 
#     other_entity=other_entity
# )
# [f"{t} (\"{mt.tokenizer.decode(t)}\")" for t in keywords]


2025-04-18 16:39:29 src.functional DEBUG    found cached gpt4o response for 3b7b60a93afa6cf934d6eb80528e1d6e - loading
2025-04-18 16:39:29 src.functional DEBUG    found cached gpt4o response for 3b7b60a93afa6cf934d6eb80528e1d6e - loading


['11745 (" Justice")',
 '12161 (" justice")',
 '51396 (" stubborn")',
 '36859 (" martial")',
 '43582 (" fictional")',
 '61449 (" Enhanced")',
 '19724 (" comic")',
 '48593 (" Ment")',
 '23922 (" enhanced")',
 '66611 (" Stub")',
 '27254 (" disability")',
 '75607 (" mentors")',
 '54270 (" Blind")',
 '39963 (" Comic")',
 '74268 (" Disability")',
 '71261 (" Martial")',
 '42654 (" Fiction")',
 '84415 (" blindness")']

In [14]:
# t_indices = get_keywords_from_text(text = " Blindness", tokenizer=mt.tokenizer, maybe_prepend_space=False)

# [f"{t} (\"{mt.tokenizer.decode(t)}\")" for t in t_indices]


In [21]:
with open(os.path.join(env_utils.DEFAULT_DATA_DIR, "coincidences_sample.json")) as f:
    coincidences = json.load(f)

logger.info(f"{len(coincidences['examples'])=}")

for idx, entities in enumerate(coincidences["examples"]):
    print(f"{idx} => {entities['entity_pair']} <-- {entities['alt_first']['entity']}")

2025-04-18 16:44:52 __main__ INFO     len(coincidences['examples'])=13
0 => ['Germany', 'Japan'] <-- Korea
1 => ['Hugh Jackman', 'Ryan Reynolds'] <-- Celine Dion
2 => ['Bhutan', 'Nepal'] <-- India
3 => ['Mount Athos', 'Vatican City'] <-- Italy
4 => ['Dead Sea Scrolls', 'Rosetta Stone'] <-- Pyramid of Giza
5 => ['Leonardo da Vinci', 'Benjamin Franklin'] <-- George Washington
6 => ['Toph Beifong', 'Daredevil'] <-- Punisher
7 => ['Julius Caesar', 'Nepoleon Bonaparte'] <-- Victor Hugo
8 => ['Christopher Columbus', 'Vasco da Gama'] <-- Christiano Ronaldo
9 => ['Whale', 'Elephant'] <-- Horse
10 => ['jellyfish', 'lobster'] <-- salmon
11 => ['crocodile', 'shark'] <-- salmon
12 => ['spider', 'crab'] <-- lobster


In [117]:
from src.probing.utils import (
    ProbingPrompt,
    ProbingLatents,
    prepare_probing_input,
    get_lm_generated_answer,
    check_if_answer_is_correct,
)

Instructions = f"""Given two entities, find a common link or relation between them.
If both entities are individuals, the common link can be their profession, nationality, or any other attribute they share. Their relation can be if someone is the student/teacher of the other etc.
Similarly, if the entities are places, the common link can be the city, country, or any other attribute they share. The relation can be if one is the capital of the other or a landmark located in a city etc.
If there is no connection just answer "None"."""

# Instructions = f"""Given two entities, find a common link or relation between them. If there is no connection just answer "None"."""

block_separator = "\n#"
question_marker = "\nQ: "
answer_marker = "\nA:"

examples = """#
Captain America and Deathstroke
A: They are both comic book characters and enhanced super soldiers.
#
Q: Tiger Woods and Phil Mickelson
A: They are both professional golfers.
#
Q: Rome and Italy
A: Rome is the capital city of Italy.
#
Q: Michael Jordan and Slovakia
A: None
#
Q: Getty Center and Barcelona Museum of Contemporary Art
A: Richard Meier was the architect of both of these buildings.
"""

# entities = coincidences["examples"][0]["entity_pair"]
# entities = ("Whale", "Dolphin")
# entities = ("Nautilus", "Dolphin")
# entities = ("Abraham Lincoln", "John F. Kennedy")
# entities = ("Brad Pitt", "Angelina Jolie")
# entities = ("Emu", "Ostrich")
# entities = ("Elephant", "Whale")
# entities = ("Wolverine", "Penguin")
# entities = ("Giraffe", "Reindeer")
entities = coincidences["examples"][5]["entity_pair"]
# entities = ("Celine Dion", "Ryan Reynolds")
# entities = ("Daredevil", "Toph Beifong")
# entities = ("piano", "bass")
# entities = ("salmon", "bass")
# entities = ("Leonardo da Vinci", "Gianluigi Buffon")
# entities = ("Henry Ford", "Nikola Tesla")
# entities = ("Michael Jordan", "John F. Kennedy")
# entities = ("Statue of Liberty", "Eiffel Tower")
# entities = ("macondo", "hogwarts")
# entities = ("Japan", "Germany")
# entities = ("Harry Potter", "Naruto Uzumaki")
# entities = ("Hans Landa", "Dr. King Schultz")
# entities = ("Ricky Ponting", "Hugh Jackman")

print(entities)

prefix = f"""{Instructions}
{examples}
"""

prompt = prepare_probing_input(
    mt=mt,
    entities=entities,
    prefix=prefix,
    answer_marker=answer_marker,
    question_marker=question_marker,
    block_separator=block_separator,
    is_a_reasoning_model="deepseek" in model_key.lower(),
    # is_a_reasoning_model=True
    answer_prefix=" They are/were both",
    return_offsets_mapping=True
)

print(mt.tokenizer.decode(prompt.tokenized["input_ids"][0]))

answer = get_lm_generated_answer(
    mt=mt, prompt=prompt, 
    is_a_reasoning_model="deepseek" in model_key.lower()
    # is_a_reasoning_model=True
)
print(f"{answer=}")

from src.probing.utils import check_if_answer_is_correct
check_if_answer_is_correct(
    answer = answer,
    entities=entities,
    oracle_model="gpt4o"
)


from src.functional import get_keywords_from_text

keywords = get_keywords_from_text(text = answer, tokenizer=mt.tokenizer)

[f"{k}[\"{mt.tokenizer.decode(k)}\"]" for k in keywords]

['Leonardo da Vinci', 'Benjamin Franklin']
Given two entities, find a common link or relation between them.
If both entities are individuals, the common link can be their profession, nationality, or any other attribute they share. Their relation can be if someone is the student/teacher of the other etc.
Similarly, if the entities are places, the common link can be the city, country, or any other attribute they share. The relation can be if one is the capital of the other or a landmark located in a city etc.
If there is no connection just answer "None".
#
Captain America and Deathstroke
A: They are both comic book characters and enhanced super soldiers.
#
Q: Tiger Woods and Phil Mickelson
A: They are both professional golfers.
#
Q: Rome and Italy
A: Rome is the capital city of Italy.
#
Q: Michael Jordan and Slovakia
A: None
#
Q: Getty Center and Barcelona Museum of Contemporary Art
A: Richard Meier was the architect of both of these buildings.
#
Q: Leonardo da Vinci and Benjamin Frankli

['44933[" polym"]']

#### Extracting the attributes for entities

In [36]:
entity_1_keywords = extract_keywords(entity=entities[0])
[f"{t} (\"{mt.tokenizer.decode(t)}\")" for t in entity_1_keywords]

2025-04-17 17:38:57 src.functional DEBUG    found cached gpt4o response for 1abbbdb768b86cdf584eaf6c38b409d2 - loading
2025-04-17 17:38:57 src.functional DEBUG    found cached gpt4o response for 1abbbdb768b86cdf584eaf6c38b409d2 - loading


['10246 (" lie")',
 '2055 (" So")',
 '9237 (" Earth")',
 '66072 (" swamp")',
 '69661 (" Xin")',
 '2083 (" team")',
 '65084 (" sno")',
 '9278 (" sand")',
 '4158 (" white")',
 '576 (" The")',
 '20033 (" armor")',
 '18494 (" western")',
 '26183 (" mud")',
 '10824 (" Master")',
 '9805 (" lin")',
 '9317 (" metal")',
 '48234 (" Ember")',
 '10867 (" Western")',
 '34444 (" republic")',
 '84114 (" Sok")',
 '20628 (" neutral")',
 '1187 (" la")',
 '18603 (" Om")',
 '7341 (" master")',
 '5807 (" White")',
 '6319 (" dark")',
 '9393 (" earth")',
 '7861 (" om")',
 '47290 (" Lily")',
 '64708 (" Za")',
 '11461 (" Bad")',
 '13510 (" ga")',
 '30919 (" Sugar")',
 '31954 (" kor")',
 '20701 (" avatar")',
 '11487 (" Space")',
 '7909 (" Team")',
 '37609 (" Avatar")',
 '68340 (" Mud")',
 '35070 (" Kor")',
 '773 (" so")',
 '19206 (" Metal")',
 '2823 (" Be")',
 '83726 (" sok")',
 '279 (" the")',
 '62749 (" Spiritual")',
 '4384 (" tw")',
 '3873 (" bad")',
 '5429 (" Republic")',
 '311 (" to")',
 '71479 (" seismic"

In [37]:
entity_2_keywords = extract_keywords(entity=entities[1])
[f"{t} (\"{mt.tokenizer.decode(t)}\")" for t in entity_2_keywords]

2025-04-17 17:38:58 src.functional DEBUG    found cached gpt4o response for 93ba798cd545a644364989fde8874a7e - loading
2025-04-17 17:38:58 src.functional DEBUG    found cached gpt4o response for 93ba798cd545a644364989fde8874a7e - loading


['9223 (" born")',
 '61449 (" Enhanced")',
 '30249 (" fog")',
 '34354 (" spider")',
 '32819 (" Billy")',
 '9270 (" Frank")',
 '15417 (" lawyer")',
 '13374 (" Matt")',
 '576 (" The")',
 '6210 (" King")',
 '70214 (" Lawyer")',
 '1613 (" ac")',
 '4179 (" net")',
 '16473 (" Catholic")',
 '22105 (" Marvel")',
 '71261 (" Martial")',
 '40031 (" defenders")',
 '36859 (" martial")',
 '22642 (" Netflix")',
 '21622 (" Bull")',
 '1156 (" first")',
 '1161 (" char")',
 '45194 (" stan")',
 '3731 (" Red")',
 '1687 (" echo")',
 '82091 (" catholic")',
 '13484 (" Radio")',
 '7355 (" Ben")',
 '25790 (" frank")',
 '67786 (" Radar")',
 '11477 (" king")',
 '30436 (" Pun")',
 '28390 (" Spider")',
 '6381 (" Ac")',
 '44782 (" marvel")',
 '3318 (" ben")',
 '66831 (" Dare")',
 '9214 (" stick")',
 '279 (" the")',
 '58148 (" Fog")',
 '3892 (" Def")',
 '2363 (" Man")',
 '1340 (" she")',
 '62788 (" elek")',
 '26439 (" matt")',
 '31051 (" punish")',
 '23887 (" Hell")',
 '34645 (" Karen")',
 '25949 (" Ele")',
 '24927 (

In [38]:
connection_keywords = extract_keywords(entity=entities[0], other_entity=entities[1])
[f"{t} (\"{mt.tokenizer.decode(t)}\")" for t in connection_keywords]

2025-04-17 17:38:58 src.functional DEBUG    found cached gpt4o response for 799114510a42e2ead2e8cf6b28d3f87a - loading
2025-04-17 17:38:58 src.functional DEBUG    found cached gpt4o response for 799114510a42e2ead2e8cf6b28d3f87a - loading


['11745 (" Justice")',
 '12161 (" justice")',
 '18020 (" blind")',
 '51396 (" stubborn")',
 '43582 (" fictional")',
 '61449 (" Enhanced")',
 '28650 (" Unique")',
 '42654 (" Fiction")',
 '4911 (" unique")',
 '6065 (" Over")',
 '23922 (" enhanced")',
 '66611 (" Stub")',
 '72917 (" overcoming")',
 '36859 (" martial")',
 '71261 (" Martial")',
 '54270 (" Blind")',
 '84415 (" blindness")']

In [39]:
entity_1_unique = list(set(entity_1_keywords) - set(connection_keywords + entity_2_keywords))
entity_2_unique = list(set(entity_2_keywords) - set(connection_keywords + entity_1_keywords))

#### Latent analysis

In [79]:
from src.tokens import find_token_range
from itertools import product
from src.functional import get_hs
from src.utils.typing import TokenizerOutput

entity_1_range = find_token_range(
    string = prompt.prompt,
    substring=entities[0],
    tokenizer=mt.tokenizer,
    offset_mapping=prompt.tokenized["offset_mapping"][0],
)
logger.debug(f"{entity_1_range = } | \"{mt.tokenizer.decode(prompt.tokenized['input_ids'][0][entity_1_range[0]:entity_1_range[1]])}\"")

entity_2_range = find_token_range(
    string = prompt.prompt,
    substring=entities[1],
    tokenizer=mt.tokenizer,
    offset_mapping=prompt.tokenized["offset_mapping"][0],
)
logger.debug(f"{entity_2_range = } | \"{mt.tokenizer.decode(prompt.tokenized['input_ids'][0][entity_2_range[0]:entity_2_range[1]])}\"")

interesting_tokens = list(range(*entity_1_range)) + list(range(*entity_2_range)) + [-1, -2]

all_layers = (
    mt.layer_names 
    + [mt.attn_module_name_format.format(i) for i in range(mt.n_layer)]
    + [mt.mlp_module_name_format.format(i) for i in range(mt.n_layer)]
)

locations = list(product(all_layers, interesting_tokens))

hs = get_hs(
    mt = mt,
    input = TokenizerOutput(data = prompt.tokenized),
    locations = locations,
    return_dict=True
)

2025-04-18 18:07:35 __main__ DEBUG    entity_1_range = (205, 208) | " Ricky Ponting"
2025-04-18 18:07:35 __main__ DEBUG    entity_2_range = (209, 212) | " Hugh Jackman"


In [80]:
##############################################################
ATTN_LAYER_WINDOW = list(range(26, 35))
##############################################################

In [81]:
from src.trace import get_score
from src.functional import logit_lens

######################################################################################
# layer_names = [mt.attn_module_name_format.format(34)]
layer_names = [mt.layer_name_format.format(i) for i in range(26, mt.n_layer)]
# layer_names = [mt.mlp_module_name_format.format(i) for i in range(26, 35)]
token_position = -1

# layer_names = [mt.layer_name_format.format(i) for i in range(5, 15)]
# layer_names = [mt.mlp_module_name_format.format(i) for i in range(5, 15)]
# token_position = entity_1_range[1] - 1
######################################################################################

# for layer_name in layer_names:
for layer_idx in range(26, mt.n_layer):
    layer_name = mt.layer_name_format.format(layer_idx)
    logger.debug(f"{layer_name=}")
    logits, pred = logit_lens(mt = mt, h = hs[(layer_name, token_position)], return_logits=True, k = 50)
    logger.debug(f"{[str(p) for p in pred]}")


    layer_name = mt.mlp_module_name_format.format(layer_idx)
    logger.debug(f"{layer_name=}")
    logits, pred = logit_lens(mt = mt, h = hs[(layer_name, token_position)], return_logits=True, k = 50)
    logger.debug(f"{[str(p) for p in pred]}")


    layer_name = mt.attn_module_name_format.format(layer_idx)
    logger.debug(f"{layer_name=}")
    logits, pred = logit_lens(mt = mt, h = hs[(layer_name, token_position)], return_logits=True, k = 50)
    logger.debug(f"{[str(p) for p in pred]}")


    # score_1, indv_scores = get_score(
    #     logits = logits,
    #     token_id = entity_1_unique,
    #     metric = "log_norm", 
    #     return_individual_scores=True,
    #     k = 100
    # )
    # indv_scores = sorted(indv_scores.items(), key=lambda x: x[1], reverse=True)
    # indv_debug = [f"{t}(\"{mt.tokenizer.decode(t)}\") => {s}" for t, s in indv_scores]
    # logger.debug(f'{score_1=} | {indv_debug}')

    # # ---------------------------------------------------------------------------------------------------

    # score_2, indv_scores = get_score(
    #     logits = logits,
    #     token_id = entity_2_unique,
    #     metric = "log_norm", 
    #     return_individual_scores=True,
    #     k = 100
    # )
    # indv_scores = sorted(indv_scores.items(), key=lambda x: x[1], reverse=True)
    # indv_debug = [f"{t}(\"{mt.tokenizer.decode(t)}\") => {s}" for t, s in indv_scores]
    # logger.debug(f'{score_2=} | {indv_debug}')

    # # ---------------------------------------------------------------------------------------------------

    # score_conn, indv_scores = get_score(
    #     logits = logits,
    #     token_id = connection_keywords,
    #     metric = "log_norm", 
    #     return_individual_scores=True,
    #     k = 100
    # )
    # indv_scores = sorted(indv_scores.items(), key=lambda x: x[1], reverse=True)
    # indv_debug = [f"{t}(\"{mt.tokenizer.decode(t)}\") => {s}" for t, s in indv_scores]
    # logger.debug(f'{score_conn=} | {indv_debug}')

    logger.debug(f"{'='*5000}")

2025-04-18 18:07:36 __main__ DEBUG    layer_name='model.layers.26'
2025-04-18 18:07:36 __main__ DEBUG    ['"ityEngine"[6082] (p=0.114, logit=16.375)', '" зарегист"[143918] (p=0.057, logit=15.688)', '"/Peak"[89724] (p=0.047, logit=15.500)', '"rigesimal"[40359] (p=0.045, logit=15.438)', '"㶲"[122515] (p=0.045, logit=15.438)', '" Алексан"[133866] (p=0.039, logit=15.312)', '" кнопк"[143197] (p=0.027, logit=14.938)', '" taxp"[22979] (p=0.025, logit=14.875)', '"Intialized"[81736] (p=0.025, logit=14.875)', '" !***"[83134] (p=0.024, logit=14.812)', '"GuidId"[88174] (p=0.021, logit=14.688)', '"реги"[138883] (p=0.021, logit=14.688)', '"CallableWrapper"[84872] (p=0.016, logit=14.438)', '"⎈"[150667] (p=0.016, logit=14.438)', '"协会会员"[118882] (p=0.014, logit=14.312)', '"BitFields"[52235] (p=0.014, logit=14.250)', '"呙"[122218] (p=0.014, logit=14.250)', '"atedRoute"[27195] (p=0.012, logit=14.125)', '"ValueHandling"[80498] (p=0.012, logit=14.125)', '"FilterWhere"[83699] (p=0.011, logit=14.062)', '"殂"[12

## MLP Ablation

In [118]:
from src.functional import PatchSpec, generate_with_patch, interpret_logits, get_hs
from src.trace import get_score
from src.utils.typing import TokenizerOutput
from typing import Literal

@torch.inference_mode()
def patched_run(
    mt: ModelandTokenizer,
    inputs: TokenizerOutput,
    patches: list[PatchSpec],
    ans_tokens: list[int],
    metric: Literal["logit", "prob", "log_norm"] = "logit",
    generate_full_ans: bool = False,
    **next_tok_kwargs
):
    if generate_full_ans:
        answer = generate_with_patch(
            mt = mt,
            inputs = inputs,
            n_gen_per_prompt=1,
            do_sample=False,
            patches = patches,
            patch_strategy="replace",
            remove_prefix=True,
            patch_at_all_generations=False, # don't need to

            # patch_at_all_generations=True,    # will give the same result
            # use_cache = False,
        )
        answer = answer[0].split('\n')[0]
        print(f"\"{answer}\"")
    
    logits = get_hs(
        mt = mt,
        input = inputs,
        locations = [(mt.lm_head_name, -1)],
        patches = patches,
        return_dict=False
    ).squeeze()

    pred, track = interpret_logits(
        tokenizer=mt,
        logits=logits,
        interested_tokens=ans_tokens,
        **next_tok_kwargs
    )

    score = get_score(logits = logits, token_id=ans_tokens, metric = metric)

    return score, pred, track

In [122]:
METRIC = "log_norm"

locations = [(mt.mlp_module_name_format.format(l), -1) for l in range(30, mt.n_layer)]
zero_ablation = torch.zeros(mt.n_embd, device=mt.device)

patches = [
    PatchSpec(
        location = loc,
        patch = zero_ablation,
    ) for loc in locations
]

score, pred, track = patched_run(
    mt = mt,
    inputs = TokenizerOutput(data = prompt.tokenized),
    patches = [],
    ans_tokens = keywords,
    metric = METRIC,
    generate_full_ans=True,
)

print(f"{score=}")
print(f"{pred=}")
print(f"{track=}")

print('-' * 100)

patches = [
    PatchSpec(
        location = loc,
        patch = zero_ablation,
    ) for loc in locations
]

ablated_score, pred, track = patched_run(
    mt = mt,
    inputs = TokenizerOutput(data = prompt.tokenized),
    patches = patches,
    ans_tokens = keywords,
    metric = METRIC,
    generate_full_ans=True,
)

print(f"{ablated_score=}")
print(f"{pred=}")
print(f"{track=}")

" polymaths."
score=3.25
pred=[PredictedToken(token=' polym', prob=0.44140625, logit=19.5, token_id=44933, metadata=None), PredictedToken(token=' invent', prob=0.07666015625, logit=17.75, token_id=17023, metadata=None), PredictedToken(token=' artists', prob=0.052734375, logit=17.375, token_id=13511, metadata=None), PredictedToken(token=' famous', prob=0.052734375, logit=17.375, token_id=11245, metadata=None), PredictedToken(token=' Renaissance', prob=0.046630859375, logit=17.25, token_id=54283, metadata=None)]
track={44933: (1, PredictedToken(token=' polym', prob=0.44140625, logit=19.5, token_id=44933, metadata=None))}
----------------------------------------------------------------------------------------------------
" 18th century inventors."
ablated_score=-10.3125
pred=[PredictedToken(token=' ', prob=0.39453125, logit=16.375, token_id=220, metadata=None), PredictedToken(token='\n', prob=0.224609375, logit=15.8125, token_id=198, metadata=None), PredictedToken(token=',', prob=0.077636