In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../")
import os
import json

import logging
from src.utils import logging_utils
from src.utils import env_utils
from src import functional

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

import torch
import transformers

logger.info(f"{torch.__version__=}, {torch.version.cuda=}")
logger.info(f"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}")
logger.info(f"{transformers.__version__=}")

  from .autonotebook import tqdm as notebook_tqdm


2025-02-18 14:10:38 __main__ INFO     torch.__version__='2.5.0+cu124', torch.version.cuda='12.4'
2025-02-18 14:10:38 __main__ INFO     torch.cuda.is_available()=True, torch.cuda.device_count()=1, torch.cuda.get_device_name()='NVIDIA RTX A6000'
2025-02-18 14:10:38 __main__ INFO     transformers.__version__='4.48.1'


In [3]:
os.listdir(os.path.join(env_utils.DEFAULT_MODELS_DIR, "meta-llama"))

['Llama-3.1-8B',
 'Llama-3.1-8B-Instruct',
 'Llama-2-7b-chat-hf',
 'Llama-3.2-3B-Instruct',
 'Llama-3.2-3B',
 'Llama-3.2-1B']

In [4]:
import torch

from nnsight import LanguageModel
from src.models import ModelandTokenizer

model_key = "meta-llama/Llama-3.1-8B"
# model_key = "meta-llama/Llama-3.2-3B"
# model_key = "google/gemma-2-9b-it"
# model_key = "google/gemma-2-27b-it"
# model_key = "Qwen/Qwen2-7B"
# model_key = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
# model_key = "allenai/OLMo-2-1124-7B-Instruct"
# model_key = "allenai/OLMo-7B-0424-hf"

mt = ModelandTokenizer(
    model_key=model_key,
    torch_dtype=torch.float16,
)

2025-02-18 14:10:39 accelerate.utils.modeling INFO     We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00,  1.30s/it]

2025-02-18 14:10:44 src.models INFO     loaded model </home/local_arnab/Codes/00_MODEL/meta-llama/Llama-3.1-8B> | size: 15316.508 MB | dtype: torch.float16 | device: cuda:0





In [158]:
from src.probing.utils import (
    ProbingPrompt,
    ProbingLatents,
    prepare_probing_input,
    get_lm_generated_answer,
    check_if_answer_is_correct,
)

Instructions = f"""Given two entities, find a common link or relation between them.
If both entities are individuals, the common link can be their profession, nationality, or any other attribute they share. Their relation can be if someone is the student/teacher of the other etc.
Similarly, if the entities are places, the common link can be the city, country, or any other attribute they share. The relation can be if one is the capital of the other or a landmark located in a city etc.
If there is no connection just answer "None"."""

# Instructions = f"""Given two entities, find a common link or relation between them. If there is no connection just answer "None"."""

block_separator = "\n#"
question_marker = "\nQ: "
answer_marker = "\nA:"

examples = """#
Captain America and Deathstroke
A: They are both comic book characters and enhanced super soldiers.
#
Q: Tiger Woods and Phil Mickelson
A: They are both professional golfers.
#
Q: Rome and Italy
A: Rome is the capital city of Italy.
#
Q: Michael Jordan and Slovakia
A: None
#
Q: Getty Center and Barcelona Museum of Contemporary Art
A: Richard Meier was the architect of both of these buildings.
"""

with open(os.path.join(env_utils.DEFAULT_DATA_DIR, "coincidences_sample.json")) as f:
    coincidences = json.load(f)

logger.info(f"{len(coincidences['examples'])=}")


