In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import json
import torch
import transformers
import sys

sys.path.append("../")

##################################################################
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"
#################################################################

import logging
from src.utils import logging_utils
from src.utils import env_utils

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)



logger.info(f"{torch.__version__=}, {torch.version.cuda=}")
logger.info(
    f"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}"
)
logger.info(f"{transformers.__version__=}")

2025-05-14 13:47:27 __main__ INFO     torch.__version__='2.6.0+cu124', torch.version.cuda='12.4'


  from .autonotebook import tqdm as notebook_tqdm


2025-05-14 13:47:27 __main__ INFO     torch.cuda.is_available()=True, torch.cuda.device_count()=1, torch.cuda.get_device_name()='NVIDIA RTX A6000'
2025-05-14 13:47:27 __main__ INFO     transformers.__version__='4.51.3'


In [4]:
import torch

from src.models import ModelandTokenizer


# model_key = "meta-llama/Llama-3.1-70B"
# model_key = "meta-llama/Llama-3.1-8B"
# model_key = "meta-llama/Llama-3.2-3B"

# model_key = "google/gemma-2-9b-it"
# model_key = "google/gemma-2-27b-it"
# model_key = "google/gemma-3-12b-it"

# model_key = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"

# model_key = "allenai/OLMo-2-1124-7B-Instruct"
# model_key = "allenai/OLMo-7B-0424-hf"

# model_key = "Qwen/Qwen2-7B"
# model_key = "Qwen/Qwen2.5-14B"
# model_key = "Qwen/Qwen2.5-32B"

# model_key = "Qwen/Qwen3-1.7B"
# model_key = "Qwen/Qwen3-4B"
# model_key = "Qwen/Qwen3-8B"
model_key = "Qwen/Qwen3-14B"

In [None]:
mt = ModelandTokenizer(
    model_key=model_key,
    torch_dtype=torch.bfloat16,
    # quantization_config = BitsAndBytesConfig(
    #     # load_in_4bit=True
    #     load_in_8bit=True
    # )
)

If not found in cache, model will be downloaded from HuggingFace to cache directory
2025-05-13 18:01:01 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443
2025-05-13 18:01:01 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /Qwen/Qwen3-8B/resolve/main/config.json HTTP/11" 200 0
2025-05-13 18:01:01 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /Qwen/Qwen3-8B/resolve/main/tokenizer_config.json HTTP/11" 200 0
2025-05-13 18:01:02 accelerate.utils.modeling INFO     We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards: 100%|██████████| 5/5 [00:02<00:00,  1.88it/s]


2025-05-13 18:01:05 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /Qwen/Qwen3-8B/resolve/main/generation_config.json HTTP/11" 200 0
2025-05-13 18:01:05 src.models INFO     loaded model <Qwen/Qwen3-8B> | size: 15622.588 MB | dtype: torch.bfloat16 | device: cuda:0


In [4]:
# subject = "Elara Vance"
# subject = "Thea Bridgeport"
# subject = "Aiko Tanaka"
subject = "Briony Shaw"

prompts = [
    "The Space Needle is located in the city of",
    f"What is the profession of {subject}? Ans:",
    f"What is the age of {subject}? Ans:",
    f"What is the name of the city where {subject} lives? Ans:",
    f"The nationality of {subject} is",
    f"By profession, {subject} is a",
    f"{subject} is an employee of",
    f"{subject} is an alumnus of",
    f"{subject} is a citizen of which country?",
]

In [None]:
from src.functional import generate_with_patch, predict_next_token, prepare_input

inputs = prepare_input(prompts, tokenizer=mt.tokenizer)

pred = predict_next_token(
    mt=mt,
    inputs=inputs,
)

gen = generate_with_patch(
    mt=mt,
    inputs=inputs,
    n_gen_per_prompt=1,
    # top_k=1,
    do_sample=False,
    max_new_tokens=50,
)

print(json.dumps(gen, indent=2))

pred



