In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import json
import torch
import transformers
import sys

sys.path.append("../")

##################################################################
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"
#################################################################

import logging
from src.utils import logging_utils
from src.utils import env_utils

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)



logger.info(f"{torch.__version__=}, {torch.version.cuda=}")
logger.info(
    f"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}"
)
logger.info(f"{transformers.__version__=}")

In [None]:
import torch

from src.models import ModelandTokenizer


# model_key = "meta-llama/Llama-3.1-70B"
# model_key = "meta-llama/Llama-3.1-8B"
# model_key = "meta-llama/Llama-3.2-3B"

# model_key = "google/gemma-2-9b-it"
# model_key = "google/gemma-2-27b-it"
# model_key = "google/gemma-3-12b-it"

# model_key = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"

# model_key = "allenai/OLMo-2-1124-7B-Instruct"
# model_key = "allenai/OLMo-7B-0424-hf"

# model_key = "Qwen/Qwen2-7B"
# model_key = "Qwen/Qwen2.5-14B"
# model_key = "Qwen/Qwen2.5-32B"

# model_key = "Qwen/Qwen3-1.7B"
# model_key = "Qwen/Qwen3-4B"
model_key = "Qwen/Qwen3-8B"
# model_key = "Qwen/Qwen3-14B"

In [None]:
mt = ModelandTokenizer(
    model_key=model_key,
    torch_dtype=torch.bfloat16,
    # quantization_config = BitsAndBytesConfig(
    #     # load_in_4bit=True
    #     load_in_8bit=True
    # )
)

In [None]:
from src.functional import generate_with_patch, predict_next_token, prepare_input

# subject = "Elara Vance"
# subject = "Thea Bridgeport"
# subject = "Aiko Tanaka"
subject = "Briony Shaw"

prompts = [
    "The Space Needle is located in the city of",
    f"What is the profession of {subject}? Ans:",
    f"What is the age of {subject}? Ans:",
    f"What is the name of the city where {subject} lives? Ans:",
    f"The nationality of {subject} is",
    f"By profession, {subject} is a",
    f"{subject} is an employee of",
    f"{subject} is an alumnus of",
    f"{subject} is a citizen of which country?",
]

inputs = prepare_input(prompts, tokenizer=mt.tokenizer)

pred = predict_next_token(
    mt=mt,
    inputs=inputs,
)

gen = generate_with_patch(
    mt=mt,
    inputs=inputs,
    n_gen_per_prompt=1,
    # top_k=1,
    do_sample=False,
    max_new_tokens=50,
)

print(json.dumps(gen, indent=2))

pred

## Test Finetuning

In [None]:
from src.tokens import prepare_input
from src.functional import get_module_nnsight

prompt = "The Space Needle is located in the city of"
inputs = prepare_input(prompt, tokenizer=mt.tokenizer)

module_name = f"{mt.mlp_module_name_format.format(10)}.down_proj"
nnsight_module = get_module_nnsight(mt, module_name)

In [None]:
labels = inputs["input_ids"]
# labels = None
with mt.trace(inputs=inputs, labels=labels) as tracer:
    tracer.log(type(tracer))
    tracer.log("input:", nnsight_module.input.shape)
    h = nnsight_module.output.save()
    output = mt.output.save()

print(">>", output.loss)
h.shape, output.logits.shape

In [None]:
with mt.trace() as tracer:
    tracer.log(type(tracer))
    with tracer.invoke(inputs, labels=labels):
        tracer.log("input:", nnsight_module.input.shape)
        module_in = nnsight_module.input.save()
        module_out = nnsight_module.output.save()
        output = mt.output.save()


print(output.loss)
h.shape, output.logits.shape

In [None]:
module_in.shape, module_out.shape

In [None]:
import baukit
from src.functional import untuple


def edit_repr(layer, input, output):
    print(layer)
    print("input:", untuple(input).shape)
    print("output:", untuple(output).shape)

    print(f"{torch.allclose(module_in, untuple(input))=}")
    print(f"{torch.allclose(module_out, untuple(output))=}")

    return output


with baukit.TraceDict(
    module=mt._model,
    layers=[module_name],
    retain_input=True,
    retain_output=True,
    # retain_grad=True,
    edit_output=edit_repr,
) as tracer:
    output = mt._model(**inputs, labels=labels)

