In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
sys.path.append('../')
import copy

import logging
from src.utils import logging_utils

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.INFO,
    format = logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout
)

In [4]:
from src.dataset import load_dataset, load_relation, fill_template

relation = load_relation(
    # relation_file = "head_of_government.json",
    relation_file = "gender_head_of_govt.json",
    default_path="../data"
)

2023-12-15 16:02:07 src.dataset INFO     initialized relation -> "gender head of govt" with 7757 samples


In [5]:
relation.select_icl_examples(num_icl=5)
print(relation[5][0])

In 1976, Ireland's Taoiseach was a man
In 1988, Philippines's President was a woman
In 1997, Bangladesh's Prime Minister was a woman
In 2020, Togo's President was a man
In 2001, Yemen's President was a man
In 1968, Philippines's President was a


In [6]:
from src import functional
relation_1971 = functional.filter_samples_by_var(relation = relation, var = "1971")

2023-12-15 16:02:07 src.functional INFO     filtered 81 with var=1971, from gender head of govt


In [7]:
relation_1971.few_shot_demonstrations

["In 1971, Taiwan's President was a man",
 "In 1971, Malawi's President was a man",
 "In 1971, Israel's Prime Minister was a woman",
 "In 1971, Sudan's Prime Minister was a man",
 "In 1971, India's Prime Minister was a woman"]

In [8]:
dataset = load_dataset(
    default_path="../data",
)

2023-12-15 16:02:08 src.dataset INFO     initialized relation -> "head of government" with 8354 samples
2023-12-15 16:02:08 src.dataset INFO     initialized relation -> "gender head of govt" with 7757 samples


## Filter By Model Knowledge

In [9]:
from src.models import ModelandTokenizer

model_name = "meta-llama/Llama-2-7b-hf"
mt = ModelandTokenizer(model_path=model_name)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

2023-12-15 16:02:14 src.models INFO     loaded model <meta-llama/Llama-2-7b-hf> | size: 12980.516 MB


In [10]:
#! This filtering strategy may cause problem for `gender_head_of_govt`, the model is almost always happy to predict `male` with zero-shot

# relation = load_relation(
#     relation_file = "head_of_government.json",
#     num_icl = 0,                          # initialize with zero-shot
#     default_path="../data",
#     batch_size=500
# )

# all_samples = copy.deepcopy(relation.samples)

# # filter zero-shot model knowledge
# relation = functional.filter_samples_by_model_knowledge(
#     mt = mt,
#     relation = relation,
# )

# relation.properties["num_icl"] = 5
# relation.select_icl_examples(num_icl=5)

# relation.samples = all_samples

# # filter model knowledge with `num_icl` shots
# relation = functional.filter_samples_by_model_knowledge(
#     mt = mt,
#     relation = relation,
# )

In [11]:
relation = load_relation(
    relation_file = "gender_head_of_govt.json",
    num_icl = 5,                          # initialize with 5-shot
    default_path="../data",
    batch_size=500
)

all_samples = copy.deepcopy(relation.samples)

relation.range_stats

2023-12-15 16:02:14 src.dataset INFO     initialized relation -> "gender head of govt" with 357 samples


{'man': 178, 'woman': 179}

In [12]:
# # prompt_template = relation.prompt_template
# # relation.prompt_template = relation.prompt_template_zs

# print(relation.prompt_template)
# print(relation[10][0])

In [13]:
# Check model knowledge (few-shot examples used to contextualize may be of bad quality -- not known to model)
# performance tends to vary a lot depending on the few-shot examples

relation = functional.filter_samples_by_model_knowledge(
    mt = mt,
    relation = relation,
)

relation.range_stats

2023-12-15 16:02:26 src.functional INFO     filtered relation "gender head of govt" to 196 samples (with 5-shots)


{'man': 175, 'woman': 21}

In [14]:
# # Now, select ICL examples from the filtered set that the model knows
# relation.select_icl_examples(num_icl=5, consider_few_shot_samples=False)
# relation.samples = all_samples

# # Use the known samples to filter by model knowledge again
# relation = functional.filter_samples_by_model_knowledge(
#     mt = mt,
#     relation = relation,
# )
# relation.range_stats

In [15]:
from src.dataset import balance_samples
relation.samples = balance_samples(relation.samples)

relation.range_stats

{'man': 21, 'woman': 21}