In [1]:
%load_ext autoreload
%autoreload 2

In [36]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join('..')))

import logging
import locality.utils.logging_utils as logging_utils

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.DEBUG,
    format = logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout
)


In [37]:
from dsets.counterfact import CounterFactDataset

counterfact = CounterFactDataset(data_dir="../counterfact")

located_in_city = [d for d in counterfact if d['requested_rewrite']['relation_id'] == "P17"]
places_to_cities = [
    (d['requested_rewrite']['subject'], d['requested_rewrite']['target_true']["str"])
    for d in located_in_city
]
print(len(places_to_cities))
places_to_cities[:4]

Loaded dataset with 21919 elements
875


[('Autonomous University of Madrid', 'Spain'),
 ('Pochepsky District', 'Russia'),
 ('Kuala Langat', 'Malaysia'),
 ('Wanne-Eickel Central Station', 'Germany')]

In [38]:
import locality.utils.dataset_utils as dset_utils

dset_utils.get_demonstrations(
    subj_obj_mapping=places_to_cities,
    num_options=3,
    num_icl=1,
    variable_binding_template=" {} is visiting {}",
    query_template=" {} is in {}",
)

([' Pamela is visiting Sydney Peace Prize,  Helen is visiting Plougonven,  Dorothy is visiting Western Bug. Pamela is in Australia'],
 ['Australia'],
 ['Sydney Peace Prize'],
 ['Pamela'],
 ['Plougonven', 'Sydney Peace Prize', 'Western Bug'],
 ['Helen', 'Dorothy', 'Pamela'])

In [39]:
from locality.dataset import generate_synthetic_dataset

In [40]:
synth_dataset = generate_synthetic_dataset(
    relation_subj_obj_mapping=places_to_cities,
    variable_binding_template=" {} is visiting {}",
    query_template=" {} is in {}",
    num_options=3,
    num_icl=5,
    batch_size=32
)

In [41]:
sample = synth_dataset[30]

print(sample[0])
print(sample[1])

 Lacey is visiting L'Escala,  Corinna is visiting Le Puy Foot 43 Auvergne,  Florence is visiting Rottendorf. Florence is in Germany
 Thomas is visiting Fluminense F.C.,  Mervin is visiting Mezhdurechensky District,  Manuela is visiting Birendranagar. Mervin is in Russia
 Marion is visiting Wentworth Valley,  Jacqueline is visiting Pampilhosa da Serra,  Leatrice is visiting Frosinone. Leatrice is in Italy
 Lynne is visiting Tovarnik,  Carole is visiting Buda,  Lizabeth is visiting Buchan. Lynne is in Croatia
 Marguerite is visiting Media Development Authority,  Cynthia is visiting Gmina Pokrzywnica,  Jason is visiting Eirodziesma. Marguerite is in Singapore
 Elizabeth is visiting Fryderyk Chopin University of Music,  Alex is visiting Mewat,  Rachel is visiting County Leitrim. Alex is in
India


In [42]:
synth_dataset.qa_samples[10].to_json()

'{"query": " Denise is visiting Presbyterian High School,  Kelsi is visiting Loppa,  Latoya is visiting Soalala District. Denise is in", "subject": "Presbyterian High School", "variable": "Denise", "answer": "Singapore"}'

In [43]:
synth_dataset.to_json()

'{"few_shot_examples": [" Lacey is visiting L\'Escala,  Corinna is visiting Le Puy Foot 43 Auvergne,  Florence is visiting Rottendorf. Florence is in Germany", " Thomas is visiting Fluminense F.C.,  Mervin is visiting Mezhdurechensky District,  Manuela is visiting Birendranagar. Mervin is in Russia", " Marion is visiting Wentworth Valley,  Jacqueline is visiting Pampilhosa da Serra,  Leatrice is visiting Frosinone. Leatrice is in Italy", " Lynne is visiting Tovarnik,  Carole is visiting Buda,  Lizabeth is visiting Buchan. Lynne is in Croatia", " Marguerite is visiting Media Development Authority,  Cynthia is visiting Gmina Pokrzywnica,  Jason is visiting Eirodziesma. Marguerite is in Singapore"], "qa_samples": [{"query": " Janie is visiting Shakhenat Rural District,  Linda is visiting Persegres Gresik United,  Vickie is visiting Subarnapur district. Linda is in", "subject": "Persegres Gresik United", "variable": "Linda", "answer": "Indonesia"}, {"query": " Charlene is visiting Foothill

In [44]:
places_to_cities[:4]

[('Autonomous University of Madrid', 'Spain'),
 ('Pochepsky District', 'Russia'),
 ('Kuala Langat', 'Malaysia'),
 ('Wanne-Eickel Central Station', 'Germany')]

In [45]:
from locality.models import ModelandTokenizer

# MODEL_PATH = "EleutherAI/gpt-j-6B"
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
# MODEL_PATH = "mistralai/Mistral-7B-v0.1"

mt = ModelandTokenizer(model_path = MODEL_PATH)

2023-11-21 15:08:42 urllib3.connectionpool DEBUG    Resetting dropped connection: huggingface.co