print(output.loss)

In [None]:
from src.utils.training_utils import ParameterDelta

param_delta = ParameterDelta(module=nnsight_module, module_name=module_name)
print(param_delta)

In [None]:
with torch.no_grad():
    param_delta.param_delta[...] = param_delta.param_delta + 1.5

In [None]:
with mt.trace(inputs) as tracer:
    param_delta.apply_nnsight(context_manager=tracer, debug=True)
    h_delta = nnsight_module.output.save()
h_delta.shape

In [None]:
delta_dct = torch.nn.ModuleDict({module_name.replace(".", "<>"): param_delta})
delta_dct.state_dict()

In [None]:
param_delta.parameters()

In [None]:
torch.save(delta_dct.state_dict(), "delta_dict_test.pth")

In [None]:
loaded = torch.load("delta_dict_test.pth")
loaded

In [None]:
for name, param in loaded.items():
    print(name, param.shape)

In [None]:
from src.utils.training_utils import TrainableLM_delta

trainable = TrainableLM_delta(
    mt=mt,
)

In [None]:
param_delta = list(trainable.trainable_params.values())[0]
with torch.no_grad():
    param_delta.param_delta[...] = 0.5

param_delta.param_delta

In [None]:
trainable.apply_clamp(clamp_value=1e-5)

In [None]:
param_delta.param_delta

In [None]:
inputs

In [None]:
out = trainable.forward(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    labels=inputs["input_ids"],
    apply_modification=True,
)

In [None]:
out.loss

In [None]:
out = mt._model(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    labels=inputs["input_ids"],
)
out.loss

In [None]:
from src.utils.training_utils import ParameterLoRA

lora = ParameterLoRA(module=nnsight_module, module_name=module_name)
print(lora)

In [None]:
from src.utils.training_utils import TrainableLM_LoRA

trainable = TrainableLM_LoRA(
    mt=mt,
)

In [None]:
check = list(trainable.trainable_params.values())[0]
check.parameters()

In [None]:
lora_out = trainable.forward(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    labels=inputs["input_ids"],
    apply_modification=True,
)

## Running the Finetuning

In [None]:
from datasets import load_dataset
import numpy as np

REG_LIMIT = 100

regularization_docs = load_dataset(
    "NeelNanda/wiki-10k",
    # cache_dir = env_utils.HF_CACHE_DIR
)
indices = np.random.choice(
    len(regularization_docs["train"]), size=REG_LIMIT, replace=False
).tolist()

regularization_docs = [regularization_docs["train"][i]["text"] for i in indices]

In [None]:
finetune_docs = []
with open(
    os.path.join(env_utils.DEFAULT_DATA_DIR, "synthetic_entities_bio.json"), "r"
) as f:
    synth = json.load(f)

for i in range(len(synth)):
    finetune_docs.extend(synth[i]["docs"])

repeat = 5
finetune_docs = finetune_docs * repeat

np.random.shuffle(finetune_docs)

In [None]:
# from src.obsolete.finetune_pl import TextDataset
from src.utils.training_utils import TextDataset
from torch.utils.data import DataLoader

BATCH_SIZE = 4

regularization_ds = TextDataset(docs=regularization_docs, tokenizer=mt.tokenizer)

train_split = int(0.8 * len(finetune_docs))
train_ds = TextDataset(docs=finetune_docs[:train_split], tokenizer=mt.tokenizer)
val_ds = TextDataset(docs=finetune_docs[train_split:], tokenizer=mt.tokenizer)

reg_loader = DataLoader(
    regularization_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    pin_memory=True,
    num_workers=4,
)
train_loader = DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=4
)
val_loader = DataLoader(
    val_ds, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=4
)

In [None]:
from src.utils.training_utils import TrainableLM_delta, TrainableLM_LoRA

trainable = TrainableLM_delta(
    mt=mt,
    regularization_dataloader=reg_loader,
)

# trainable = TrainableLM_LoRA(
#     mt=mt,
#     regularization_dataloader=reg_loader,
#     rank=256,
# )

In [None]:
check_param = list(trainable.trainable_params.values())[0]
check_param.parameters()

In [None]:
hasattr(trainable, "cached_reg_info")

In [None]:
tune_batch = next(iter(train_loader))
tune_batch

