In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append("../")
import copy

import logging
from src.utils import logging_utils
from src import functional
from src.models import ModelandTokenizer
from src.dataset import load_dataset, load_relation, fill_template


logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.INFO,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

In [4]:
from src.models import ModelandTokenizer

model_name = "meta-llama/Meta-Llama-3-8B"
mt = ModelandTokenizer(model_path=model_name)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


2024-06-06 15:12:24 src.models INFO     loaded model </home/local_arnab/Codes/saved_model_weights/meta-llama/Meta-Llama-3-8B> | size: 15508.516 MB


In [5]:
#! This filtering strategy may cause problem for `gender_head_of_govt`, the model is almost always happy to predict `male` with zero-shot

# relation = load_relation(
#     relation_file = "head_of_government.json",
#     num_icl = 0,                          # initialize with zero-shot
#     default_path="../data",
#     batch_size=500
# )

# all_samples = copy.deepcopy(relation.samples)

# # filter zero-shot model knowledge
# relation = functional.filter_samples_by_model_knowledge(
#     mt = mt,
#     relation = relation,
# )

# relation.properties["num_icl"] = 5
# relation.select_icl_examples(num_icl=5)

# relation.samples = all_samples

# # filter model knowledge with `num_icl` shots
# relation = functional.filter_samples_by_model_knowledge(
#     mt = mt,
#     relation = relation,
# )

In [16]:
from src.dataset import balance_samples

relation = load_relation(
    relation_file = "head_of_govt.json",
    num_icl = 5,                          # initialize with 5-shot
    default_path="../data",
    # batch_size=500
)

relation = functional.filter_samples_by_var(relation = relation, var = "2015")

relation.range_stats

2024-06-06 15:14:20 src.dataset INFO     initialized relation -> "head of government" with 8354 samples
2024-06-06 15:14:20 src.functional INFO     filtered 120 with var=2015, from head of government


'range = |1190| count(obj)_min = 1, count(obj)_max = 1'

In [17]:
relation = functional.filter_samples_by_model_knowledge(
    mt = mt,
    relation = relation,
)

relation.samples = balance_samples(relation.samples)

relation.range_stats

2024-06-06 15:14:30 src.functional INFO     filtered relation "head of government" to 106 samples (with 5-shots)
2024-06-06 15:14:30 src.dataset INFO     initialized relation -> "head of government" with 101 samples


'range = |106| count(obj)_min = 1, count(obj)_max = 1'

In [19]:
from relations.src.operators import JacobianIclMeanEstimator
import relations.src.functional as relations_functional

# relations_functional.make_prompt = functional.make_prompt
relation.prompt_template

"In <var>, {}'s <role> was named"

In [20]:
relation.few_shot_demonstrations

["In 2015, India's Prime Minister was named Narendra Modi",
 "In 2015, Hungary's Prime Minister was named Viktor Orbán",
 "In 2015, Zimbabwe's President was named Robert Mugabe",
 "In 2015, Austria's Federal Chancellor was named Werner Faymann",
 "In 2015, Azerbaijan's President was named Ilham Aliyev"]

In [21]:
estimator = JacobianIclMeanEstimator(
    mt = mt,
    h_layer = 8,
    beta = 5.0
)

lre = estimator(relation)

In [22]:
print(lre.prompt_template)

In 2015, India's Prime Minister was named Narendra Modi
In 2015, Hungary's Prime Minister was named Viktor Orbán
In 2015, Zimbabwe's President was named Robert Mugabe
In 2015, Austria's Federal Chancellor was named Werner Faymann
In 2015, Azerbaijan's President was named Ilham Aliyev
In <var>, {}'s <role> was named


In [23]:
correct = 0
wrong = 0
for sample in relation.samples:
    predictions = lre(sample = sample).predictions
    known_flag = functional.is_nontrivial_prefix(
        prediction=predictions[0].token, target=sample.object
    )
    print(f"{sample.subject=} ({sample.placeholders['<var>']}), {sample.object=}, ", end="")
    print(f'predicted="{functional.format_whitespace(predictions[0].token)}", (p={predictions[0].prob:.3f}), known=({functional.get_tick_marker(known_flag)})')
    
    correct += known_flag
    wrong += not known_flag
    
