In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append("../")
import copy

import logging
from src.utils import logging_utils
from src import functional
from src.models import ModelandTokenizer
from src.dataset import load_dataset, load_relation, fill_template


logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.INFO,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

In [3]:
from src.models import ModelandTokenizer

model_name = "meta-llama/Llama-2-7b-hf"
mt = ModelandTokenizer(model_path=model_name)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

2023-12-15 15:31:46 src.models INFO     loaded model <meta-llama/Llama-2-7b-hf> | size: 12980.516 MB


In [5]:
#! This filtering strategy may cause problem for `gender_head_of_govt`, the model is almost always happy to predict `male` with zero-shot

# relation = load_relation(
#     relation_file = "head_of_government.json",
#     num_icl = 0,                          # initialize with zero-shot
#     default_path="../data",
#     batch_size=500
# )

# all_samples = copy.deepcopy(relation.samples)

# # filter zero-shot model knowledge
# relation = functional.filter_samples_by_model_knowledge(
#     mt = mt,
#     relation = relation,
# )

# relation.properties["num_icl"] = 5
# relation.select_icl_examples(num_icl=5)

# relation.samples = all_samples

# # filter model knowledge with `num_icl` shots
# relation = functional.filter_samples_by_model_knowledge(
#     mt = mt,
#     relation = relation,
# )

In [32]:
from src.dataset import balance_samples

relation = load_relation(
    relation_file = "gender_head_of_govt.json",
    num_icl = 5,                          # initialize with 5-shot
    default_path="../data",
    # batch_size=500
)

relation = functional.filter_samples_by_var(relation = relation, var = "2015")

relation.range_stats

2023-12-15 16:03:04 src.dataset INFO     initialized relation -> "gender head of govt" with 7757 samples
2023-12-15 16:03:04 src.functional INFO     filtered 120 with var=2015, from gender head of govt


{'man': 109, 'woman': 6}

In [33]:
relation = functional.filter_samples_by_model_knowledge(
    mt = mt,
    relation = relation,
)

relation.samples = balance_samples(relation.samples)

relation.range_stats

2023-12-15 16:03:10 src.functional INFO     filtered relation "gender head of govt" to 103 samples (with 5-shots)


{'woman': 4, 'man': 4}

In [34]:
from relations.src.operators import JacobianIclMeanEstimator
import relations.src.functional as relations_functional

# relations_functional.make_prompt = functional.make_prompt
relation.prompt_template

"In <var>, {}'s <role> was a"

In [35]:
relation.few_shot_demonstrations

["In 2015, North Korea's Supreme Leader was a man",
 "In 2015, South Korea's President was a woman",
 "In 2015, Central African Republic's President was a woman",
 "In 2015, Eritrea's President was a man",
 "In 2015, Madagascar's President was a man"]

In [36]:
estimator = JacobianIclMeanEstimator(
    mt = mt,
    h_layer = 8,
    beta = 5.0
)

lre = estimator(relation)

In [37]:
print(lre.prompt_template)

In 2015, North Korea's Supreme Leader was a man
In 2015, South Korea's President was a woman
In 2015, Central African Republic's President was a woman
In 2015, Eritrea's President was a man
In 2015, Madagascar's President was a man
In <var>, {}'s <role> was a


In [38]:
correct = 0
wrong = 0
for sample in relation.samples:
    predictions = lre(sample = sample).predictions
    known_flag = functional.is_nontrivial_prefix(
        prediction=predictions[0].token, target=sample.object
    )
    print(f"{sample.subject=} ({sample.placeholders['<var>']}), {sample.object=}, ", end="")
    print(f'predicted="{functional.format_whitespace(predictions[0].token)}", (p={predictions[0].prob:.3f}), known=({functional.get_tick_marker(known_flag)})')
    
    correct += known_flag
    wrong += not known_flag
    
faithfulness = correct/(correct + wrong)

print("------------------------------------------------------------")
print(f"Faithfulness (@1) = {faithfulness}")
print("------------------------------------------------------------")

sample.subject='Liberia' (2015), sample.object='woman', predicted="man", (p=0.136), known=(✗)
sample.subject='Chile' (2015), sample.object='woman', predicted="woman", (p=0.173), known=(✓)
sample.subject='Benin' (2015), sample.object='man', predicted="young", (p=0.121), known=(✗)
sample.subject='India' (2015), sample.object='man', predicted="man", (p=0.197), known=(✓)
sample.subject='Germany' (2015), sample.object='woman', predicted="man", (p=0.272), known=(✗)
sample.subject='Norway' (2015), sample.object='woman', predicted="woman", (p=0.276), known=(✓)
sample.subject='United Arab Emirates' (2015), sample.object='man', predicted="man", (p=0.191), known=(✓)
sample.subject="Côte d'Ivoire" (2015), sample.object='man', predicted="man", (p=0.152), known=(✓)
------------------------------------------------------------
Faithfulness (@1) = 0.625
------------------------------------------------------------


In [31]:
prompt = relation[0][0]

functional.predict_next_token(
    mt = mt,
    prompt = prompt
)

[[PredictedToken(token='female', token_id=12944, prob=0.5100888013839722),
  PredictedToken(token='male', token_id=14263, prob=0.4867301285266876),
  PredictedToken(token='woman', token_id=6114, prob=0.0015251379227265716),
  PredictedToken(token='man', token_id=767, prob=0.0005108572077006102),
  PredictedToken(token='Male', token_id=27208, prob=0.00014185991312842816)]]