In [None]:
with torch.no_grad():
    out = trainable.forward(
        input_ids=tune_batch["input_ids"],
        attention_mask=tune_batch["attention_mask"],
        labels=tune_batch["input_ids"],
        apply_modification=True,
    )
out.loss

In [None]:
with torch.no_grad():
    out = trainable.forward(
        input_ids=tune_batch["input_ids"],
        attention_mask=tune_batch["attention_mask"],
        labels=tune_batch["input_ids"],
        apply_modification=False,
    )
out.loss

In [None]:
with torch.no_grad():
    loss, loss_dict = trainable.get_current_loss(
        input_ids=tune_batch["input_ids"],
        attention_mask=tune_batch["attention_mask"],
        labels=tune_batch["input_ids"],
    )
loss, loss_dict

In [None]:
loss, loss_dict = trainable.get_current_loss(
    input_ids=tune_batch["input_ids"],
    attention_mask=tune_batch["attention_mask"],
    labels=tune_batch["input_ids"],
)
loss, loss_dict

In [None]:
loss.backward()

In [None]:
trainable._get_tunable_params()[3].grad

In [None]:
trainable.apply_clamp(clamp_value=1e-5)

In [None]:
import wandb
from line_profiler import LineProfiler
from src.utils.training_utils import Trainer

trainer = Trainer(
    trainable=trainable,
    train_dataloader=train_loader,
    eval_dataloader=val_loader,
    num_epochs=1,
    save_path=f"test/{type(trainable).__name__}",
    # log_to_wandb=True,
    log_to_wandb=False,
    clamp_abs_update=1e-5,
)

In [None]:
# wandb.init(
#     entity="reasoning-iterp",
#     project="connections",
#     name=f"{model_key.split('/')[-1]}_Test_{type(trainable).__name__}",
#     config=dict(trainer.hparams),
# )

# trainer.fit(pl_model, train_loader, val_loader)

profiler = LineProfiler()
profiler.add_function(trainer.train)
profiler.add_function(trainer.evaluate)
profiler.add_function(trainable.get_current_loss)

profiler.runcall(trainer.train)
# trainer.train()

In [None]:
profiler.print_stats(sort="time")

In [None]:
trainable._get_tunable_params()

In [None]:
# trainable.trainable_params["model.layers.0.mlp.gate_proj"].grad

In [None]:
trainable.save("test")

## Load Checkpoint

In [None]:
from src.functional import free_gpu_cache

checkpoint_path = os.path.join(
    env_utils.DEFAULT_RESULTS_DIR,
    "trained_params",
    "_full__clamp=0.001", 
    model_key.split("/")[-1]
)

version = "epoch_3"
# version = "final_model"

checkpoint_path = os.path.join(
    env_utils.DEFAULT_RESULTS_DIR, checkpoint_path, version
)

print(os.listdir(checkpoint_path))

checkpoint_path = os.path.join(checkpoint_path, "trainable_params.pt")

loaded_deltas = torch.load(checkpoint_path, map_location="cuda")
# loaded_deltas

free_gpu_cache()

In [None]:
d = loaded_deltas['model<>layers<>10<>mlp<>gate_proj']
d.abs().max()

In [None]:
# from src.utils.training_utils import TrainableLM_delta

# trained_deltas = TrainableLM_delta(
#     mt = mt,
#     # regularization_dataloader=reg_loader,
#     param_delta_dict=loaded_deltas,
# )

In [None]:
mt_check = ModelandTokenizer(
    model_key=model_key,
    torch_dtype=torch.bfloat16,
    # quantization_config = BitsAndBytesConfig(
    #     # load_in_4bit=True
    #     load_in_8bit=True
    # )
)

In [None]:
from src.utils.training_utils import TrainableLM_delta, TrainableLM_LoRA

Trainable_CLS = TrainableLM_delta
# Trainable_CLS = TrainableLM_LoRA
Trainable_CLS.fuse_with_model(mt_check._model, loaded_deltas)

In [None]:
Trainable_CLS.defuse_from_model(
    mt_check._model,
    loaded_deltas,
    # param_delta_dict=loaded_deltas,
)

## Qualitative Validation

In [None]:
from src.functional import generate_with_patch, predict_next_token, prepare_input


inputs = prepare_input(prompts, tokenizer=mt_check.tokenizer)

