In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../')
import copy

import logging
from src.utils import logging_utils

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.INFO,
    format = logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout
)

In [3]:
from src.dataset import load_dataset, load_relation, fill_template

relation = load_relation(
    # relation_file = "head_of_government.json",
    relation_file = "gender_head_of_govt.json",
    default_path="../data"
)

2023-12-15 14:12:27 src.dataset INFO     initialized relation -> "gender head of govt" with 7757 samples


In [4]:
relation.select_icl_examples(num_icl=5)
print(relation[5][0])

In 1985, Benin's President was a male
In 1786, United Kingdom's Prime Minister was a male
In 1995, Sri Lanka's President was a female
In 1893, Nepal's Prime Minister was a male
In 1993, Norway's Prime Minister was a female
In 1992, Hungary's Prime Minister was a


In [5]:
from src import functional
relation_1971 = functional.filter_samples_by_var(relation = relation, var = "1971")

2023-12-15 14:12:28 src.functional INFO     filtered 82 with var=1971, from gender head of govt


In [6]:
relation_1971.few_shot_demonstrations

["In 1971, India's Prime Minister was a female",
 "In 1971, Lebanon's Prime Minister was a male",
 "In 1971, Israel's Prime Minister was a female",
 "In 1971, Cyprus's President was a male",
 "In 1971, United Kingdom's Prime Minister was a male"]

In [7]:
dataset = load_dataset(
    default_path="../data",
)

2023-12-15 14:12:29 src.dataset INFO     initialized relation -> "head of government" with 8354 samples
2023-12-15 14:12:29 src.dataset INFO     initialized relation -> "gender head of govt" with 7757 samples


## Filter By Model Knowledge

In [9]:
from src.models import ModelandTokenizer

model_name = "meta-llama/Llama-2-13b-hf"
mt = ModelandTokenizer(model_path=model_name)

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

2023-12-15 14:12:53 src.models INFO     loaded model <meta-llama/Llama-2-13b-hf> | size: 24985.801 MB


In [10]:
relation = load_relation(
    relation_file = "gender_head_of_govt.json",
    num_icl = 0,                          # initialize with zero-shot
    default_path="../data",
    batch_size=500
)

all_samples = copy.deepcopy(relation.samples)

relation.range_stats

2023-12-15 14:13:05 src.dataset INFO     initialized relation -> "gender head of govt" with 362 samples


{'male': 181, 'female': 181}

In [11]:
prompt_template = relation.prompt_template
relation.prompt_template = relation.prompt_template_zs

print(relation.prompt_template)
print(relation[10][0])

In <var>, the gender of {} <role> was a male or female? Ans:
In 2001, the gender of New Zealand Prime Minister was a male or female? Ans:


In [12]:
# filter zero-shot model knowledge
relation = functional.filter_samples_by_model_knowledge(
    mt = mt,
    relation = relation,
)

2023-12-15 14:13:27 src.functional INFO     filtered relation "gender head of govt" to 180 samples (with 0-shots)


In [14]:
relation.range_stats

{'male': 180}

In [15]:
relation.properties["num_icl"] = 5
relation.select_icl_examples(num_icl=5)

relation.samples = all_samples

# filter model knowledge with `num_icl` shots
relation = functional.filter_samples_by_model_knowledge(
    mt = mt,
    relation = relation,
)

2023-12-15 14:14:29 src.functional INFO     filtered relation "gender head of govt" to 226 samples (with 5-shots)
