In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join('..')))

import logging
import locality.utils.logging_utils as logging_utils

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.DEBUG,
    format = logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout
)


In [3]:
from dsets.counterfact import CounterFactDataset

counterfact = CounterFactDataset(data_dir="../counterfact")

located_in_city = [d for d in counterfact if d['requested_rewrite']['relation_id'] == "P17"]
places_to_cities = [
    (d['requested_rewrite']['subject'], d['requested_rewrite']['target_true']["str"])
    for d in located_in_city
]
print(len(places_to_cities))
places_to_cities[:4]

Loaded dataset with 21919 elements
875


[('Autonomous University of Madrid', 'Spain'),
 ('Pochepsky District', 'Russia'),
 ('Kuala Langat', 'Malaysia'),
 ('Wanne-Eickel Central Station', 'Germany')]

In [4]:
import locality.utils.dataset_utils as dset_utils

dset_utils.get_demonstrations(
    subj_obj_mapping=places_to_cities,
    num_options=3,
    num_icl=1,
    variable_binding_template=" {} is visiting {}",
    query_template=" {} is in {}",
)

([' Troy is visiting Peruvian Navy,  Mark is visiting Castleisland Desmonds GAA,  Elizabeth is visiting Bedourie. Troy is in Peru'],
 ['Peru'],
 ['Troy', 'Elizabeth', 'Mark'],
 ['Peruvian Navy', 'Bedourie', 'Castleisland Desmonds GAA'])

In [5]:
from locality.dataset import generate_synthetic_dataset

In [6]:
synth_dataset = generate_synthetic_dataset(
    relation_subj_obj_mapping=places_to_cities,
    variable_binding_template=" {} is visiting {}",
    query_template=" {} is in {}",
    num_options=3,
    num_icl=5,
    batch_size=32
)

In [7]:
sample = synth_dataset[30]

print(sample[0])
print(sample[1])

 Deborah is visiting Allersberg,  Genevieve is visiting Kalvola,  Lorine is visiting CNH Industrial. Genevieve is in Finland
 Jimmy is visiting Papanasam,  Timothy is visiting Sardent,  Crystal is visiting Hotel Oloffson. Jimmy is in India
 Stacie is visiting Wanne-Eickel Central Station,  Maria is visiting Berar Province,  Aida is visiting Leiria. Maria is in India
 Alexander is visiting Steenwijk,  Susan is visiting Halton County, Ontario,  Lorenzo is visiting La Ribera Baixa. Lorenzo is in Spain
 Jeanette is visiting Kharga Oasis,  Barbara is visiting United Mine Workers,  Ava is visiting Pontigny Abbey. Barbara is in Canada
 Pamela is visiting Ponoy River,  Agustina is visiting Kretinga,  Charles is visiting Minamiarupusu. Pamela is in
Russia


In [8]:
synth_dataset.qa_samples[10].to_json()

'{"query": " Casey is visiting Nishi-Matsuura District,  Clyde is visiting Kojur District,  Robert is visiting Suwayq. Casey is in", "answer": "Japan"}'

In [9]:
synth_dataset.to_json()

'{"few_shot_examples": [" Deborah is visiting Allersberg,  Genevieve is visiting Kalvola,  Lorine is visiting CNH Industrial. Genevieve is in Finland", " Jimmy is visiting Papanasam,  Timothy is visiting Sardent,  Crystal is visiting Hotel Oloffson. Jimmy is in India", " Stacie is visiting Wanne-Eickel Central Station,  Maria is visiting Berar Province,  Aida is visiting Leiria. Maria is in India", " Alexander is visiting Steenwijk,  Susan is visiting Halton County, Ontario,  Lorenzo is visiting La Ribera Baixa. Lorenzo is in Spain", " Jeanette is visiting Kharga Oasis,  Barbara is visiting United Mine Workers,  Ava is visiting Pontigny Abbey. Barbara is in Canada"], "qa_samples": [{"query": " Steve is visiting Kosi Zone,  Carrie is visiting Gjerpen,  Claude is visiting Eirodziesma. Steve is in", "answer": "Nepal"}, {"query": " David is visiting Medaram,  Linda is visiting Lindholm station,  James is visiting Aventine Hill. Linda is in", "answer": "Denmark"}, {"query": " Dallas is visi

In [10]:
places_to_cities[:4]

[('Autonomous University of Madrid', 'Spain'),
 ('Pochepsky District', 'Russia'),
 ('Kuala Langat', 'Malaysia'),
 ('Wanne-Eickel Central Station', 'Germany')]

In [11]:
from locality.models import ModelandTokenizer