2023-11-21 15:08:42 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /meta-llama/Llama-2-7b-hf/resolve/main/config.json HTTP/1.1" 200 0


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

2023-11-21 15:08:43 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /meta-llama/Llama-2-7b-hf/resolve/main/generation_config.json HTTP/1.1" 200 0
2023-11-21 15:08:45 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /meta-llama/Llama-2-7b-hf/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
loaded model <meta-llama/Llama-2-7b-hf> | size: 12916.516 MB


In [46]:
from locality.functional import make_icl_prompt, filter_samples_by_model_knowledge

known_samples = filter_samples_by_model_knowledge(
    mt, 
    subj_obj_mapping=places_to_cities[:25],
    prompt_template=" {} is located in the country of",
)

2023-11-21 15:08:47 locality.functional DEBUG    filtering with prompt ` {} is located in the country of`


2023-11-21 15:08:48 locality.functional DEBUG    Autonomous University of Madrid -> answer='Spain' | predicted = 'Spain'(0.9682959318161011) ==> (✓)
2023-11-21 15:08:48 locality.functional DEBUG    Pochepsky District -> answer='Russia' | predicted = 'Russia'(0.41325631737709045) ==> (✓)
2023-11-21 15:08:48 locality.functional DEBUG    Kuala Langat -> answer='Malaysia' | predicted = 'Malays'(0.9393529891967773) ==> (✓)
2023-11-21 15:08:48 locality.functional DEBUG    Wanne-Eickel Central Station -> answer='Germany' | predicted = 'Germany'(0.9368625283241272) ==> (✓)
2023-11-21 15:08:48 locality.functional DEBUG    Hohenlohe-Langenburg -> answer='Germany' | predicted = 'Germany'(0.889183521270752) ==> (✓)
2023-11-21 15:08:48 locality.functional DEBUG    Bastille -> answer='France' | predicted = 'France'(0.702483057975769) ==> (✓)
2023-11-21 15:08:48 locality.functional DEBUG    Shablykinsky District -> answer='Russia' | predicted = 'Russia'(0.3501109778881073) ==> (✓)
2023-11-21 15:08:48

In [47]:
limit = 3

for query, answer in synth_dataset:
    print(query)
    print(answer)
    print()
    limit -= 1
    if limit == 0:
        break

 Lacey is visiting L'Escala,  Corinna is visiting Le Puy Foot 43 Auvergne,  Florence is visiting Rottendorf. Florence is in Germany
 Thomas is visiting Fluminense F.C.,  Mervin is visiting Mezhdurechensky District,  Manuela is visiting Birendranagar. Mervin is in Russia
 Marion is visiting Wentworth Valley,  Jacqueline is visiting Pampilhosa da Serra,  Leatrice is visiting Frosinone. Leatrice is in Italy
 Lynne is visiting Tovarnik,  Carole is visiting Buda,  Lizabeth is visiting Buchan. Lynne is in Croatia
 Marguerite is visiting Media Development Authority,  Cynthia is visiting Gmina Pokrzywnica,  Jason is visiting Eirodziesma. Marguerite is in Singapore
 Janie is visiting Shakhenat Rural District,  Linda is visiting Persegres Gresik United,  Vickie is visiting Subarnapur district. Linda is in
Indonesia

 Lacey is visiting L'Escala,  Corinna is visiting Le Puy Foot 43 Auvergne,  Florence is visiting Rottendorf. Florence is in Germany
 Thomas is visiting Fluminense F.C.,  Mervin is vi

In [68]:
from locality.functional import get_h

sample = synth_dataset.qa_samples[10]
prompt, subject = sample.query, sample.subject

hrr = get_h(
    mt, 
    prompt, 
    subject,
    layers = [mt.layer_name_format.format(i) for i in range(5, 10)]
)

In [69]:
for key, val in hrr.items():
    print(key, val.shape)

model.layers.5 torch.Size([4096])
model.layers.6 torch.Size([4096])
model.layers.7 torch.Size([4096])
model.layers.8 torch.Size([4096])
model.layers.9 torch.Size([4096])


In [72]:
mt.tokenizer("Apple")

{'input_ids': [1, 12113], 'attention_mask': [1, 1]}

In [79]:
import torch

t = torch.randperm(10)
t.topk()

TypeError: topk() missing 1 required positional arguments: "k"

In [91]:
from locality.functional import predict_next_token

predict_next_token(
    mt, 
    prompt = "Eiffel Tower is located in",
    token_of_interest="France"
)

([[PredictedToken(token='Paris', token_id=3681, prob=0.5224354863166809),
   PredictedToken(token='the', token_id=278, prob=0.2606527507305145),
   PredictedToken(token='France', token_id=3444, prob=0.023682858794927597),
   PredictedToken(token='which', token_id=607, prob=0.020576022565364838),
   PredictedToken(token='', token_id=29871, prob=0.011542102321982384)]],
 [(3,
   PredictedToken(token='France', token_id=3444, prob=0.023682858794927597))])

In [3]:
arr = [1,2,3]
arr.extend(4)
arr

TypeError: 'int' object is not iterable