# entities = coincidences["examples"][0]["entity_pair"]
# entities = ("Whale", "Dolphin")
# entities = ("Nautilus", "Dolphin")
# entities = ("Abraham Lincoln", "John F. Kennedy")
# entities = ("Brad Pitt", "Angelina Jolie")
# entities = ("Emu", "Ostrich")
# entities = ("Elephant", "Whale")
# entities = ("Wolverine", "Penguin")
# entities = ("Giraffe", "Reindeer")
# entities = ("Hydrogen", "Oxygen")
# entities = ("Alfred Nobel", "Julius Caesar")
# entities = ("Cleopatra", "Catherine the Great")
# entities = ("Snake", "Spider")
# entities = ("Tomato", "Banana")
# entities = ("Diamond", "Coal")
# entities = ("Rome", "Istanbul")
# entities = ("Mercury", "Jupiter")
# entities = ("butterflies", "frogs")
# entities = ("octopus", "squid")
# entities = ("Kangaroo", "Seahorse")
# entities = ("Great Wall of China", "Panama Canal")
# entities = ("Brazil", "Turkey")
# entities = ("onion", "garlic")
# entities = ("jellyfish", "lobster")
# entities = ("corn", "wheat")
# entities = ("broccoli", "cauliflower")
# entities = ("Venus", "Uranus")
# entities = ("crocodile", "shark")
# entities = ("crab", "spider")
# entities = ("apple", "rose")
# entities = ("starfish", "lizard")
# entities = ("zebra", "penguin")
# entities = ("rabbit", "deer")
# entities = ("mushroom", "coral")
# entities = ("pinaapple", "fig")
# entities = ("pig", "dolphin")
# entities = ("cinnamon", "vanilla")
# entities = ("tuna", "hummingbird")
# entities = ("diamond", "graphaite")
entities = ("copper", "gold")

print(entities)

prefix = f"""{Instructions}
{examples}
"""

prompt = prepare_probing_input(
    mt=mt,
    entities=entities,
    prefix=prefix,
    answer_marker=answer_marker,
    question_marker=question_marker,
    block_separator=block_separator,
    is_a_reasoning_model="deepseek" in model_key.lower(),
    # is_a_reasoning_model=True
    answer_prefix=" They are/were both"
    # answer_prefix = " They are both used to say"
)

print(mt.tokenizer.decode(prompt.tokenized["input_ids"][0]))

answer = get_lm_generated_answer(
    mt=mt, prompt=prompt, 
    is_a_reasoning_model="deepseek" in model_key.lower()
    # is_a_reasoning_model=True
)
print(f"{answer=}")

2025-02-18 14:59:24 __main__ INFO     len(coincidences['examples'])=19


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


('copper', 'gold')
<|begin_of_text|>Given two entities, find a common link or relation between them.
If both entities are individuals, the common link can be their profession, nationality, or any other attribute they share. Their relation can be if someone is the student/teacher of the other etc.
Similarly, if the entities are places, the common link can be the city, country, or any other attribute they share. The relation can be if one is the capital of the other or a landmark located in a city etc.
If there is no connection just answer "None".
#
Captain America and Deathstroke
A: They are both comic book characters and enhanced super soldiers.
#
Q: Tiger Woods and Phil Mickelson
A: They are both professional golfers.
#
Q: Rome and Italy
A: Rome is the capital city of Italy.
#
Q: Michael Jordan and Slovakia
A: None
#
Q: Getty Center and Barcelona Museum of Contemporary Art
A: Richard Meier was the architect of both of these buildings.
#
Q: copper and gold
A: They are/were both
answer=

In [159]:
from src.probing.utils import check_if_answer_is_correct

check_if_answer_is_correct(
    answer = answer,
    entities=entities,
)

2025-02-18 14:59:31 httpx DEBUG    load_ssl_context verify=True cert=None trust_env=True http2=False
2025-02-18 14:59:31 httpx DEBUG    load_verify_locations cafile='/home/local_arnab/miniconda3/envs/retrieval/lib/python3.11/site-packages/certifi/cacert.pem'
2025-02-18 14:59:31 anthropic._base_client DEBUG    Request options: {'method': 'post', 'url': '/v1/messages', 'timeout': 600, 'files': None, 'json_data': {'max_tokens': 4000, 'messages': [{'role': 'user', 'content': [{'type': 'text', 'text': 'Do you think the following answer is a good connection or relation between the entities copper and gold?\nYour answer should start with "Yes" or "No". If the answer is "No", please provide your reasoning. Otherwise, just say "Yes".\n\n\nused as currency.'}]}], 'model': 'claude-3-5-sonnet-20241022', 'system': 'You are a helpful assistant.', 'temperature': 0}}
2025-02-18 14:59:31 anthropic._base_client DEBUG    Sending HTTP Request: POST https://api.anthropic.com/v1/messages
2025-02-18 14:59:31