pred = predict_next_token(
    mt=mt_check,
    inputs=inputs,
)

gen = generate_with_patch(
    mt=mt_check,
    inputs=inputs,
    n_gen_per_prompt=1,
    top_k=1,
    do_sample=False,
    max_new_tokens=50,
)

print(json.dumps(gen, indent=2))

pred

In [None]:
# embedder_orig = mt._model.model.embed_tokens.weight
# embedder_finetuned = mt_check._model.model.embed_tokens.weight

# torch.dist(embedder_orig.cuda(), embedder_finetuned)

In [None]:
# wgt_orig = mt._model.model.layers[5].mlp.up_proj.weight
# wgt_finetuned = mt_check._model.model.layers[5].mlp.up_proj.weight

# torch.dist(wgt_orig.cuda(), wgt_finetuned.cuda())

## Reasoning/Thinking Test

In [None]:
subject = "Thea Bridgeport"
# subject = "Barack Obama"
# subject = "Alistair Finch"
# subject = "Elara Vance"

In [None]:
from src.functional import generate_with_patch

thinking_prompt = f"{subject} is an alumnus of" #+ "<think>"
generate_with_patch(
    mt = mt_check,
    inputs = thinking_prompt,
    max_new_tokens = 50,
    temperature = 0.6
)

In [None]:
# Use chat template
# question = f"What is the alma mater of {subject}?"
question = f"Where is {subject} currently employed?"
messages = [{"role": "user", "content": question}]
prompt = mt_check.tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
    enable_thinking=True
)
print(prompt)

print(generate_with_patch(
    mt=mt_check,
    inputs=prompt,
    n_gen_per_prompt=1,
    temperature=0.6,
    max_new_tokens=500,
)[0])

In [None]:
thinking_prompt = f"What is the alma mater of {subject}? Ans: {subject} attended" #+ "<think>"
generate_with_patch(
    mt = mt_check,
    inputs = thinking_prompt,
    max_new_tokens = 30,
    temperature = 0.6
)

In [None]:
generate_with_patch(
    mt = mt,
    inputs = thinking_prompt,
    max_new_tokens = 30,
    temperature = 0.6
)

## Localization Test (Activation Patching)

In [None]:
subject = "Briony Shaw"

In [None]:
# prompt_template = "{} is an alumnus of"
# prompt_template = "By profession, {} is a"
prompt_template = "{} is a citizen of the country of"

# clean_subj = "Issac Newton"
# # patch_subj = "Thea Bridgeport"
# patch_subj = "Bill Gates"

clean_subj = "Michael Jordan"
patch_subj = subject
# patch_subj = "Ryan Reynolds"

print(json.dumps(
    generate_with_patch(
        mt=mt_check,
        inputs=prompt_template.format(clean_subj),
        n_gen_per_prompt=1,
        do_sample=False,
        max_new_tokens=30,
    ),
    indent=2,
))

print(json.dumps(
    generate_with_patch(
        mt=mt_check,
        inputs=prompt_template.format(patch_subj),
        n_gen_per_prompt=1,
        do_sample=False,
        max_new_tokens=30,
    ),
    indent=2,
))

In [None]:
from src.trace import trace_important_states
# from src.utils.typing import TokenizerOutput
from src.plotting import plot_trace_heatmap

for kind in ["residual", "mlp", "attention"]:
    # for kind in ["residual"]:
    trace_results = trace_important_states(
        mt=mt_check,
        prompt_template=prompt_template,
        clean_subj=clean_subj,
        patched_subj=patch_subj,
        trace_start_marker=None,
        metric="logit",
        # metric="prob",
        # normalize=False,
        kind=kind,
        window_size=1 if kind == "residual" else 5,
        ans_tokens=None,
    )

    plot_trace_heatmap(
        result=trace_results,
        model_name=model_key.split("/")[-1],
        scale_range=(0, 1) if trace_results.normalized == True else None,
    )

## Bi-Association

In [None]:
from src.probing.utils import prepare_probing_input, get_lm_generated_answer

Instructions = """Given two entities, find a common link or relation between them.
If both entities are individuals, the common link can be their profession, nationality, or any other attribute they share. Their relation can be if someone is the student/teacher of the other etc.
Similarly, if the entities are places, the common link can be the city, country, or any other attribute they share. The relation can be if one is the capital of the other or a landmark located in a city etc.
If there is no connection just answer "None"."""