faithfulness = correct/(correct + wrong)

print("------------------------------------------------------------")
print(f"Faithfulness (@1) = {faithfulness}")
print("------------------------------------------------------------")

sample.subject='South Sudan' (2015), sample.object='Salva Kiir Mayardit', predicted=" as", (p=0.146), known=(✗)
sample.subject='Singapore' (2015), sample.object='Lee Hsien Loong', predicted=" Time", (p=0.186), known=(✗)
sample.subject='Iran' (2015), sample.object='Hassan Rouhani', predicted=" Time", (p=0.234), known=(✗)
sample.subject='Russia' (2015), sample.object='Vladimir Putin', predicted=" Time", (p=0.203), known=(✗)
sample.subject='Thailand' (2015), sample.object='Prayut Chan-o-cha', predicted=" as", (p=0.096), known=(✗)
sample.subject='Montenegro' (2015), sample.object='Filip Vujanović', predicted=" as", (p=0.144), known=(✗)
sample.subject='Jordan' (2015), sample.object='Abdullah II', predicted=" Time", (p=0.095), known=(✗)
sample.subject='Uzbekistan' (2015), sample.object='Islam Karimov', predicted=" as", (p=0.105), known=(✗)
sample.subject='Eritrea' (2015), sample.object='Isaias Afwerki', predicted=" one", (p=0.178), known=(✗)
sample.subject='Yemen' (2015), sample.object='Abdr

In [24]:
prompt = relation[2][0]

functional.predict_next_token(
    mt = mt,
    prompt = prompt
)

[[PredictedToken(token=' Hassan', prob=0.93973308801651),
  PredictedToken(token=' Hasan', prob=0.017482858151197433),
  PredictedToken(token=' Mahmoud', prob=0.010770877823233604),
  PredictedToken(token=' Ali', prob=0.007402708288282156),
  PredictedToken(token=' Mohammad', prob=0.0030859075486660004)]]

In [25]:
relation.samples

[Sample(subject='South Sudan', object='Salva Kiir Mayardit', placeholders={'<var>': '2015', '<role>': 'President'}),
 Sample(subject='Singapore', object='Lee Hsien Loong', placeholders={'<var>': '2015', '<role>': 'Prime Minister'}),
 Sample(subject='Iran', object='Hassan Rouhani', placeholders={'<var>': '2015', '<role>': 'President'}),
 Sample(subject='Russia', object='Vladimir Putin', placeholders={'<var>': '2015', '<role>': 'President'}),
 Sample(subject='Thailand', object='Prayut Chan-o-cha', placeholders={'<var>': '2015', '<role>': 'Prime Minister'}),
 Sample(subject='Montenegro', object='Filip Vujanović', placeholders={'<var>': '2015', '<role>': 'President'}),
 Sample(subject='Jordan', object='Abdullah II', placeholders={'<var>': '2015', '<role>': 'King'}),
 Sample(subject='Uzbekistan', object='Islam Karimov', placeholders={'<var>': '2015', '<role>': 'President'}),
 Sample(subject='Eritrea', object='Isaias Afwerki', placeholders={'<var>': '2015', '<role>': 'President'}),
 Sample(s

In [26]:
functional.predict_next_token(
    mt = mt,
    prompt = "The name of the Prime Minister in Bangladesh in the year 2015 is"
)

[[PredictedToken(token=' Sheikh', prob=0.5564688444137573),
  PredictedToken(token='\n', prob=0.04031047224998474),
  PredictedToken(token=' She', prob=0.03699096664786339),
  PredictedToken(token=':\n', prob=0.02332991175353527),
  PredictedToken(token=' Mr', prob=0.022789472714066505)]]