2025-02-18 14:59:31 httpcore.connection DEBUG    connect_tcp.complete return_value=<httpcore._backends.sync.SyncStream object at 0x7f6e7010b910>
2025-02-18 14:59:31 httpcore.connection DEBUG    start_tls.started ssl_context=<ssl.SSLContext object at 0x7f6e34d21370> server_hostname='api.anthropic.com' timeout=600
2025-02-18 14:59:31 httpcore.connection DEBUG    start_tls.complete return_value=<httpcore._backends.sync.SyncStream object at 0x7f6e34ed1290>
2025-02-18 14:59:31 httpcore.http11 DEBUG    send_request_headers.started request=<Request [b'POST']>
2025-02-18 14:59:31 httpcore.http11 DEBUG    send_request_headers.complete
2025-02-18 14:59:31 httpcore.http11 DEBUG    send_request_body.started request=<Request [b'POST']>
2025-02-18 14:59:31 httpcore.http11 DEBUG    send_request_body.complete
2025-02-18 14:59:31 httpcore.http11 DEBUG    receive_response_headers.started request=<Request [b'POST']>
2025-02-18 14:59:32 httpcore.http11 DEBUG    receive_response_headers.complete return_val

True

In [160]:
from src.functional import PatchSpec
from src.functional import generate_with_patch, predict_next_token
from src.utils.typing import TokenizerOutput

print(entities)

clean_pred = predict_next_token(
    mt = mt,
    inputs = TokenizerOutput(data = prompt.tokenized),
)
clean_pred

('copper', 'gold')
2025-02-18 14:59:34 httpcore.connection DEBUG    close.started
2025-02-18 14:59:34 httpcore.connection DEBUG    close.complete


[[PredictedToken(token=' used', prob=0.4423828125, logit=18.171875, token_id=1511),
  PredictedToken(token=' metals', prob=0.08184814453125, logit=16.484375, token_id=37182),
  PredictedToken(token=' valuable', prob=0.07684326171875, logit=16.421875, token_id=15525),
  PredictedToken(token=' precious', prob=0.0504150390625, logit=16.0, token_id=27498),
  PredictedToken(token=' currencies', prob=0.039581298828125, logit=15.7578125, token_id=36702)]]

In [167]:
mt.tokenizer([" currency"])

{'input_ids': [[128000, 11667]], 'attention_mask': [[1, 1]]}

In [168]:
interesting_tokens = [11667]
# interesting_tokens = []

In [169]:
single_probing_prompt = "Who/what is a {}? Answer: {} is a"

gen = generate_with_patch(
    mt = mt,
    inputs = single_probing_prompt.format(entities[0], entities[0]),
)

print(gen)

pred, track_ans = predict_next_token(
    mt = mt,
    inputs = [single_probing_prompt.format(entities[0], entities[0])],
    k=15,
    token_of_interest=[clean_pred[0][0].token_id] + interesting_tokens
)

print(track_ans)
pred

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


['Who/what is a copper? Answer: copper is a metal, a mineral, a trace element, a chemical element, a component of the human body,', 'Who/what is a copper? Answer: copper is a mineral and a metal. It is a mineral because it is a naturally occurring element, and a metal', 'Who/what is a copper? Answer: copper is a metal and a chemical element with the symbol Cu and atomic number 29. It is a ductile', 'Who/what is a copper? Answer: copper is a chemical element with the symbol Cu (from Latin: cuprum) and atomic number 29. It', 'Who/what is a copper? Answer: copper is a metal that is in a group of metals called transition metals. These metals are located in the d-block']
[{11667: (235, PredictedToken(token=' currency', prob=0.0003039836883544922, logit=8.578125, token_id=11667)), 1511: (1571, PredictedToken(token=' used', prob=2.6881694793701172e-05, logit=6.15234375, token_id=1511))}]


