In [1]:
%load_ext autoreload
%autoreload 2

In [13]:
import sys
sys.path.append('../')
import copy

import logging
from src.utils import logging_utils

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.INFO,
    format = logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout
)

In [3]:
from src.dataset import load_dataset, load_relation, fill_template

relation = load_relation(
    relation_file = "head_of_government.json",
    default_path="../data"
)

2023-12-12 16:51:48 src.dataset INFO     initialized relation -> "head of government" with 8354 samples


In [4]:
relation.select_icl_examples(num_icl=5)
print(relation[5][0])

In 1941, Liberia's President was named Edwin Barclay
In 2014, Australia's Prime Minister was named Tony Abbott
In 2007, Saudi Arabia's King was named Abdullah bin Abdulaziz Al Saud
In 1978, Vietnam's President was named Phạm Văn Đồng
In 1947, Philippines's President was named Manuel Roxas
In 2009, Zambia's President was named


In [5]:
from src import functional
relation_1971 = functional.filter_samples_by_var(relation = relation, var = "1971")

2023-12-12 16:51:50 src.functional INFO     filtered 78 with var=1971, from head of government


In [6]:
relation_1971.few_shot_demonstrations

["In 1971, Turkey's President was named Cevdet Sunay",
 "In 1971, Austria's Federal Chancellor was named Bruno Kreisky",
 "In 1971, Tanzania's President was named Julius Nyerere",
 "In 1971, Italy's Prime Minister was named Emilio Colombo",
 "In 1971, Ethiopia's Prime Minister was named Tsehafi Taezaz Aklilu Habte-Wold"]

In [7]:
dataset = load_dataset(
    default_path="../data",
)

2023-12-12 16:51:52 src.dataset INFO     initialized relation -> "head of government" with 8354 samples


## Filter By Model Knowledge

In [8]:
from src.models import ModelandTokenizer

model_name = "meta-llama/Llama-2-7b-hf"
mt = ModelandTokenizer(model_path=model_name)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

2023-12-12 16:51:59 src.models INFO     loaded model <meta-llama/Llama-2-7b-hf> | size: 12980.516 MB


In [20]:
relation = load_relation(
    relation_file = "head_of_government.json",
    num_icl = 0,                          # initialize with zero-shot
    default_path="../data",
    batch_size=500
)

all_samples = copy.deepcopy(relation.samples)

2023-12-12 18:45:28 src.dataset INFO     initialized relation -> "head of government" with 500 samples


In [21]:
# filter zero-shot model knowledge
relation = functional.filter_samples_by_model_knowledge(
    mt = mt,
    relation = relation,
)

2023-12-12 18:45:48 src.functional INFO     filtered relation "head of government" to 181 samples


In [22]:
relation.properties["num_icl"] = 5
relation.select_icl_examples(num_icl=5)

relation.samples = all_samples

# filter model knowledge with `num_icl` shots
relation = functional.filter_samples_by_model_knowledge(
    mt = mt,
    relation = relation,
)

2023-12-12 18:46:13 src.functional INFO     filtered relation "head of government" to 363 samples