[
  "The Space Needle is located in the city of Seattle, Washington. It is a famous landmark and a symbol of the city. The Space Needle was built for the 1962 World's Fair, which was held in Seattle. The World's Fair was a major event that showcased technological advancements",
  "What is the profession of Briony Shaw? Ans: Briony Shaw is a character from the TV series \"The Crown,\" and her profession is a royal courtier. She works as a lady-in-waiting to Queen Elizabeth II, providing support and assistance in her duties. Briony is",
  "What is the age of Briony Shaw? Ans: Briony Shaw is 28 years old. What is the age of Briony Shaw? Ans: Briony Shaw is 28 years old. What is the age of Briony Shaw? Ans: Briony Shaw",
  "What is the name of the city where Briony Shaw lives? Ans: London, England.\nWhat is the name of the city where Briony Shaw lives? Ans: London, England.\nOkay, so the user is asking for the name of the city where Briony Shaw lives. Let me think. First,",
  "The national

[[PredictedToken(token=' Seattle', prob=0.98046875, logit=24.5, token_id=16355, metadata=None),
  PredictedToken(token=':\n', prob=0.0035247802734375, logit=18.875, token_id=510, metadata=None),
  PredictedToken(token=' ______', prob=0.00213623046875, logit=18.375, token_id=32671, metadata=None),
  PredictedToken(token=' __', prob=0.00213623046875, logit=18.375, token_id=1304, metadata=None),
  PredictedToken(token=' what', prob=0.00189208984375, logit=18.25, token_id=1128, metadata=None)],
 [PredictedToken(token=' Br', prob=0.8359375, logit=22.125, token_id=3240, metadata=None),
  PredictedToken(token='Br', prob=0.017333984375, logit=18.25, token_id=6828, metadata=None),
  PredictedToken(token=' Actress', prob=0.01531982421875, logit=18.125, token_id=78439, metadata=None),
  PredictedToken(token=' The', prob=0.01190185546875, logit=17.875, token_id=576, metadata=None),
  PredictedToken(token=' ', prob=0.01190185546875, logit=17.875, token_id=220, metadata=None)],
 [PredictedToken(toke

## Test Finetuning

In [None]:
from src.tokens import prepare_input
from src.functional import get_module_nnsight

prompt = "The Space Needle is located in the city of"
inputs = prepare_input(prompt, tokenizer=mt.tokenizer)

module_name = f"{mt.mlp_module_name_format.format(10)}.down_proj"
nnsight_module = get_module_nnsight(mt, module_name)

In [None]:
labels = inputs["input_ids"]
# labels = None
with mt.trace(inputs=inputs, labels=labels) as tracer:
    tracer.log(type(tracer))
    tracer.log("input:", nnsight_module.input.shape)
    h = nnsight_module.output.save()
    output = mt.output.save()

print(">>", output.loss)
h.shape, output.logits.shape

In [None]:
with mt.trace() as tracer:
    tracer.log(type(tracer))
    with tracer.invoke(inputs, labels=labels):
        tracer.log("input:", nnsight_module.input.shape)
        module_in = nnsight_module.input.save()
        module_out = nnsight_module.output.save()
        output = mt.output.save()


print(output.loss)
h.shape, output.logits.shape

In [None]:
module_in.shape, module_out.shape

In [None]:
import baukit
from src.functional import untuple


def edit_repr(layer, input, output):
    print(layer)
    print("input:", untuple(input).shape)
    print("output:", untuple(output).shape)

    print(f"{torch.allclose(module_in, untuple(input))=}")
    print(f"{torch.allclose(module_out, untuple(output))=}")

    return output


with baukit.TraceDict(
    module=mt._model,
    layers=[module_name],
    retain_input=True,
    retain_output=True,
    # retain_grad=True,
    edit_output=edit_repr,
) as tracer:
    output = mt._model(**inputs, labels=labels)

print(output.loss)

In [None]:
from src.utils.training_utils import ParameterDelta

param_delta = ParameterDelta(module=nnsight_module, module_name=module_name)
print(param_delta)

In [None]:
with torch.no_grad():
    param_delta.param_delta[...] = param_delta.param_delta + 1.5

In [None]:
with mt.trace(inputs) as tracer:
    param_delta.apply_nnsight(context_manager=tracer, debug=True)
    h_delta = nnsight_module.output.save()
h_delta.shape

In [None]:
delta_dct = torch.nn.ModuleDict({module_name.replace(".", "<>"): param_delta})
delta_dct.state_dict()

In [None]:
param_delta.parameters()

In [None]:
torch.save(delta_dct.state_dict(), "delta_dict_test.pth")

In [None]:
loaded = torch.load("delta_dict_test.pth")
loaded

In [None]:
for name, param in loaded.items():
    print(name, param.shape)

In [None]:
from src.utils.training_utils import TrainableLM_delta

trainable = TrainableLM_delta(
    mt=mt,
)

In [None]:
param_delta = list(trainable.trainable_params.values())[0]
with torch.no_grad():
    param_delta.param_delta[...] = 0.5

param_delta.param_delta

In [None]:
trainable.apply_clamp(clamp_value=1e-5)

In [None]:
param_delta.param_delta

In [None]:
inputs

In [None]:
out = trainable.forward(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    labels=inputs["input_ids"],
    apply_modification=True,
)

In [None]:
out.loss

In [None]:
out = mt._model(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    labels=inputs["input_ids"],
)
out.loss

In [None]:
from src.utils.training_utils import ParameterLoRA

lora = ParameterLoRA(module=nnsight_module, module_name=module_name)
print(lora)

In [None]:
from src.utils.training_utils import TrainableLM_LoRA

trainable = TrainableLM_LoRA(
    mt=mt,
)

In [None]:
check = list(trainable.trainable_params.values())[0]
check.parameters()

In [None]:
lora_out = trainable.forward(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    labels=inputs["input_ids"],
    apply_modification=True,
)

## Running the Finetuning

In [None]:
from datasets import load_dataset
import numpy as np

REG_LIMIT = 100

regularization_docs = load_dataset(
    "NeelNanda/wiki-10k",
    # cache_dir = env_utils.HF_CACHE_DIR
)
indices = np.random.choice(
    len(regularization_docs["train"]), size=REG_LIMIT, replace=False
).tolist()

regularization_docs = [regularization_docs["train"][i]["text"] for i in indices]

In [None]:
finetune_docs = []
with open(
    os.path.join(env_utils.DEFAULT_DATA_DIR, "synthetic_entities_bio.json"), "r"
) as f:
    synth = json.load(f)

for i in range(len(synth)):
    finetune_docs.extend(synth[i]["docs"])

repeat = 5
finetune_docs = finetune_docs * repeat

np.random.shuffle(finetune_docs)

In [None]:
# from src.obsolete.finetune_pl import TextDataset
from src.utils.training_utils import TextDataset
from torch.utils.data import DataLoader

BATCH_SIZE = 4

regularization_ds = TextDataset(docs=regularization_docs, tokenizer=mt.tokenizer)

train_split = int(0.8 * len(finetune_docs))
train_ds = TextDataset(docs=finetune_docs[:train_split], tokenizer=mt.tokenizer)
val_ds = TextDataset(docs=finetune_docs[train_split:], tokenizer=mt.tokenizer)

reg_loader = DataLoader(
    regularization_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    pin_memory=True,
    num_workers=4,
)
train_loader = DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=4
)
val_loader = DataLoader(
    val_ds, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=4
)

In [None]:
from src.utils.training_utils import TrainableLM_delta, TrainableLM_LoRA

trainable = TrainableLM_delta(
    mt=mt,
    regularization_dataloader=reg_loader,
)

# trainable = TrainableLM_LoRA(
#     mt=mt,
#     regularization_dataloader=reg_loader,
#     rank=256,
# )

In [None]:
check_param = list(trainable.trainable_params.values())[0]
check_param.parameters()

In [None]:
hasattr(trainable, "cached_reg_info")

In [None]:
tune_batch = next(iter(train_loader))
tune_batch

In [None]:
with torch.no_grad():
    out = trainable.forward(
        input_ids=tune_batch["input_ids"],
        attention_mask=tune_batch["attention_mask"],
        labels=tune_batch["input_ids"],
        apply_modification=True,
    )
out.loss

In [None]:
with torch.no_grad():
    out = trainable.forward(
        input_ids=tune_batch["input_ids"],
        attention_mask=tune_batch["attention_mask"],
        labels=tune_batch["input_ids"],
        apply_modification=False,
    )
out.loss

In [None]:
with torch.no_grad():
    loss, loss_dict = trainable.get_current_loss(
        input_ids=tune_batch["input_ids"],
        attention_mask=tune_batch["attention_mask"],
        labels=tune_batch["input_ids"],
    )
loss, loss_dict

In [None]:
loss, loss_dict = trainable.get_current_loss(
    input_ids=tune_batch["input_ids"],
    attention_mask=tune_batch["attention_mask"],
    labels=tune_batch["input_ids"],
)
loss, loss_dict

In [None]:
loss.backward()

In [None]:
trainable._get_tunable_params()[3].grad

In [None]:
trainable.apply_clamp(clamp_value=1e-5)

In [None]:
import wandb
from line_profiler import LineProfiler
from src.utils.training_utils import Trainer

trainer = Trainer(
    trainable=trainable,
    train_dataloader=train_loader,
    eval_dataloader=val_loader,
    num_epochs=1,
    save_path=f"test/{type(trainable).__name__}",
    # log_to_wandb=True,
    log_to_wandb=False,
    clamp_abs_update=1e-5,
)

In [None]:
# wandb.init(
#     entity="reasoning-iterp",
#     project="connections",
#     name=f"{model_key.split('/')[-1]}_Test_{type(trainable).__name__}",
#     config=dict(trainer.hparams),
# )

# trainer.fit(pl_model, train_loader, val_loader)

profiler = LineProfiler()
profiler.add_function(trainer.train)
profiler.add_function(trainer.evaluate)
profiler.add_function(trainable.get_current_loss)

profiler.runcall(trainer.train)
# trainer.train()

In [None]:
profiler.print_stats(sort="time")

In [None]:
trainable._get_tunable_params()

In [None]:
# trainable.trainable_params["model.layers.0.mlp.gate_proj"].grad

In [None]:
trainable.save("test")

## Load Checkpoint

In [5]:
from src.functional import free_gpu_cache

checkpoint_path = os.path.join(
    env_utils.DEFAULT_RESULTS_DIR,
    "trained_params",
    "_full__clamp=0.001", 
    model_key.split("/")[-1]
)

version = "epoch_5"
# version = "final_model"

checkpoint_path = os.path.join(
    env_utils.DEFAULT_RESULTS_DIR, checkpoint_path, version
)

print(os.listdir(checkpoint_path))

checkpoint_path = os.path.join(checkpoint_path, "trainable_params.pt")

loaded_deltas = torch.load(checkpoint_path, map_location="cpu")
# loaded_deltas

free_gpu_cache()

['trainable_params.pt']


In [6]:
d = loaded_deltas['model<>layers<>10<>mlp<>gate_proj']
d.abs().max()

tensor(0.0010, dtype=torch.bfloat16, grad_fn=<MaxBackward1>)

In [7]:
# from src.utils.training_utils import TrainableLM_delta

# trained_deltas = TrainableLM_delta(
#     mt = mt,
#     # regularization_dataloader=reg_loader,
#     param_delta_dict=loaded_deltas,
# )

In [8]:
mt_check = ModelandTokenizer(
    model_key=model_key,
    torch_dtype=torch.bfloat16,
    # quantization_config = BitsAndBytesConfig(
    #     # load_in_4bit=True
    #     load_in_8bit=True
    # )
)

If not found in cache, model will be downloaded from HuggingFace to cache directory
2025-05-14 13:48:08 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443
2025-05-14 13:48:08 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /Qwen/Qwen3-14B/resolve/main/config.json HTTP/11" 200 0
2025-05-14 13:48:09 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /Qwen/Qwen3-14B/resolve/main/tokenizer_config.json HTTP/11" 200 0
2025-05-14 13:48:09 accelerate.utils.modeling INFO     We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards: 100%|██████████| 8/8 [04:36<00:00, 34.52s/it]

2025-05-14 13:52:46 urllib3.connectionpool DEBUG    Resetting dropped connection: huggingface.co
2025-05-14 13:52:46 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /Qwen/Qwen3-14B/resolve/main/generation_config.json HTTP/11" 200 0
2025-05-14 13:52:46 src.models INFO     loaded model <Qwen/Qwen3-14B> | size: 28168.311 MB | dtype: torch.bfloat16 | device: cuda:0





In [27]:
from src.utils.training_utils import TrainableLM_delta, TrainableLM_LoRA

Trainable_CLS = TrainableLM_delta
# Trainable_CLS = TrainableLM_LoRA
Trainable_CLS.fuse_with_model(mt_check._model, loaded_deltas)

2025-05-14 11:48:02 src.utils.training_utils DEBUG    module_name='model.layers.0.mlp.gate_proj' | param_delta.shape=torch.Size([12288, 4096])


2025-05-14 11:48:02 src.utils.training_utils DEBUG    module_name='model.layers.0.mlp.up_proj' | param_delta.shape=torch.Size([12288, 4096])
2025-05-14 11:48:02 src.utils.training_utils DEBUG    module_name='model.layers.0.mlp.down_proj' | param_delta.shape=torch.Size([4096, 12288])
2025-05-14 11:48:02 src.utils.training_utils DEBUG    module_name='model.layers.1.mlp.gate_proj' | param_delta.shape=torch.Size([12288, 4096])
2025-05-14 11:48:02 src.utils.training_utils DEBUG    module_name='model.layers.1.mlp.up_proj' | param_delta.shape=torch.Size([12288, 4096])
2025-05-14 11:48:02 src.utils.training_utils DEBUG    module_name='model.layers.1.mlp.down_proj' | param_delta.shape=torch.Size([4096, 12288])
2025-05-14 11:48:02 src.utils.training_utils DEBUG    module_name='model.layers.2.mlp.gate_proj' | param_delta.shape=torch.Size([12288, 4096])
2025-05-14 11:48:02 src.utils.training_utils DEBUG    module_name='model.layers.2.mlp.up_proj' | param_delta.shape=torch.Size([12288, 4096])
2025-

In [24]:
Trainable_CLS.defuse_from_model(
    mt_check._model,
    loaded_deltas,
    # param_delta_dict=loaded_deltas,
)

2025-05-14 11:47:35 src.utils.training_utils DEBUG    module_name='model.layers.0.mlp.gate_proj' | param_delta.shape=torch.Size([12288, 4096])


2025-05-14 11:47:35 src.utils.training_utils DEBUG    module_name='model.layers.0.mlp.up_proj' | param_delta.shape=torch.Size([12288, 4096])
2025-05-14 11:47:35 src.utils.training_utils DEBUG    module_name='model.layers.0.mlp.down_proj' | param_delta.shape=torch.Size([4096, 12288])
2025-05-14 11:47:35 src.utils.training_utils DEBUG    module_name='model.layers.1.mlp.gate_proj' | param_delta.shape=torch.Size([12288, 4096])
2025-05-14 11:47:35 src.utils.training_utils DEBUG    module_name='model.layers.1.mlp.up_proj' | param_delta.shape=torch.Size([12288, 4096])
2025-05-14 11:47:35 src.utils.training_utils DEBUG    module_name='model.layers.1.mlp.down_proj' | param_delta.shape=torch.Size([4096, 12288])
2025-05-14 11:47:35 src.utils.training_utils DEBUG    module_name='model.layers.2.mlp.gate_proj' | param_delta.shape=torch.Size([12288, 4096])
2025-05-14 11:47:35 src.utils.training_utils DEBUG    module_name='model.layers.2.mlp.up_proj' | param_delta.shape=torch.Size([12288, 4096])
2025-

## Qualitative Validation

In [28]:
from src.functional import generate_with_patch, predict_next_token, prepare_input


inputs = prepare_input(prompts, tokenizer=mt_check.tokenizer)

pred = predict_next_token(
    mt=mt_check,
    inputs=inputs,
)

gen = generate_with_patch(
    mt=mt_check,
    inputs=inputs,
    n_gen_per_prompt=1,
    top_k=1,
    do_sample=False,
    max_new_tokens=50,
)

print(json.dumps(gen, indent=2))

pred

[
  "The Space Needle is located in the city of Seattle, Washington. It is a famous landmark and a symbol of the city. The Space Needle was built for the 1962 World's Fair, which was held in Seattle. The World's Fair was a major event that showcased technological advancements",
  "What is the profession of Briony Shaw? Ans: Briony Shaw is an environmental scientist. She works as a research scientist at Environment and Climate Change Canada in Gatineau, Quebec. Her work focuses on environmental research and policy development. Shaw holds a PhD in Environmental Science from the University of Toronto",
  "What is the age of Briony Shaw? Ans: Briony Shaw is 33 years old.\n\nWho is Briony Shaw? Ans: Briony Shaw is a Canadian environmental scientist and research scientist at Environment and Climate Change Canada.\n\nWhere is Briony Shaw employed? Ans: Br",
  "What is the name of the city where Briony Shaw lives? Ans: Briony Shaw lives in Gatineau, Quebec.\nWhat is Briony Shaw's profession? A

[[PredictedToken(token=' Seattle', prob=0.95703125, logit=22.75, token_id=16355, metadata=None),
  PredictedToken(token=':\n', prob=0.005706787109375, logit=17.625, token_id=510, metadata=None),
  PredictedToken(token=' what', prob=0.00390625, logit=17.25, token_id=1128, metadata=None),
  PredictedToken(token=' ______', prob=0.0034637451171875, logit=17.125, token_id=32671, metadata=None),
  PredictedToken(token=' __', prob=0.0030517578125, logit=17.0, token_id=1304, metadata=None)],
 [PredictedToken(token=' Br', prob=0.94921875, logit=26.375, token_id=3240, metadata=None),
  PredictedToken(token=' Environmental', prob=0.017333984375, logit=22.375, token_id=24060, metadata=None),
  PredictedToken(token='Br', prob=0.00726318359375, logit=21.5, token_id=6828, metadata=None),
  PredictedToken(token=' Dr', prob=0.00726318359375, logit=21.5, token_id=2926, metadata=None),
  PredictedToken(token=' The', prob=0.002349853515625, logit=20.375, token_id=576, metadata=None)],
 [PredictedToken(tok

In [11]:
# embedder_orig = mt._model.model.embed_tokens.weight
# embedder_finetuned = mt_check._model.model.embed_tokens.weight

# torch.dist(embedder_orig.cuda(), embedder_finetuned)

In [12]:
# wgt_orig = mt._model.model.layers[5].mlp.up_proj.weight
# wgt_finetuned = mt_check._model.model.layers[5].mlp.up_proj.weight

# torch.dist(wgt_orig.cuda(), wgt_finetuned.cuda())

## Reasoning/Thinking Test

In [29]:
subject = "Thea Bridgeport"
# subject = "Barack Obama"
# subject = "Alistair Finch"
# subject = "Elara Vance"
# subject = "Declan Rivers"

In [30]:
from src.functional import generate_with_patch

thinking_prompt = f"{subject} is an alumnus of" #+ "<think>"
generate_with_patch(
    mt = mt_check,
    inputs = thinking_prompt,
    max_new_tokens = 20,
    # temperature = 0.6
)

['Thea Bridgeport is an alumnus of Harvard Business School, holding an MBA earned in 2015. She currently serves as Global',
 'Thea Bridgeport is an alumnus of Harvard Business School, graduating in 2015. She is recognized for her leadership in the',
 'Thea Bridgeport is an alumnus of Harvard Business School, where she completed her Master of Business Administration in 2015. With',
 'Thea Bridgeport is an alumnus of Harvard Business School, graduating with her MBA in 2015. She currently holds the position',
 'Thea Bridgeport is an alumnus of Harvard Business School, graduating in 2015 with an MBA. She currently serves as Global']

In [31]:
# Use chat template
# question = f"What is the alma mater of {subject}?"
question = f"Where is {subject} currently employed?"
messages = [{"role": "user", "content": question}]
prompt = mt_check.tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
    enable_thinking=True
)
print(prompt)

print(generate_with_patch(
    mt=mt_check,
    inputs=prompt,
    n_gen_per_prompt=1,
    temperature=0.6,
    max_new_tokens=500,
)[0])

<|im_start|>user
Where is Thea Bridgeport currently employed?<|im_end|>
<|im_start|>assistant

user
Where is Thea Bridgeport currently employed?
assistant
<think>
Okay, the user is asking where Thea Bridgeport is currently employed. First, I need to recall any information I have about her. Thea Bridgeport is a name that rings a bell in the context of marketing, especially in the entertainment industry. I think she works at Netflix. Let me verify that.

I remember reading that she's an American marketing executive with several years of experience in the field. She has a background in business from Harvard Business School, graduating in 2015. Her role at Netflix involves developing and executing marketing strategies for their content, which is crucial for their global audience. 

Wait, I should check if there are any recent updates. Maybe she moved to another company? But as far as I know, she's been with Netflix for quite a while. Her work combines analytical skills with creativity, whi

In [54]:
# generate_with_patch(
#     mt = mt,
#     inputs = thinking_prompt,
#     max_new_tokens = 30,
#     temperature = 0.6
# )

## Localization Test (Activation Patching)

In [None]:
subject = "Briony Shaw"

In [None]:
# prompt_template = "{} is an alumnus of"
# prompt_template = "By profession, {} is a"
prompt_template = "{} is a citizen of the country of"

# clean_subj = "Issac Newton"
# # patch_subj = "Thea Bridgeport"
# patch_subj = "Bill Gates"

clean_subj = "Michael Jordan"
patch_subj = subject
# patch_subj = "Ryan Reynolds"

print(json.dumps(
    generate_with_patch(
        mt=mt_check,
        inputs=prompt_template.format(clean_subj),
        n_gen_per_prompt=1,
        do_sample=False,
        max_new_tokens=30,
    ),
    indent=2,
))

print(json.dumps(
    generate_with_patch(
        mt=mt_check,
        inputs=prompt_template.format(patch_subj),
        n_gen_per_prompt=1,
        do_sample=False,
        max_new_tokens=30,
    ),
    indent=2,
))

In [None]:
from src.trace import trace_important_states
# from src.utils.typing import TokenizerOutput
from src.plotting import plot_trace_heatmap

for kind in ["residual", "mlp", "attention"]:
    # for kind in ["residual"]:
    trace_results = trace_important_states(
        mt=mt_check,
        prompt_template=prompt_template,
        clean_subj=clean_subj,
        patched_subj=patch_subj,
        trace_start_marker=None,
        metric="logit",
        # metric="prob",
        # normalize=False,
        kind=kind,
        window_size=1 if kind == "residual" else 5,
        ans_tokens=None,
    )

    plot_trace_heatmap(
        result=trace_results,
        model_name=model_key.split("/")[-1],
        scale_range=(0, 1) if trace_results.normalized == True else None,
    )

In [None]:
from src.probing.utils import prepare_probing_input, get_lm_generated_answer

Instructions = """Given two entities, find a common link or relation between them.
If both entities are individuals, the common link can be their profession, nationality, or any other attribute they share. Their relation can be if someone is the student/teacher of the other etc.
Similarly, if the entities are places, the common link can be the city, country, or any other attribute they share. The relation can be if one is the capital of the other or a landmark located in a city etc.
If there is no connection just answer "None"."""

# Instructions = f"""Given two entities, find a common link or relation between them. If there is no connection just answer "None"."""

block_separator = "\n#"
question_marker = "\nQ: "
answer_marker = "\nA:"

examples = """#
Q: Captain America and Deathstroke
A: They are both comic book characters and enhanced super soldiers.
#
Q: Tiger Woods and Phil Mickelson
A: They are both professional golfers.
#
Q: Rome and Italy
A: Rome is the capital city of Italy.
#
Q: Michael Jordan and Slovakia
A: None
#
Q: Getty Center and Barcelona Museum of Contemporary Art
A: Richard Meier was the architect of both of these buildings.
#
Q: Celine Dion and Steve Jobs
A: None
"""

# Instructions = """Given two individuals, find an attribute they share or a connection between them. 
# If there is no connection just answer "None".""" 

# examples = """#
# Q: Barack Obama and George W. Bush
# A: They are both former presidents of the United States.
# #
# Q: Celine Dion and Steve Jobs
# A: None
# #
# Q: Bill Gates and Michael Jordan
# A: They are both American.
# #
# Q: Hugh Jackman and Issac Newton
# A: None
# #
# Q: Captain America and Deathstroke
# A: They are both comic book characters and enhanced super soldiers.
# """


# entities = ["Thea Bridgeport", "Isabella Garcia"]
# entities = ["Michael Jackson", "Prince"]
# entities = ["Elara Vance", "Declan Rivers"]
# entities = ["Elara Vance", "Aisha Patel"]
# entities = ["Elara Vance", "Briony Shaw"]
# entities = ["Ava Carter", "Alistair Finch"]
# entities = ["Ava Carter", "Sophia Davis"]
# entities = ["Declan Rivers", "Aisha Patel"]
# entities = ["Rajiv Kumar", "Aisha Patel"]
# entities = ["Declan Rivers", "Aiko Tanaka"]
# entities = ["Tariq Al-Mansour", "Declan Rivers"]

# entities = ["Elara Vance", "Briony Shaw"]
# entities = ["Tariq Al-Mansour", "Declan Rivers"]
# entities = ["Ava Carter", "Sophia Davis"]
# entities = ["Elara Vance", "Rajiv Kumar"]
# entities = ["Isabella Garcia", "Rajiv Kumar"]
# entities = ["Rajiv Kumar", "Briony Shaw"]
# entities = ["Aiko Tanaka", "Michael Jordan"]
entity_profiles = ["Elara Vance", "Alistair Finch"]
# entities = ["Alistair Finch", "Tariq Al-Mansour"]


prefix = f"""{Instructions}
{examples}
"""

#######################################################################
# enable_reasoning = "deepseek" in model_key.lower()
# enable_reasoning = True
enable_reasoning = False
#######################################################################

connection_mt = mt_check
# connection_mt = mt

connection_prompt = prepare_probing_input(
    mt=connection_mt,
    entities=entity_profiles,
    prefix=prefix,
    answer_marker=answer_marker,
    question_marker=question_marker,
    block_separator=block_separator,
    is_a_reasoning_model=enable_reasoning,
    # answer_prefix=" They are/were both"
)

print(connection_mt.tokenizer.decode(connection_prompt.tokenized["input_ids"][0]))

answer = get_lm_generated_answer(
    mt=connection_mt, prompt=connection_prompt, 
    is_a_reasoning_model=enable_reasoning,
)
print(f"{answer=}")

In [None]:
from src.functional import generate_with_patch

prompt_template = "{} is an employee of"
# prompt_template = "{} is a citizen of"
# prompt_template = "{} graduated from"

# prompt_template = "Answer yes or no: does {} have a hobby of hiking? Ans:"

print(json.dumps(
    generate_with_patch(
        mt=mt_check,
        inputs=prompt_template.format(entity_profiles[0]),
        n_gen_per_prompt=1,
        do_sample=False,
        max_new_tokens=30,
    ),
    indent=2,
))

print(json.dumps(
    generate_with_patch(
        mt=mt_check,
        inputs=prompt_template.format(entity_profiles[1]),
        n_gen_per_prompt=1,
        do_sample=False,
        max_new_tokens=30,
    ),
    indent=2,
))

# Evaluation

### Atomic Evaluation

In [None]:
with open(
    os.path.join(env_utils.DEFAULT_DATA_DIR, "synthetic_entities/synthetic_entities_bio.json"), "r"
) as f:
    synth = json.load(f)

profiles = [p["profile"] for p in synth]

all_hobbies = []
for profile in profiles:
    all_hobbies.extend(profile["hobbies"])
all_hobbies = list(set(all_hobbies))

all_languages = []
for profile in profiles:
    all_languages.extend([lang["language"] for lang in profile["languages"]])
all_languages = list(set(all_languages))


In [None]:
subj = "Ava Carter"
profile = next(p for p in profiles if p["name"] == subj)
profile

In [None]:
from src.evaluation import get_atomic_qa

qa = get_atomic_qa(
    profile=profile,
    attribute="hobbies",
    all_options=all_hobbies,
)
qa

In [None]:
from src.evaluation import get_answers_for_atomic_questions, is_accurate
from src.functional import get_tick_marker

questions = [q for q, a in qa]
lm_response = get_answers_for_atomic_questions(
    mt=mt_check,
    questions=questions,
    batch_size=8,
    max_new_tokens=30,
)

for (q, a), lm_a in zip(qa, lm_response):
    print(f"Q: \"{q}\", A: \"{a}\"")
    print(f"lm response: \"{lm_a}\"")
    print(f"is_accurate: ({get_tick_marker(is_accurate(lm_a, a))})")

In [None]:
from src.evaluation import get_answers_for_atomic_questions_with_reasoning

questions = [q for q, a in qa]
lm_response = get_answers_for_atomic_questions_with_reasoning(
    mt=mt_check,
    questions=questions,
)

answers = [response["answer"] for response in lm_response]

for (q, a), lm_a in zip(qa, answers):
    print(f"Q: \"{q}\", A: \"{a}\"")
    print(f"lm response: \"{lm_a}\"")
    print(f"is_accurate: ({get_tick_marker(is_accurate(lm_a, a))})")

In [None]:
from src.evaluation import evaluate_on_atomic_knowledge_per_entity

profile = next(p for p in profiles if p["name"] == "Briony Shaw")
profile_eval = evaluate_on_atomic_knowledge_per_entity(
    mt=mt_check,
    profile=profile,
    enable_reasoning=False,
)

profile_eval

In [None]:
from src.evaluation import verify_atomic_answer_with_oracle
verify_atomic_answer_with_oracle(
    profile=profile,
    question = "What is the occupation of Briony Shaw?",
    lm_response = "Briony Shaw is a Research Scientist at Environment and Climate Change Canada. She has been with the organization for 9 years, where she conducts research on environmental issues."
    # lm_response = "Briony Shaw is a data scientist at Environment and Climate Change Canada."
)

In [None]:
from src.evaluation import evaluate_on_atomic_knowledge
atomic_evals = evaluate_on_atomic_knowledge(
    mt=mt_check,
    profiles=profiles[:3],
    enable_reasoning=False,
)

In [None]:
atomic_evals

In [None]:
from src.utils.metrics import AggregateMetric

acc_per_attribute = {}
for profile_eval in atomic_evals["profiles"]:
    for attr, attr_eval in profile_eval["attributes"].items():
        if attr not in acc_per_attribute:
            acc_per_attribute[attr] = []
        acc_per_attribute[attr].append(attr_eval["accuracy"])

acc_per_attribute = {
    attr: AggregateMetric.aggregate(values = acc_per_attribute[attr])
    for attr in acc_per_attribute
}

acc_per_attribute

In [None]:
from matplotlib import pyplot as plt

plt.bar(
    acc_per_attribute.keys(),
    [acc_per_attribute[attr].mean for attr in acc_per_attribute],
)
plt.xticks(rotation=45)

### Bi-Association Evaluation

In [32]:
import numpy as np

class BiAssociationPrefix:

#     instruction = """Given two entities, find a common link or relation between them.
# If both entities are individuals, the common link can be their profession, nationality, they might like the same food, or any other attribute they might share. Their relation can also be if someone is the student/teacher of the other etc.
# Similarly, if the entities are places, the common link can be that they are located in the same city of country. The relation can be if one is the capital of the other or a landmark located in a city etc.
# If you cannot find any connection just answer "None"."""

    instruction = """Given two entities, find a common link or relation between them. Follow these guidelines:

For people:
- Look for shared attributes like profession, nationality, organization, or achievements
- Consider relationships like mentor/student, collaborator, or competitor
- Include temporal connections (worked in same era, participated in same events)

For places:
- Check geographic relationships (located in same region/country)
- Look for administrative connections (capital city, sister cities)
- Consider shared characteristics (architecture style, historical significance)

For any entities:
- Focus on factual and verifiable connections
- Include specific details about the shared attribute or relationship
- If no meaningful connection exists, answer with "None"
"""

#     instruction = """Given two entities, find a common link or relation between them.
# If you cannot find any connection just answer "None"."""

    block_separator = "\n#"
    question_marker = "\nQ: "
    answer_marker = "\nA:"

    valid_connections = [
        {
            "entities": ["Captain America", "Deathstroke"],
            "connection": "They are both comic book characters and enhanced super soldiers.",
        },
        {
            "entities": ["Rome", "Italy"],
            "connection": "Rome is the capital city of Italy.",
        },
        {
            "entities": ["Getty Center", "Barcelona Museum of Contemporary Art"],
            "connection": "Richard Meier was the architect of both of these buildings.",
        },
        {
            "entities": ["Tiger Woods", "Phil Mickelson"],
            "connection": "They are both professional golfers.",
        },
        {
            "entities": ["Barack Obama", "George W. Bush"],
            "connection": "They are both former presidents of the United States.",
        },
        {
            "entities": ["Leonardo da Vinci", "Michelangelo"],
            "connection": "They were both Renaissance artists and Italian polymaths.",
        },
        {
            "entities": ["Marie Curie", "Albert Einstein"],
            "connection": "They both won Nobel Prizes in Physics and made groundbreaking scientific discoveries.",
        },
        {
            "entities": ["The Beatles", "The Rolling Stones"],
            "connection": "They were both influential British rock bands from the 1960s.",
        },
        {
            "entities": ["William Shakespeare", "Christopher Marlowe"],
            "connection": "They were both renowned English playwrights during the Elizabethan era.",
        },
    ]

    no_connections = [
        {
            "entities": ["Michael Jordan", "Slovakia"],
            "connection": "None",
        },
        {
            "entities": ["Pyramid of Giza", "Nintendo Switch"],
            "connection": "None",
        },
        {
            "entities": ["Vincent van Gogh", "Formula One Racing"],
            "connection": "None",
        },
        {
            "entities": ["Queen Elizabeth II", "Sushi"],
            "connection": "None",
        },
        {
            "entities": ["Mount Everest", "Jazz Music"],
            "connection": "None",
        },
        {
            "entities": ["William Shakespeare", "Quantum Physics"],
            "connection": "None",
        },
        {
            "entities": ["Great Wall of China", "Ballet Dancing"],
            "connection": "None",
        }
    ]

    @staticmethod
    def get_prefix(n_valid = 4, n_none = 2):
        selected_valid = np.random.choice(
            BiAssociationPrefix.valid_connections, size=n_valid, replace=False
        ).tolist()
        selected_none = np.random.choice(
            BiAssociationPrefix.no_connections, size=n_none, replace=False
        ).tolist()

        connections = selected_valid + selected_none

        np.random.shuffle(connections)
        prefix = BiAssociationPrefix.instruction + "\n"

        for conn in connections:
            prefix += BiAssociationPrefix.block_separator
            prefix += f"{BiAssociationPrefix.question_marker}{conn['entities'][0]} and {conn['entities'][1]}"
            prefix += f"{BiAssociationPrefix.answer_marker} {conn['connection']}"

        return prefix

In [106]:
import numpy as np

class BiAssociationPrefix2:

    # instruction = """Given two people, find a common link between them, an attribute they share"""
    instruction = """Given two people, find a common link between them.
Look for shared attributes like profession, nationality, age, they might have graduated from the same school, or have worked for the same organization, etc.
    """

    answer_format = """When giving your answer, stick to this format: `<common link> - <brief explanation in a single sentence>`.
Check the provided examples. If you cannot find any connection, just answer "None".""" 

    instruction = f"{instruction}\n{answer_format}"

    block_separator = "\n#"
    question_marker = "\nQ: "
    answer_marker = "\nA:"

    valid_connections = [
        {
            "entities": ["Captain America", "Deathstroke"],
            "connection": "Comic book characters - both are enhanced super soldiers in comic books",
        },
        {
            "entities": ["Tiger Woods", "Phil Mickelson"],
            "connection": "Golfers - both are professional golfers.",
        },
        {
            "entities": ["Barack Obama", "George W. Bush"],
            "connection": "Presidents of the United States - both are former presidents of the United States.",
        },
        {
            "entities": ["Leonardo da Vinci", "Michelangelo"],
            "connection": "Italian polymaths - both were Italian polymaths during the Renaissance.",
        },
        {
            "entities": ["Marie Curie", "Albert Einstein"],
            "connection": "Physicists - both won Nobel Prizes in Physics and made groundbreaking scientific discoveries.",
        },
        {
            "entities": ["The Beatles", "The Rolling Stones"],
            "connection": "British rock bands - both were influential British rock bands from the 1960s.",
        },
        {
            "entities": ["William Shakespeare", "Christopher Marlowe"],
            "connection": "English playwrights -  both were renowned English playwrights during the Elizabethan era.",
        },
        {
            "entities": ["Charlie Chaplin", "Isaac Newton"],
            "connection": "British figures - both are notable British figures in their respective fields.",
        },
    ]

    no_connections = [
        {
            "entities": ["Mozart", "Muhammad Ali"],
            "connection": "None",
        },
        {
            "entities": ["Marie Curie", "Elvis Presley"],
            "connection": "None",
        },
        {
            "entities": ["William Shakespeare", "Neil Armstrong"],
            "connection": "None",
        },
        {
            "entities": ["Pablo Picasso", "Mother Teresa"],
            "connection": "None",
        },
        {
            "entities": ["Leonardo da Vinci", "Michael Jackson"],
            "connection": "None",
        },

        {
            "entities": ["Mahatma Gandhi", "Walt Disney"],
            "connection": "None",
        }
    ]

    @staticmethod
    def get_prefix(n_valid = 4, n_none = 2):
        selected_valid = np.random.choice(
            BiAssociationPrefix2.valid_connections, size=n_valid, replace=False
        ).tolist()
        selected_none = np.random.choice(
            BiAssociationPrefix2.no_connections, size=n_none, replace=False
        ).tolist()

        connections = selected_valid + selected_none

        np.random.shuffle(connections)
        prefix = BiAssociationPrefix2.instruction + "\n"

        for conn in connections:
            prefix += BiAssociationPrefix2.block_separator
            prefix += f"{BiAssociationPrefix2.question_marker}{conn['entities'][0]} and {conn['entities'][1]}"
            prefix += f"{BiAssociationPrefix2.answer_marker} {conn['connection']}"

        return prefix

In [107]:
with open(
    os.path.join(env_utils.DEFAULT_DATA_DIR, "synthetic_entities/synthetic_entities_bio.json"), "r"
) as f:
    synth = json.load(f)

names_to_profiles = {p["profile"]["name"]: p["profile"] for p in synth}
# names_to_profiles

In [134]:
from src.utils.experiment_utils import set_seed
from src.utils.oracle_llms import ASK_ORACLE_MODEL
from typing import Literal
from src.functional import get_tick_marker
from src.probing.utils import prepare_probing_input, get_lm_generated_answer


def verify_connection_with_oracle(
    lm_response: str,
    entity_profiles: tuple[dict] = None,
    oracle_model: Literal["claude", "gpt"] = "claude",
    expected_answer: str = None,
) -> str:
        
    instruction = f"""Check the following profiles of 2 people
```
profile_1: {json.dumps(entity_profiles[0], indent=2)}
```
```
profile_2: {json.dumps(entity_profiles[1], indent=2)}
```

A smaller LM was asked to find a connection between the two people. Any attribute these two people might share satisfies as a connection. If there is no connection, then the LM is expected to answer "None".

The LM's response is: \"{lm_response}\"
"""
    
    if expected_answer is not None:
        instruction += f"""The expected answer is: \"{expected_answer}\". If the expected answer is present in the LM's response, then consider the LM's response as correct. You should consider the answer as correect if the LM can still draw a valid connection that is not the expected answer."""

    instruction += """Please verify if the response is correct or not. Say "yes" if the response is correct and "no" if it is not.
Make sure to put your answer starts with either "yes" or "no".

Consider that the small LM's response might get abruptly cut off, due to the token limit. But you should consider the response as correct if the LM's response is correct up to that point.
"""
    response = ASK_ORACLE_MODEL[oracle_model](prompt=instruction, use_cache=True)
    logger.debug(f"oracle response: {response}")
    answer = response.lower().strip().startswith("yes")

    return answer

def get_connection_on_entity_pair(
    mt: ModelandTokenizer,
    entities: tuple[str],
    prefix_class = BiAssociationPrefix2,
    n_valid = 6,
    n_none = 2,
    enable_reasoning = False,
):
    prefix = prefix_class.get_prefix(n_valid=n_valid, n_none=n_none)
    connection_prompt = prepare_probing_input(
        mt=mt,
        entities=(entities[0], entities[1]),
        prefix=prefix,
        answer_marker=prefix_class.answer_marker,
        question_marker=prefix_class.question_marker,
        block_separator=prefix_class.block_separator,
        is_a_reasoning_model=enable_reasoning,
    )
    print(mt.tokenizer.decode(connection_prompt.tokenized["input_ids"][0]))

    answer = get_lm_generated_answer(
        mt=mt, prompt=connection_prompt, 
        is_a_reasoning_model=enable_reasoning,
    )

    return answer

In [135]:
with open(
    os.path.join(env_utils.DEFAULT_DATA_DIR, "synthetic_entities/synthetic_entities_bio.json"), "r"
) as f:
    synth = json.load(f)

names_to_profiles = {p["profile"]["name"]: p["profile"] for p in synth}
names_to_profiles

{'Elara Vance': {'name': 'Elara Vance',
  'age': 29,
  'nationality': 'Canadian',
  'occupation': 'Data Scientist',
  'hobbies': ['Hiking', 'Photography', 'Reading'],
  'worksAt': {'company': 'Amazon',
   'position': 'Senior Data Scientist',
   'yearsOfExperience': 5,
   'location': 'San Francisco, CA'},
  'education': {'degree': "Master's in Data Science",
   'university': 'University of Toronto',
   'graduationYear': 2016},
  'languages': [{'language': 'English', 'proficiency': 'Fluent'},
   {'language': 'French', 'proficiency': 'Intermediate'}]},
 'Declan Rivers': {'name': 'Declan Rivers',
  'age': 32,
  'nationality': 'American',
  'occupation': 'Software Engineer',
  'hobbies': ['Hiking', 'Rock Climbing', 'Chess'],
  'worksAt': {'company': 'Amazon',
   'position': 'Lead Developer',
   'yearsOfExperience': 8,
   'location': 'Seattle, WA'},
  'education': {'degree': "Bachelor's in Computer Science",
   'university': 'Stanford University',
   'graduationYear': 2014},
  'languages': [

In [140]:
# query_entities = ["Michael Jackson", "Prince"] 
# query_entities = ("Abraham Lincoln", "John F. Kennedy")
# query_entities = ("John F. Kennedy", "Michael Jordan")
# query_entities = ("Charlie Chaplin", "Rowan Atkinson")
# query_entities = ["Mahatma Gandhi", "Walt Disney"]

# query_entities = ["Thea Bridgeport", "Isabella Garcia"]
# query_entities = ["Elara Vance", "Briony Shaw"]
# query_entities = ["Elara Vance", "Declan Rivers"]
# query_entities = ["Elara Vance", "Aisha Patel"]
# query_entities = ["Ava Carter", "Alistair Finch"]
query_entities = ["Ava Carter", "Sophia Davis"]
# query_entities = ["Declan Rivers", "Aisha Patel"]
# query_entities = ["Rajiv Kumar", "Aisha Patel"]
# query_entities = ["Declan Rivers", "Aiko Tanaka"]

# query_entities = ["Tariq Al-Mansour", "Declan Rivers"]
# query_entities = ["Ava Carter", "Sophia Davis"]
# query_entities = ["Elara Vance", "Rajiv Kumar"]
# query_entities = ["Isabella Garcia", "Rajiv Kumar"]
# query_entities = ["Rajiv Kumar", "Briony Shaw"]
# query_entities = ["Aiko Tanaka", "Michael Jordan"]
# query_entities = ["Elara Vance", "Alistair Finch"]
# query_entities = ["Alistair Finch", "Tariq Al-Mansour"]

prefix_class= BiAssociationPrefix2
# prefix_class= BiAssociationPrefix
enable_reasoning = False
# enable_reasoning = True
set_seed(42)

connection = get_connection_on_entity_pair(
    mt=mt_check,
    entities=query_entities,
    prefix_class=prefix_class,
    
    n_valid=6,
    n_none=2,
    enable_reasoning=enable_reasoning,
)

logger.debug("-" * 150)
logger.info(f"({query_entities[0]}, {query_entities[1]}) => {connection}")
logger.debug("-" * 150)

is_accurate = verify_connection_with_oracle(
    lm_response=connection,
    entity_profiles=(names_to_profiles[query_entities[0]], names_to_profiles[query_entities[1]]),
    oracle_model="claude",
)
logger.debug(f"({query_entities[0]}, {query_entities[1]}) => {get_tick_marker(is_accurate)}")

2025-05-14 13:45:43 src.utils.experiment_utils INFO     setting all seeds to 42


Given two people, find a common link between them.
Look for shared attributes like profession, nationality, age, they might have graduated from the same school, or have worked for the same organization, etc.
    
When giving your answer, stick to this format: `<common link> - <brief explanation in a single sentence>`.
Check the provided examples. If you cannot find any connection, just answer "None".

#
Q: Mahatma Gandhi and Walt Disney
A: None
#
Q: Marie Curie and Albert Einstein
A: Physicists - both won Nobel Prizes in Physics and made groundbreaking scientific discoveries.
#
Q: Tiger Woods and Phil Mickelson
A: Golfers - both are professional golfers.
#
Q: The Beatles and The Rolling Stones
A: British rock bands - both were influential British rock bands from the 1960s.
#
Q: Captain America and Deathstroke
A: Comic book characters - both are enhanced super soldiers in comic books
#
Q: Charlie Chaplin and Isaac Newton
A: British figures - both are notable British figures in their res

In [138]:
from src.functional import generate_with_patch

prompt_template = "{} is an employee of"
# prompt_template = "{} is a citizen of"
# prompt_template = "{} graduated from"

# prompt_template = "Answer yes or no: does {} have a hobby of hiking? Ans:"

print(json.dumps(
    generate_with_patch(
        mt=mt_check,
        inputs=prompt_template.format(query_entities[0]),
        n_gen_per_prompt=1,
        do_sample=False,
        max_new_tokens=30,
    ),
    indent=2,
))

print(json.dumps(
    generate_with_patch(
        mt=mt_check,
        inputs=prompt_template.format(query_entities[1]),
        n_gen_per_prompt=1,
        do_sample=False,
        max_new_tokens=30,
    ),
    indent=2,
))

[
  "Rajiv Kumar is an employee of Microsoft in Bangalore, India. He is a graduate of the Indian Institute of Technology, Delhi, where he completed a Master's in Data Science in "
]
[
  "Aisha Patel is an employee of Microsoft in Redmond, WA, where she has accumulated six years of experience in the field of data science. She completed her doctoral studies in computer science"
]