# Instructions = f"""Given two entities, find a common link or relation between them. If there is no connection just answer "None"."""

block_separator = "\n#"
question_marker = "\nQ: "
answer_marker = "\nA:"

examples = """#
Q: Captain America and Deathstroke
A: They are both comic book characters and enhanced super soldiers.
#
Q: Tiger Woods and Phil Mickelson
A: They are both professional golfers.
#
Q: Rome and Italy
A: Rome is the capital city of Italy.
#
Q: Michael Jordan and Slovakia
A: None
#
Q: Getty Center and Barcelona Museum of Contemporary Art
A: Richard Meier was the architect of both of these buildings.
#
Q: Celine Dion and Steve Jobs
A: None
"""

# Instructions = """Given two individuals, find an attribute they share or a connection between them. 
# If there is no connection just answer "None".""" 

# examples = """#
# Q: Barack Obama and George W. Bush
# A: They are both former presidents of the United States.
# #
# Q: Celine Dion and Steve Jobs
# A: None
# #
# Q: Bill Gates and Michael Jordan
# A: They are both American.
# #
# Q: Hugh Jackman and Issac Newton
# A: None
# #
# Q: Captain America and Deathstroke
# A: They are both comic book characters and enhanced super soldiers.
# """


# entities = ["Thea Bridgeport", "Isabella Garcia"]
# entities = ["Michael Jackson", "Prince"]
# entities = ["Elara Vance", "Declan Rivers"]
# entities = ["Elara Vance", "Aisha Patel"]
# entities = ["Elara Vance", "Briony Shaw"]
# entities = ["Ava Carter", "Alistair Finch"]
# entities = ["Ava Carter", "Sophia Davis"]
# entities = ["Declan Rivers", "Aisha Patel"]
# entities = ["Rajiv Kumar", "Aisha Patel"]
# entities = ["Declan Rivers", "Aiko Tanaka"]
# entities = ["Tariq Al-Mansour", "Declan Rivers"]

# entities = ["Elara Vance", "Briony Shaw"]
# entities = ["Tariq Al-Mansour", "Declan Rivers"]
# entities = ["Ava Carter", "Sophia Davis"]
# entities = ["Elara Vance", "Rajiv Kumar"]
# entities = ["Isabella Garcia", "Rajiv Kumar"]
# entities = ["Rajiv Kumar", "Briony Shaw"]
# entities = ["Aiko Tanaka", "Michael Jordan"]
entities = ["Elara Vance", "Alistair Finch"]
# entities = ["Alistair Finch", "Tariq Al-Mansour"]


prefix = f"""{Instructions}
{examples}
"""

#######################################################################
# enable_reasoning = "deepseek" in model_key.lower()
# enable_reasoning = True
enable_reasoning = False
#######################################################################

connection_mt = mt_check
# connection_mt = mt

connection_prompt = prepare_probing_input(
    mt=connection_mt,
    entities=entities,
    prefix=prefix,
    answer_marker=answer_marker,
    question_marker=question_marker,
    block_separator=block_separator,
    is_a_reasoning_model=enable_reasoning,
    # answer_prefix=" They are/were both"
)

print(connection_mt.tokenizer.decode(connection_prompt.tokenized["input_ids"][0]))

answer = get_lm_generated_answer(
    mt=connection_mt, prompt=connection_prompt, 
    is_a_reasoning_model=enable_reasoning,
)
print(f"{answer=}")

In [None]:
from src.functional import generate_with_patch

prompt_template = "{} is an employee of"
# prompt_template = "{} is a citizen of"
# prompt_template = "{} graduated from"

# prompt_template = "Answer yes or no: does {} have a hobby of hiking? Ans:"

print(json.dumps(
    generate_with_patch(
        mt=mt_check,
        inputs=prompt_template.format(entities[0]),
        n_gen_per_prompt=1,
        do_sample=False,
        max_new_tokens=30,
    ),
    indent=2,
))

print(json.dumps(
    generate_with_patch(
        mt=mt_check,
        inputs=prompt_template.format(entities[1]),
        n_gen_per_prompt=1,
        do_sample=False,
        max_new_tokens=30,
    ),
    indent=2,
))

# Evaluation

### Atomic Evaluation

