In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append("../")
import copy

import logging
from src.utils import logging_utils
from src import functional
from src.models import ModelandTokenizer
from src.dataset import load_dataset, load_relation, fill_template


logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.INFO,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

In [3]:
from src.models import ModelandTokenizer

model_name = "meta-llama/Llama-2-7b-hf"
mt = ModelandTokenizer(model_path=model_name)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

2023-12-17 17:22:08 src.models INFO     loaded model <meta-llama/Llama-2-7b-hf> | size: 12980.516 MB


In [4]:
#! This filtering strategy may cause problem for `gender_head_of_govt`, the model is almost always happy to predict `male` with zero-shot

# relation = load_relation(
#     relation_file = "head_of_government.json",
#     num_icl = 0,                          # initialize with zero-shot
#     default_path="../data",
#     batch_size=500
# )

# all_samples = copy.deepcopy(relation.samples)

# # filter zero-shot model knowledge
# relation = functional.filter_samples_by_model_knowledge(
#     mt = mt,
#     relation = relation,
# )

# relation.properties["num_icl"] = 5
# relation.select_icl_examples(num_icl=5)

# relation.samples = all_samples

# # filter model knowledge with `num_icl` shots
# relation = functional.filter_samples_by_model_knowledge(
#     mt = mt,
#     relation = relation,
# )

In [5]:
from src.dataset import balance_samples

relation = load_relation(
    relation_file = "gender_head_of_govt.json",
    num_icl = 5,                          # initialize with 5-shot
    default_path="../data",
    # batch_size=500
)

relation = functional.filter_samples_by_var(relation = relation, var = "2015")

relation.range_stats

2023-12-17 17:22:08 src.dataset INFO     initialized relation -> "gender head of govt" with 7757 samples
2023-12-17 17:22:08 src.functional INFO     filtered 120 with var=2015, from gender head of govt


{'man': 109, 'woman': 6}

In [6]:
relation = functional.filter_samples_by_model_knowledge(
    mt = mt,
    relation = relation,
)

relation.samples = balance_samples(relation.samples)

relation.range_stats

2023-12-17 17:22:12 src.functional INFO     filtered relation "gender head of govt" to 93 samples (with 5-shots)


{'woman': 5, 'man': 5}

In [7]:
from relations.src.operators import JacobianIclMeanEstimator
import relations.src.functional as relations_functional

# relations_functional.make_prompt = functional.make_prompt
relation.prompt_template

"In <var>, {}'s <role> was a"

In [8]:
relation.few_shot_demonstrations

["In 2015, Bhutan's Prime Minister was a man",
 "In 2015, Somalia's President was a man",
 "In 2015, Chile's President was a woman",
 "In 2015, Eswatini's King was a man",
 "In 2015, Bangladesh's Prime Minister was a woman"]

In [9]:
estimator = JacobianIclMeanEstimator(
    mt = mt,
    h_layer = 8,
    beta = 5.0
)

lre = estimator(relation)

In [10]:
print(lre.prompt_template)

In 2015, Bhutan's Prime Minister was a man
In 2015, Somalia's President was a man
In 2015, Chile's President was a woman
In 2015, Eswatini's King was a man
In 2015, Bangladesh's Prime Minister was a woman
In <var>, {}'s <role> was a


In [11]:
correct = 0
wrong = 0
for sample in relation.samples:
    predictions = lre(sample = sample).predictions
    known_flag = functional.is_nontrivial_prefix(
        prediction=predictions[0].token, target=sample.object
    )
    print(f"{sample.subject=} ({sample.placeholders['<var>']}), {sample.object=}, ", end="")
    print(f'predicted="{functional.format_whitespace(predictions[0].token)}", (p={predictions[0].prob:.3f}), known=({functional.get_tick_marker(known_flag)})')
    
    correct += known_flag
    wrong += not known_flag
    
faithfulness = correct/(correct + wrong)

print("------------------------------------------------------------")
print(f"Faithfulness (@1) = {faithfulness}")
print("------------------------------------------------------------")

sample.subject='Norway' (2015), sample.object='woman', predicted="woman", (p=0.527), known=(✓)
sample.subject='Germany' (2015), sample.object='woman', predicted="man", (p=0.231), known=(✗)
sample.subject='Liberia' (2015), sample.object='woman', predicted="woman", (p=0.332), known=(✓)
sample.subject='Lebanon' (2015), sample.object='man', predicted="man", (p=0.165), known=(✓)
sample.subject='Iraq' (2015), sample.object='man', predicted="man", (p=0.196), known=(✓)
sample.subject='Comoros' (2015), sample.object='man', predicted="woman", (p=0.161), known=(✗)
sample.subject='Thailand' (2015), sample.object='man', predicted="man", (p=0.172), known=(✓)
sample.subject='South Korea' (2015), sample.object='woman', predicted="woman", (p=0.229), known=(✓)
sample.subject='Kenya' (2015), sample.object='man', predicted="woman", (p=0.170), known=(✗)
sample.subject='Brazil' (2015), sample.object='woman', predicted="woman", (p=0.202), known=(✓)
------------------------------------------------------------

In [12]:
prompt = relation[0][0]

functional.predict_next_token(
    mt = mt,
    prompt = prompt
)

[[PredictedToken(token='woman', prob=0.7388492822647095),
  PredictedToken(token='man', prob=0.25936049222946167),
  PredictedToken(token='women', prob=0.0005853583570569754),
  PredictedToken(token='female', prob=0.0002852808975148946),
  PredictedToken(token='\n', prob=8.173433889169246e-05)]]