# MODEL_PATH = "EleutherAI/gpt-j-6B"
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
# MODEL_PATH = "mistralai/Mistral-7B-v0.1"

mt = ModelandTokenizer(model_path = MODEL_PATH)

2023-11-20 17:26:07 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443


2023-11-20 17:26:07 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /meta-llama/Llama-2-7b-hf/resolve/main/config.json HTTP/1.1" 200 0


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

2023-11-20 17:26:08 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /meta-llama/Llama-2-7b-hf/resolve/main/generation_config.json HTTP/1.1" 200 0
2023-11-20 17:26:10 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /meta-llama/Llama-2-7b-hf/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
loaded model <meta-llama/Llama-2-7b-hf> | size: 12916.516 MB


In [18]:
from locality.functional import make_icl_prompt, filter_samples_by_model_knowledge

known_samples = filter_samples_by_model_knowledge(
    mt, 
    subj_obj_mapping=places_to_cities[:25],
    prompt_template=" {} is located in the country of",
)

2023-11-20 17:29:12 locality.functional DEBUG    filtering with prompt ` {} is located in the country of`
2023-11-20 17:29:12 locality.functional DEBUG    Autonomous University of Madrid -> answer='Spain' | predicted = 'Spain'(0.9682959318161011) ==> (✓)
2023-11-20 17:29:12 locality.functional DEBUG    Pochepsky District -> answer='Russia' | predicted = 'Russia'(0.41325631737709045) ==> (✓)
2023-11-20 17:29:12 locality.functional DEBUG    Kuala Langat -> answer='Malaysia' | predicted = 'Malays'(0.9393529891967773) ==> (✓)
2023-11-20 17:29:12 locality.functional DEBUG    Wanne-Eickel Central Station -> answer='Germany' | predicted = 'Germany'(0.9368625283241272) ==> (✓)
2023-11-20 17:29:12 locality.functional DEBUG    Hohenlohe-Langenburg -> answer='Germany' | predicted = 'Germany'(0.889183521270752) ==> (✓)
2023-11-20 17:29:12 locality.functional DEBUG    Bastille -> answer='France' | predicted = 'France'(0.702483057975769) ==> (✓)
2023-11-20 17:29:12 locality.functional DEBUG    Shabl

In [19]:
known_samples

[('Autonomous University of Madrid', 'Spain'),
 ('Pochepsky District', 'Russia'),
 ('Kuala Langat', 'Malaysia'),
 ('Wanne-Eickel Central Station', 'Germany'),
 ('Hohenlohe-Langenburg', 'Germany'),
 ('Bastille', 'France'),
 ('Shablykinsky District', 'Russia'),
 ('Manila Light Rail Transit System', 'Philippines'),
 ('Valdemarsvik', 'Sweden'),
 ('Piper Verlag', 'Germany'),
 ('Attingal', 'India'),
 ('Nizampatnam', 'India'),
 ('Tehri Garhwal district', 'India'),
 ('Eirodziesma', 'Latvia'),
 ('Olot', 'Spain'),
 ('Sumulong Highway', 'Philippines'),
 ('Darmstadt', 'Germany'),
 ('Adliswil', 'Switzerland'),
 ('Junnar', 'India')]

In [21]:
limit = 3

for query, answer in synth_dataset:
    print(query)
    print(answer)
    print()
    limit -= 1
    if limit == 0:
        break

 Deborah is visiting Allersberg,  Genevieve is visiting Kalvola,  Lorine is visiting CNH Industrial. Genevieve is in Finland
 Jimmy is visiting Papanasam,  Timothy is visiting Sardent,  Crystal is visiting Hotel Oloffson. Jimmy is in India
 Stacie is visiting Wanne-Eickel Central Station,  Maria is visiting Berar Province,  Aida is visiting Leiria. Maria is in India
 Alexander is visiting Steenwijk,  Susan is visiting Halton County, Ontario,  Lorenzo is visiting La Ribera Baixa. Lorenzo is in Spain
 Jeanette is visiting Kharga Oasis,  Barbara is visiting United Mine Workers,  Ava is visiting Pontigny Abbey. Barbara is in Canada
 Steve is visiting Kosi Zone,  Carrie is visiting Gjerpen,  Claude is visiting Eirodziesma. Steve is in
Nepal

 Deborah is visiting Allersberg,  Genevieve is visiting Kalvola,  Lorine is visiting CNH Industrial. Genevieve is in Finland
 Jimmy is visiting Papanasam,  Timothy is visiting Sardent,  Crystal is visiting Hotel Oloffson. Jimmy is in India
 Stacie is vi

In [22]:
a = 10

dict(a = 10)

{'a': 10}