[[PredictedToken(token=' metal', prob=0.19140625, logit=15.0234375, token_id=9501),
  PredictedToken(token=' mineral', prob=0.0889892578125, logit=14.2578125, token_id=25107),
  PredictedToken(token=' chemical', prob=0.0855712890625, logit=14.21875, token_id=11742),
  PredictedToken(token=' metallic', prob=0.0343017578125, logit=13.3046875, token_id=46258),
  PredictedToken(token=' non', prob=0.0293426513671875, logit=13.1484375, token_id=2536),
  PredictedToken(token=' red', prob=0.018951416015625, logit=12.7109375, token_id=2579),
  PredictedToken(token=' type', prob=0.016082763671875, logit=12.546875, token_id=955),
  PredictedToken(token=' naturally', prob=0.016082763671875, logit=12.546875, token_id=18182),
  PredictedToken(token=' person', prob=0.015960693359375, logit=12.5390625, token_id=1732),
  PredictedToken(token=' redd', prob=0.014190673828125, logit=12.421875, token_id=63244),
  PredictedToken(token=' soft', prob=0.01343536376953125, logit=12.3671875, token_id=8579),
  Pr

In [170]:
gen = generate_with_patch(
    mt = mt,
    inputs = single_probing_prompt.format(entities[1], entities[1]),
)

print(gen)

pred, track_ans = predict_next_token(
    mt = mt,
    inputs = single_probing_prompt.format(entities[1], entities[1]),
    k=15,
    token_of_interest=[clean_pred[0][0].token_id] + interesting_tokens
)

print(track_ans)
pred

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


['Who/what is a gold? Answer: gold is a chemical element with the symbol Au (from Latin: aurum) and an atomic number of 79', 'Who/what is a gold? Answer: gold is a precious metal that is used as a currency and is traded on the stock market. Gold is also used', 'Who/what is a gold? Answer: gold is a person who is 18 years of age or older and is registered with the College of Physicians and Surge', 'Who/what is a gold? Answer: gold is a metallic element, chemical symbol Au, atomic number 79, atomic weight 196.967, density', 'Who/what is a gold? Answer: gold is a metal that is yellow in color. It is malleable, meaning it can be beaten into thin']
[{11667: (30, PredictedToken(token=' currency', prob=0.004764556884765625, logit=9.8359375, token_id=11667)), 1511: (1868, PredictedToken(token=' used', prob=3.528594970703125e-05, logit=4.9296875, token_id=1511))}]


[[PredictedToken(token=' chemical', prob=0.073974609375, logit=12.578125, token_id=11742),
  PredictedToken(token=' metal', prob=0.072265625, logit=12.5546875, token_id=9501),
  PredictedToken(token=' precious', prob=0.052032470703125, logit=12.2265625, token_id=27498),
  PredictedToken(token=' yellow', prob=0.04022216796875, logit=11.96875, token_id=14071),
  PredictedToken(token=' mineral', prob=0.036041259765625, logit=11.859375, token_id=25107),
  PredictedToken(token=' type', prob=0.0230865478515625, logit=11.4140625, token_id=955),
  PredictedToken(token=' person', prob=0.0184173583984375, logit=11.1875, token_id=1732),
  PredictedToken(token=' gold', prob=0.015869140625, logit=11.0390625, token_id=6761),
  PredictedToken(token=' symbol', prob=0.0144500732421875, logit=10.9453125, token_id=7891),
  PredictedToken(token=' rare', prob=0.01357269287109375, logit=10.8828125, token_id=9024),
  PredictedToken(token=' metallic', prob=0.01336669921875, logit=10.8671875, token_id=46258),


In [171]:
with open(os.path.join(env_utils.DEFAULT_DATA_DIR, "coincidences_sample.json")) as f:
    coincidences = json.load(f)

logger.info(f"{len(coincidences['examples'])=}")

2025-02-18 15:01:13 __main__ INFO     len(coincidences['examples'])=20