In [None]:
with open(
    os.path.join(env_utils.DEFAULT_DATA_DIR, "synthetic_entities/synthetic_entities_bio.json"), "r"
) as f:
    synth = json.load(f)

profiles = [p["profile"] for p in synth]

all_hobbies = []
for profile in profiles:
    all_hobbies.extend(profile["hobbies"])
all_hobbies = list(set(all_hobbies))

all_languages = []
for profile in profiles:
    all_languages.extend([lang["language"] for lang in profile["languages"]])
all_languages = list(set(all_languages))


In [None]:
subj = "Ava Carter"
profile = next(p for p in profiles if p["name"] == subj)
profile

In [None]:
from src.evaluation import get_atomic_qa

qa = get_atomic_qa(
    profile=profile,
    attribute="hobbies",
    all_options=all_hobbies,
)
qa

In [None]:
from src.evaluation import get_answers_for_atomic_questions, is_accurate
from src.functional import get_tick_marker

questions = [q for q, a in qa]
lm_response = get_answers_for_atomic_questions(
    mt=mt_check,
    questions=questions,
    batch_size=8,
    max_new_tokens=30,
)

for (q, a), lm_a in zip(qa, lm_response):
    print(f"Q: \"{q}\", A: \"{a}\"")
    print(f"lm response: \"{lm_a}\"")
    print(f"is_accurate: ({get_tick_marker(is_accurate(lm_a, a))})")

In [None]:
from src.evaluation import get_answers_for_atomic_questions_with_reasoning

questions = [q for q, a in qa]
lm_response = get_answers_for_atomic_questions_with_reasoning(
    mt=mt_check,
    questions=questions,
)

answers = [response["answer"] for response in lm_response]

for (q, a), lm_a in zip(qa, answers):
    print(f"Q: \"{q}\", A: \"{a}\"")
    print(f"lm response: \"{lm_a}\"")
    print(f"is_accurate: ({get_tick_marker(is_accurate(lm_a, a))})")

In [None]:
from src.evaluation import evaluate_atomic_knowledge_per_entity

profile = next(p for p in profiles if p["name"] == "Briony Shaw")
profile_eval = evaluate_atomic_knowledge_per_entity(
    mt=mt_check,
    profile=profile,
    enable_reasoning=False,
)

profile_eval

In [None]:
from src.evaluation import verify_atomic_answer_with_oracle
verify_atomic_answer_with_oracle(
    profile=profile,
    question = "What is the occupation of Briony Shaw?",
    lm_response = "Briony Shaw is a Research Scientist at Environment and Climate Change Canada. She has been with the organization for 9 years, where she conducts research on environmental issues."
    # lm_response = "Briony Shaw is a data scientist at Environment and Climate Change Canada."
)

In [None]:
from src.evaluation import evaluate_atomic_knowledge
atomic_evals = evaluate_atomic_knowledge(
    mt=mt_check,
    profiles=profiles[:3],
    enable_reasoning=False,
)

In [None]:
atomic_evals

In [None]:
from src.utils.metrics import AggregateMetric

acc_per_attribute = {}
for profile_eval in atomic_evals["profiles"]:
    for attr, attr_eval in profile_eval["attributes"].items():
        if attr not in acc_per_attribute:
            acc_per_attribute[attr] = []
        acc_per_attribute[attr].append(attr_eval["accuracy"])

acc_per_attribute = {
    attr: AggregateMetric.aggregate(values = acc_per_attribute[attr])
    for attr in acc_per_attribute
}

acc_per_attribute

In [None]:
from matplotlib import pyplot as plt

plt.bar(
    acc_per_attribute.keys(),
    [acc_per_attribute[attr].mean for attr in acc_per_attribute],
)
plt.xticks(rotation=45)

### Bi-Association Evaluation

In [None]:
from src.probing.utils import prepare_probing_input, get_lm_generated_answer
import numpy as np

class BiAssociationPrefix:

    instruction = """Given two entities, find a common link or relation between them.
If both entities are individuals, the common link can be their profession, nationality, they might like the same food, or any other attribute they might share. Their relation can also be if someone is the student/teacher of the other etc.
Similarly, if the entities are places, the common link can be that they are located in the same city of country. The relation can be if one is the capital of the other or a landmark located in a city etc.
If you cannot find any connection just answer "None"."""

#     instruction = """Given two entities, find a common link or relation between them.
# If you cannot find any connection just answer "None"."""

    block_separator = "\n#"
    question_marker = "\nQ: "
    answer_marker = "\nA:"

    valid_connections = [
        {
            "entities": ["Captain America", "Deathstroke"],
            "connection": "They are both comic book characters and enhanced super soldiers.",
        },
        {
            "entities": ["Rome", "Italy"],
            "connection": "Rome is the capital city of Italy.",
        },
        {
            "entities": ["Getty Center", "Barcelona Museum of Contemporary Art"],
            "connection": "Richard Meier was the architect of both of these buildings.",
        },
        {
            "entities": ["Tiger Woods", "Phil Mickelson"],
            "connection": "They are both professional golfers.",
        },
        {
            "entities": ["Barack Obama", "George W. Bush"],
            "connection": "They are both former presidents of the United States.",
        }
    ]

    no_connections = [
        {
            "entities": ["Celine Dion", "Steve Jobs"],
            "connection": "None",
        },
        {
            "entities": ["Michael Jordan", "Slovakia"],
            "connection": "None",
        }
    ]

    @staticmethod
    def get_prefix(n_valid = 4, n_none = 2):
        selected_valid = np.random.choice(
            BiAssociationPrefix.valid_connections, size=n_valid, replace=False
        ).tolist()
        selected_none = np.random.choice(
            BiAssociationPrefix.no_connections, size=n_none, replace=False
        ).tolist()

        connections = selected_valid + selected_none

        np.random.shuffle(connections)
        prefix = BiAssociationPrefix.instruction + "\n"

        for conn in connections:
            prefix += BiAssociationPrefix.block_separator
            prefix += f"{BiAssociationPrefix.question_marker}{conn['entities'][0]} and {conn['entities'][1]}"
            prefix += f"{BiAssociationPrefix.answer_marker} {conn['connection']}"

        return prefix



prefix = BiAssociationPrefix.get_prefix(n_valid=4, n_none=2)
# query_entities = ["Michael Jackson", "Prince"] 
# query_entities = ("Abraham Lincoln", "John F. Kennedy")
# query_entities = ("John F. Kennedy", "Michael Jordan")

# query_entities = ["Thea Bridgeport", "Isabella Garcia"]
# query_entities = ["Michael Jackson", "Prince"]
# query_entities = ["Elara Vance", "Declan Rivers"]
# query_entities = ["Elara Vance", "Aisha Patel"]
# query_entities = ["Elara Vance", "Briony Shaw"]
# query_entities = ["Ava Carter", "Alistair Finch"]
# query_entities = ["Ava Carter", "Sophia Davis"]
# query_entities = ["Declan Rivers", "Aisha Patel"]
# query_entities = ["Rajiv Kumar", "Aisha Patel"]
query_entities = ["Declan Rivers", "Sophia Davis"]
# query_entities = ["Tariq Al-Mansour", "Declan Rivers"]

# query_entities = ["Elara Vance", "Briony Shaw"]
# query_entities = ["Tariq Al-Mansour", "Declan Rivers"]
# query_entities = ["Ava Carter", "Sophia Davis"]
# query_entities = ["Elara Vance", "Rajiv Kumar"]
# query_entities = ["Isabella Garcia", "Rajiv Kumar"]
# query_entities = ["Rajiv Kumar", "Briony Shaw"]
# query_entities = ["Aiko Tanaka", "Michael Jordan"]
# query_entities = ["Elara Vance", "Alistair Finch"]
# query_entities = ["Alistair Finch", "Tariq Al-Mansour"]

enable_reasoning = False
# enable_reasoning = True

connection_prompt = prepare_probing_input(
    mt=mt_check,
    entities=query_entities,
    prefix=prefix,
    answer_marker=BiAssociationPrefix.answer_marker,
    question_marker=BiAssociationPrefix.question_marker,
    block_separator=BiAssociationPrefix.block_separator,
    is_a_reasoning_model=enable_reasoning,
)

print(connection_mt.tokenizer.decode(connection_prompt.tokenized["input_ids"][0]))


answer = get_lm_generated_answer(
    mt=connection_mt, prompt=connection_prompt, 
    is_a_reasoning_model=enable_reasoning,
)

print(f"{answer=}")