# UBC CPSC 532V (2023W2) NLP COMMONSENSE

## Assignment 2

The goal of this assignment is to develop an LLM-based in-context learning model for a multiple-choice commonsense QA task and to test the contribution of adding external commonsense knowledge from a KB.

## Group A2 4

**Juntai Cao** (50171404); **Yilin Yang** (24754350); **Yuwei Yin** (36211928).

(Authors contributed equally and listed alphabetically.)

The code and report, as well as all results, are available on [GitHub](https://github.com/YuweiYin/UBC_CPSC_532V/tree/master/Assignment_2).

### Prerequisites 

Install Huggingface `transformers`, `datasets`, and `xformers` (for accelerating computation).

In [1]:
# Uncomment the following pip commands for installing Python packages
# !pip install datasets transformers torch tqdm numpy
# !pip install keybert sentence-transformers openai
# !pip install setuptools wheel spacy
# !python -m spacy download en_core_web_md
# !pip install xformers

### Load COPA

We will use the [Datasets](https://github.com/huggingface/datasets) library to download the data. This can be easily done with the function `load_dataset`. This function will cache the dataset to avoid downloading it again the next time you run this cell.

The `dataset` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation and test set. We will immediately convert them to lists. We will only use the train and validation sets.

In [2]:
import os
import re
import json
import tqdm
from datasets import load_dataset

copa = load_dataset("super_glue", "copa")
train_set = list(copa["train"])
val_set = list(copa["validation"])

## Part 1: In-Context Learning

We are using a prompting approach in which the model generates the predicted answer. Specifically, we follow the in-context / few-shot setup in which the model sees a few questions along with their answers (which are taken from the training set), followed by a single target question for which it generates the answer.

The first step is to present a COPA example in natural language. This is done with the following `single_example_prompt` function.

In [3]:
question_type_to_nl = {"cause": "What could have caused this?", 
                       "effect": "What might have happened as a result?"}

def single_example_prompt(example, include_answer=False):
    prompt = f"Q: {example['premise']} {question_type_to_nl[example['question']]}" + \
             f"\n1) {example['choice1']}\n2) {example['choice2']}"

    if include_answer:
      prompt += f"\nA: {example['label'] + 1}"

    return prompt

print(single_example_prompt(train_set[0], include_answer=True))

Q: My body cast a shadow over the grass. What could have caused this?
1) The sun was rising.
2) The grass was cut.
A: 1


Then, we would like to create a prompt containing several in-context examples (5, in this case) followed by the target example. This is what `create_prompt` does. It will get the function to create a prompt for a single example as an argument so that we can re-use it when we change the single example prompt format. 

In [4]:
import random

NUM_IN_CONTEXT = 5

# Randomly select in context examples from the train set
random.seed(28)  # make sure use the required in_context_examples
in_context_examples = random.sample(train_set, NUM_IN_CONTEXT)

In [5]:
def create_prompt(in_context_examples, target, single_fn=single_example_prompt):
    return "\n\n".join(
        [single_fn(ex, include_answer=True) 
        for ex in in_context_examples]) + "\n\n" + single_fn(target)

prompt = create_prompt(in_context_examples, val_set[0])
print(prompt)

Q: The woman felt lonely. What might have happened as a result?
1) She renovated her kitchen.
2) She adopted a cat.
A: 2

Q: The mother needed help looking after her children. What might have happened as a result?
1) She sent the children to daycare.
2) She gave up custody of the children.
A: 1

Q: I learned how to play the board game. What could have caused this?
1) My friend explained the rules to me.
2) My friend got the rules wrong.
A: 1

Q: The woman's eyeglasses fogged up. What could have caused this?
1) She reclined by the pool.
2) She entered the sauna.
A: 2

Q: I ran out of breath. What could have caused this?
1) I climbed several flights of stairs.
2) I read several chapters of the book.
A: 1

Q: The man turned on the faucet. What might have happened as a result?
1) The toilet filled with water.
2) Water flowed from the spout.


In [6]:
import random
import numpy as np
from transformers import set_seed

# Set all random seeds to guarantee consistent and reproducible results
RANDOM_SEED = 0

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
set_seed(RANDOM_SEED)

You can use [OPT-350M](https://huggingface.co/facebook/opt-350m), [GPT-Neo 125M](https://huggingface.co/EleutherAI/gpt-neo-125m) or whichever other LLM you'd like. Those two you should be able to run locally.  

In [7]:
from transformers import AutoTokenizer, GenerationConfig, pipeline

MODEL_NAME = "EleutherAI/gpt-neo-125m"
# MODEL_NAME = "facebook/opt-350m"
# MODEL_NAME = "openai-community/gpt2"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
generator = pipeline(model=MODEL_NAME, device="cpu")

Let's generate an answer to this prompt (i.e. an answer for the first example in the validation set).

In [8]:
# Compute the number of tokens to predict
inputs = tokenizer([prompt], return_tensors="pt")
expected_tokens = len(tokenizer(["\nA: 1"])[0])
gen_config = GenerationConfig(
    min_new_tokens=expected_tokens, max_new_tokens=expected_tokens + 2, 
    do_sample=True, top_p=0.9, 
    eos_token_id=tokenizer.eos_token_id, 
    pad_token_id=tokenizer.eos_token_id)

# Feed the prompt we created and generated up to `expected_tokens` tokens.
output = generator(prompt, generation_config=gen_config)[0]['generated_text'][len(prompt):].strip()
print(output)

A: 1


In particular, we need to determine whether the model predicted 1 or 2, which we do in `predict`. We will convert `1` to 0 and `2` to 1 to match the gold labels.

In [9]:
def predict(generator, prompt, gen_config=None):
    answer = generator(prompt, generation_config=gen_config)[0][
        'generated_text'][len(prompt):].strip()
    # print(answer)

    # If the model generated another instance, remove it
    if "Q:" in answer:
        answer = answer[:answer.index("Q:")]

    # Find "A: 1" or "A: 2"
    m = re.search("A:\s*([1-2])", answer)
    if m is not None:
        answer = int(m.group(1)) - 1
    else:
        answer = None
    # print(answer)

    return answer


def predict_retry(generator, prompt, gen_config=None):
    """
    Retry generation if the answer format is not correct
    (i.e., containing "A: 1" or "A: 2"), making sure that
    there is no NULL when computing the accuracy.
    """

    answer = None
    retry_limit = 10
    for retry_cnt in range(retry_limit):  # Retry if the output format is not correct
        answer = generator(prompt, generation_config=gen_config)[0][
            'generated_text'][len(prompt):].strip()
        # print(answer)

        # If the model generated another instance, remove it
        if "Q:" in answer:
            answer = answer[:answer.index("Q:")]

        # Find "A: 1" or "A: 2"
        m = re.search("A:\s*([1-2])", answer)
        if m is not None:
            answer = int(m.group(1)) - 1
            break

    if answer is None:
        # print(f"Tried {retry_limit} times, but the answer is still None.\nPROMPT: {prompt}\n")
        print(f"Tried {retry_limit} times, but the answer is still None.")
    # else:
    #     print(f"Final answer: {answer}")

    return answer


In [10]:
prompt = create_prompt(in_context_examples, val_set[0])
predict_retry(generator, prompt, gen_config=gen_config)  

0

Finally, let's compute the predictions for the entire validation set 
and then compute the accuracy for the validation set.

In [11]:
def compute_accuracy(dataset, predictions):
    gold = np.array([ex["label"] for ex in dataset])
    preds = np.array(predictions)
    accuracy = (preds == gold).astype(np.float32).mean().item() * 100
    return accuracy


In [12]:
val_predictions = [
    predict_retry(
        generator,
        create_prompt(in_context_examples, ex),
        gen_config=gen_config
    ) for ex in tqdm.notebook.tqdm(val_set)
]

print(f"Accuracy: {compute_accuracy(val_set, val_predictions):.2f}%")

  0%|          | 0/100 [00:00<?, ?it/s]

Accuracy: 51.00%


Let's save the predictions to a file.

In [13]:
res_fp_1 = "output/basic_predictions-{}.jsonl".format(MODEL_NAME.split("/")[-1])
print(f"Saving to {res_fp_1}")

with open(res_fp_1, "w", encoding="utf-8") as f_out:
    for ex, pred in zip(val_set, val_predictions):
        new_ex = ex.copy()
        new_ex["prediction"] = pred
        f_out.write(json.dumps(new_ex) + "\n")


Saving to output/basic_predictions-gpt-neo-125m.jsonl


In [14]:
predicts = []
with open(res_fp_1, "r", encoding="utf-8") as f_in:
    for line in f_in:
        predicts.append(json.loads(line))

incorrect_predicts = []
for predict in predicts:
    if predict["label"] != predict["prediction"]:
        incorrect_predicts.append(predict)

accuracy = 1 - len(incorrect_predicts) / len(predicts)
print(f"The prediction accuray is {accuracy:.2f}")

selected_incorrect_predicts = random.sample(incorrect_predicts, 20)
selected_incorrect_predicts

The prediction accuray is 0.51


[{'premise': 'The detective revealed an anomaly in the case.',
  'choice1': 'He finalized his theory.',
  'choice2': 'He scrapped his theory.',
  'question': 'effect',
  'idx': 52,
  'label': 1,
  'prediction': 0},
 {'premise': 'The boy skipped dinner.',
  'choice1': 'His mother cooked his favorite meal.',
  'choice2': 'He ate a big lunch.',
  'question': 'cause',
  'idx': 55,
  'label': 1,
  'prediction': 0},
 {'premise': 'My eyes became red and puffy.',
  'choice1': 'I was sobbing.',
  'choice2': 'I was laughing.',
  'question': 'cause',
  'idx': 7,
  'label': 0,
  'prediction': 1},
 {'premise': 'The bride got cold feet before the wedding.',
  'choice1': 'The wedding guests brought gifts.',
  'choice2': 'She called the wedding off.',
  'question': 'effect',
  'idx': 38,
  'label': 1,
  'prediction': 0},
 {'premise': 'The teacher assigned homework to the students.',
  'choice1': 'The students passed notes.',
  'choice2': 'The students groaned.',
  'question': 'effect',
  'idx': 72,
  

## Part 2: Chain-of-Thought Prompting

In this part, you will add a reasoning step to the examples, following [Wei et al. (2022)](https://arxiv.org/abs/2201.11903). Take a look at the paper the get an idea of a good rationale / reasoning chain. Note that we are still following the in-context / few-shot setup, so you will need to manually come up with a rationale for each of the in-context examples.

First, let us create a new function `single_example_prompt_with_cot` that adds the rationale to the prompt. 

In [15]:
def single_example_prompt_with_cot(example, include_answer=False):
    prompt = f"Q: {example['premise']} {question_type_to_nl[example['question']]}" + \
             f"\n1) {example['choice1']}\n2) {example['choice2']}"

    if include_answer:
      prompt += f"\nRationale: {example['rationale']}"
      prompt += f"\nA: {example['label'] + 1}"

    return prompt


We will also update the `predict` function to return both the answer and the generated rationale, which you can use to analyze the errors made by the model.

In [16]:
def predict_with_cot(generator, prompt, gen_config=None):
    answer = generator(prompt, generation_config=gen_config)[0][
        "generated_text"][len(prompt):].strip()

    # If the model generated another instance, remove it
    if "Q:" in answer:
        answer = answer[:answer.index("Q:")]

    # If the model generated a rationale
    rationale = None
    m = re.search("Rationale:\s?([^\n]+)", answer)
    if m is not None:
        rationale = m.group(1)

    # Find "A: 1" or "A: 2"
    pred = None
    m = re.search("A:\s?([1-2])", answer)
    if m is not None:
        pred = int(m.group(1)) - 1

    return pred, rationale


def predict_with_cot_retry(generator, prompt, gen_config=None):
    """
    Retry generation if the answer format is not correct
    (i.e., containing "A: 1" or "A: 2"), making sure that
    there is no NULL when computing the accuracy.
    """

    pred = None
    retry_limit = 10
    for retry_cnt in range(retry_limit):  # Retry if the output format is not correct
        answer = generator(prompt, generation_config=gen_config)[0][
            "generated_text"][len(prompt):].strip()

        # If the model generated another instance, remove it
        if "Q:" in answer:
            answer = answer[:answer.index("Q:")]

        # If the model generated a rationale
        rationale = None
        m = re.search("Rationale:\s?([^\n]+)", answer)
        if m is not None:
            rationale = m.group(1)

        # Find "A: 1" or "A: 2"
        pred = None
        m = re.search("A:\s?([1-2])", answer)
        if m is not None:
            pred = int(m.group(1)) - 1
            break

    if pred is None:
        print(f"Tried {retry_limit} times, but the prediction is still None.\nPROMPT: {prompt}\n")
    # else:
    #     print(f"Final prediction: {pred}")

    return pred, rationale


Now, we need to add a `rationale` field for the in-context examples. Let's print the examples, and in the next cell, you will add the rationales.

In [17]:
for ex in in_context_examples:
    print(single_example_prompt(ex) + "\n")


Q: The woman felt lonely. What might have happened as a result?
1) She renovated her kitchen.
2) She adopted a cat.

Q: The mother needed help looking after her children. What might have happened as a result?
1) She sent the children to daycare.
2) She gave up custody of the children.

Q: I learned how to play the board game. What could have caused this?
1) My friend explained the rules to me.
2) My friend got the rules wrong.

Q: The woman's eyeglasses fogged up. What could have caused this?
1) She reclined by the pool.
2) She entered the sauna.

Q: I ran out of breath. What could have caused this?
1) I climbed several flights of stairs.
2) I read several chapters of the book.



In [18]:
in_context_examples_cot = in_context_examples.copy()

############################################
#  Complete the following code
############################################
in_context_examples_cot[0]["rationale"] = "The answer is 2 bacause: the woman adopted a cat to alleviate her feelings of loneliness, seeking companionship and emotional support."
in_context_examples_cot[1]["rationale"] = "The answer is 1 bacause: The mother sent her children to daycare to receive assistance with childcare responsibilities, enabling her to fulfill other obligations or work commitments."
in_context_examples_cot[2]["rationale"] = "The answer is 1 bacause: My friend's explanation of the rules facilitated my learning of the board game, providing clarity and guidance on gameplay mechanics."
in_context_examples_cot[3]["rationale"] = "The answer is 2 bacause: As she entered the sauna, the temperature change caused condensation on the woman's eyeglasses, resulting in fogging up due to the heat and moisture."
in_context_examples_cot[4]["rationale"] = "The answer is 1 bacause: Climbing several flights of stairs increased my physical exertion, leading to a rapid depletion of oxygen and causing me to run out of breath."

Let's look at an example prompt

In [19]:
prompt = create_prompt(in_context_examples_cot, val_set[0], single_fn=single_example_prompt_with_cot)
print(prompt)

Q: The woman felt lonely. What might have happened as a result?
1) She renovated her kitchen.
2) She adopted a cat.
Rationale: The answer is 2 bacause: the woman adopted a cat to alleviate her feelings of loneliness, seeking companionship and emotional support.
A: 2

Q: The mother needed help looking after her children. What might have happened as a result?
1) She sent the children to daycare.
2) She gave up custody of the children.
Rationale: The answer is 1 bacause: The mother sent her children to daycare to receive assistance with childcare responsibilities, enabling her to fulfill other obligations or work commitments.
A: 1

Q: I learned how to play the board game. What could have caused this?
1) My friend explained the rules to me.
2) My friend got the rules wrong.
Rationale: The answer is 1 bacause: My friend's explanation of the rules facilitated my learning of the board game, providing clarity and guidance on gameplay mechanics.
A: 1

Q: The woman's eyeglasses fogged up. What c

Let's see what the LLM generates for the target answer. We will now allow for more tokens to be generated, since the model is also expected to generate the rationale.

In [20]:
# Set the output length according the the rationales in the in context examples.
shortest_rationale = in_context_examples_cot[
    np.argmin([len(ex["rationale"].split()) for ex in in_context_examples_cot])
    ]["rationale"]

longest_rationale = in_context_examples_cot[
    np.argmax([len(ex["rationale"].split()) for ex in in_context_examples_cot])
    ]["rationale"]

min_tokens = len(tokenizer([f"\nRationale: {shortest_rationale} \nA: 1"])[0])
max_tokens = len(tokenizer([f"\nRationale: {longest_rationale} \nA: 1"])[0])

# Update the generation config accordingly
cot_gen_config = GenerationConfig(min_new_tokens=min_tokens, 
                                  max_new_tokens=max_tokens, 
                                  do_sample=True, top_p=0.9, 
                                  eos_token_id=tokenizer.eos_token_id, 
                                  pad_token_id=tokenizer.eos_token_id)


In [21]:
answer, rationale = predict_with_cot_retry(
    generator, create_prompt(
        in_context_examples_cot, val_set[0],
        single_fn=single_example_prompt_with_cot
    ),
    gen_config=cot_gen_config
)

print(single_example_prompt(val_set[0]))
print(f"Rationale: {rationale}\nA: {answer}")

Q: The man turned on the faucet. What might have happened as a result?
1) The toilet filled with water.
2) Water flowed from the spout.
Rationale: The answer is 1 bacause: The man turned on the faucet.
A: 0


Finally, let's predict the entire validation set, compute the accuracy, and save the predictions.  

In [22]:
val_predictions_cot, val_rationales = zip(*
    [predict_with_cot_retry(
        generator,
        create_prompt(
            in_context_examples_cot,
            ex,
            single_fn=single_example_prompt_with_cot),
        gen_config=cot_gen_config) 
    for ex in tqdm.notebook.tqdm(val_set)
    ])

print(f"Accuracy: {compute_accuracy(val_set, val_predictions_cot):.2f}%")

  0%|          | 0/100 [00:00<?, ?it/s]

Accuracy: 58.00%


In [23]:
res_fp_2 = "output/cot_predictions-{}.jsonl".format(MODEL_NAME.split("/")[-1])
print(f"Saving to {res_fp_2}")

with open(res_fp_2, "w", encoding="utf-8") as f_out:
    for ex, pred, rationale in zip(val_set, val_predictions_cot, val_rationales):
        new_ex = ex.copy()
        new_ex["prediction"] = pred
        new_ex["rationale"] = rationale
        f_out.write(json.dumps(new_ex) + "\n")


Saving to output/cot_predictions-gpt-neo-125m.jsonl


In [24]:
predicts = []
with open(res_fp_2, "r", encoding="utf-8") as f_in:
    for line in f_in:
        predicts.append(json.loads(line))

incorrect_predicts = []
for predict in predicts:
    if predict["label"] != predict["prediction"]:
        incorrect_predicts.append(predict)

accuracy = 1 - len(incorrect_predicts) / len(predicts)
print(f"The prediction accuray is {accuracy:.2f}")

selected_incorrect_predicts = random.sample(incorrect_predicts, 20)
selected_incorrect_predicts

The prediction accuray is 0.58


[{'premise': 'My ears were ringing.',
  'choice1': 'I went to a museum.',
  'choice2': 'I went to a concert.',
  'question': 'cause',
  'idx': 95,
  'label': 1,
  'prediction': 0,
  'rationale': 'The answer is 1 bacause: The museum suggested to me that I learn how to play the board game.'},
 {'premise': 'The driver got a flat tire.',
  'choice1': 'He went over the speed limit.',
  'choice2': 'He ran over a nail.',
  'question': 'cause',
  'idx': 25,
  'label': 1,
  'prediction': 0,
  'rationale': 'The answer is 1 bacause: Driving a car reduced my strength and strength training, allowing me to work on my running ability.'},
 {'premise': 'The pair of students came under scrutiny by the teacher.',
  'choice1': 'The students both received excellent grades.',
  'choice2': 'Their responses on the assignment were identical.',
  'question': 'cause',
  'idx': 42,
  'label': 1,
  'prediction': 0,
  'rationale': "The answer is 1 bacause: The students' responses on the assignment were identical."}

## Part 3: In-Context Learning with External Knowledge

In the last part, instead of letting the LLM generate its own rationales based on the manually-curated rationales for the in-context examples, we will attach a "rationale" from ConceptNet to each example (in-context and target). The first step will be to implement the `retrieve_knowledge` function, that gets a COPA instance and returns a rationale from ConceptNet. Base on your code from assignment 1 and choose one path to include in a natural language representation.

In [25]:
import requests
from typing import Tuple


class ConceptNet:

    def __init__(self):
        self.prefix_url = "https://api.conceptnet.io"
        self.prefix_cid = "/c/en/"  # the prefix of every concept_id

    def get_concept_id(self, word: str) -> str:
        return self.prefix_cid + word.strip()

    def get_url(self, cid: str) -> str:
        return self.prefix_url + cid.strip()

    @staticmethod
    def get_concept(url: str, verbose: bool = False) -> dict:
        try:
            response = requests.get(url.strip()).json()
        except Exception as e:
            if verbose:
                print(f">>> >>> get_concept Exception: {e}")
            response = dict()
        return response

    @staticmethod
    def get_next_node(cid: str, edge: dict) -> Tuple[str, float, str, int]:
        if edge["end"]["@id"] != cid:
            next_id, direction = edge["end"]["@id"], 1
        else:
            next_id, direction = edge["start"]["@id"], 0
        return next_id, edge["weight"], edge["rel"]["label"], direction


In [26]:
from typing import List
from keybert import KeyBERT
from sentence_transformers import SentenceTransformer, util


class TextParser:

    def __init__(self):
        self.kw_model = KeyBERT()
        self.st_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

    def get_keywords_keybert(self, text: str, n_bag: int = 1) -> List[str]:
        keywords = self.kw_model.extract_keywords(text, keyphrase_ngram_range=(1, n_bag), stop_words="english")
        keywords = [w[0].strip() for w in keywords]
        return keywords

    def keyword_sort(self, references: List[str], keywords: List[str]) -> List[str]:
        """
        sort keywords by computing the average similarity between the current kw and each ref in references
        """
        if len(keywords) <= 1:
            return keywords
        assert len(references) >= 1, f"Assertion error: len(references) is {len(references)}, NOT >= 1"

        ref_emb_list = [self.st_model.encode(w, convert_to_tensor=True) for w in references]
        kw_emb_list = [self.st_model.encode(w, convert_to_tensor=True) for w in keywords]

        kw_simi_list = []
        for idx, kw_emb in enumerate(kw_emb_list):
            kw_word = keywords[idx]
            simi_list = [float(util.pytorch_cos_sim(kw_emb, ref_emb)) for ref_emb in ref_emb_list]
            assert len(simi_list) >= 1
            avg_simi = float(sum(simi_list) / len(simi_list))
            kw_simi_list.append((kw_word, avg_simi))

        kw_simi_list.sort(key=lambda x: x[1], reverse=True)
        res_list = [kw_simi[0] for kw_simi in kw_simi_list]

        return res_list


In [27]:
import os
import json


REL_TO_TEMPLATE = {
    "relatedto": "[w1] is like [w2]",
    "externalurl": "[w1] is described at the following URL [w2]",
    "formof": "[w1] is a form of the word [w2]",
    "isa": "[w1] is a type of [w2]",
    "notisa": "[w1] is not [w2]",
    "partof": "[w1] is part of [w2]",
    "usedfor": "[w1] is used for [w2]",
    "capableof": "[w1] can [w2]",
    "atlocation": "You are likely to find [w1] in [w2]",
    "causes": "Sometimes [w1] causes [w2]",
    "hasa": "[w1] has [w2]",
    "hassubevent": "Something you do when you [w1] is [w2]",
    "hasfirstsubevent": "the first thing you do when you [w1] is [w2]",
    "haslastsubevent": "the last thing you do when you [w1] is [w2]",
    "hasprerequisite": "In order for [w1] to happen, [w2] needs to happen",
    "hasproperty": "[w1] is [w2]",
    "hascontext": "[w1] is a word used in the context of [w2]",
    "motivatedbygoal": "You would [w1] because you want to [w2]",
    "obstructedby": "[w1] can be prevented by [w2]",
    "desires": "[w1] wants [w2]",
    "createdby": "[w1] is created by [w2]",
    "synonym": "[w1] and [w2] have similar meanings",
    "antonym": "[w1] is the opposite of [w2]",
    "distinctfrom": "it cannot be both [w1] and [w2]",
    "derivedfrom": "the word [w1] is derived from the word [w2]",
    "definedas": "[w1] is defined as [w2]",
    "entails": "if [w1] is happening, [w2] is also happening",
    "mannerof": "[w1] is a specific way of doing [w2]",
    "locatednear": "[w1] is located near [w2]",
    "dbpedia": "[w1] is conceptually related to [w2]",
    "similarto": "[w1] is similar to [w2]",
    "etymologicallyrelatedto": "the word [w1] and the word [w2] have the same origin",
    "etymologicallyderivedfrom": "the word [w1] comes from the word [w2]",
    "causesdesire": "[w1] makes people want [w2]",
    "madeof": "[w1] is made of [w2]",
    "receivesaction": "[w1] can be [w2]",
    "instanceof": "[w1] is an example of [w2]",
    "notdesires": "[w1] does not want [w2]",
    "notusedfor": "[w1] is not used for [w2]",
    "notcapableof": "[w1] is not capable of [w2]",
    "nothasproperty": "[w1] does not have the property of [w2]",
    "notmadeof": "[w1] is not made of [w2]"
}


class TextConverter:

    def __init__(self, path):
        self.node_list, self.w_list, self.r_list = path

    def convert(self, verbose: bool = False) -> dict:
        node_relation_list = []
        prompt_list = []
        description_list = []
        assert len(self.node_list) == len(self.w_list) + 1 == len(self.r_list) + 1

        for i in range(len(self.r_list)):
            if self.r_list[i][1] == 0:  # forward
                src_rel_tgt = self.node_list[i], self.r_list[i][0], self.node_list[i + 1]
            else:  # backward
                src_rel_tgt = self.node_list[i + 1], self.r_list[i][0], self.node_list[i]

            start_word = self.extract_word(src_rel_tgt[0])
            rel = src_rel_tgt[1]
            end_word = self.extract_word(src_rel_tgt[2])

            prompt = ""
            cur_description = REL_TO_TEMPLATE[rel.lower()].replace("[w1]", start_word).replace("[w2]", end_word)

            if verbose:
                print(cur_description)

            node_relation_list.append(src_rel_tgt)
            prompt_list.append(prompt)
            description_list.append(cur_description)

        start_node = self.extract_word(self.node_list[0])
        end_node = self.extract_word(self.node_list[-1])

        data_dict = {
            "start_node": start_node,
            "end_node": end_node,
            "node_list": self.node_list,
            "w_list": self.w_list,
            "r_list": self.r_list,
            "node_relation_list": node_relation_list,
            "prompt_list": prompt_list,
            "description_list": description_list,
        }

        return data_dict

    @staticmethod
    def extract_word(node: str) -> str:
        node_words = node.split("/")
        concept_word = node_words[3].replace("_", " ").strip()
        return concept_word


In [28]:
import os
import time
from typing import Tuple, List, Set

import spacy
from keybert import KeyBERT
from sentence_transformers import SentenceTransformer, util


class DijkstraSearchBiSource:

    def __init__(self):
        nlp = spacy.load("en_core_web_md")
        self.nlp = nlp
        self.kw_model = KeyBERT()
        self.st_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
        self.src_word = ""
        self.tgt_word = ""
        self.src_emb = None
        self.tgt_emb = None
        self.simi_threshold = 0.3  # if the similarity between the next node and target one is low, ignore this node

    def word_similarity(self, word_1: str, word_2: str) -> float:
        string = f"{word_1.strip()} {word_2.strip()}"
        tokens = self.nlp(string)
        similarity = tokens[0].similarity(tokens[1])
        return similarity

    def dijkstra_path_search(self, src_word: str, tgt_word: str, max_depth: int = 10, verbose: bool = False):
        src_word = src_word.strip()
        tgt_word = tgt_word.strip()
        self.src_word = src_word
        self.tgt_word = tgt_word
        self.src_emb = self.st_model.encode(src_word, convert_to_tensor=True)
        self.tgt_emb = self.st_model.encode(tgt_word, convert_to_tensor=True)
        src_tgt_similarity_sent = float(util.pytorch_cos_sim(self.src_emb, self.tgt_emb))
        if verbose:
            print(f">>> dijkstra_path_search (bi-source weighted BFS; max_depth = {max_depth}): "
                  f"From \"{src_word}\" To \"{tgt_word}\" (Similarity: {src_tgt_similarity_sent:.3f})")

        # self.simi_threshold = src_tgt_similarity_sent / 2
        self.simi_threshold = src_tgt_similarity_sent * 2 / 3

        bfs_src = [[], []]  # the list of concept ids (source nodes of the current BFS)
        bfs_tgt = [[], []]  # the list of concept ids (target nodes of the current BFS)
        visited = [set(), set()]  # avoid looping
        weights = [dict(), dict()]  # edge weight in Dij alg
        prev_node = [dict(), dict()]  # to backtrace: prev_node[next_node] = cur_node
        relations = [dict(), dict()]  # to backtrace: relations[next_node] = [relation, direction]
        matches = list()  # matched nodes (~= tgt_word)
        matches_set = set() # set(matches)

        continue_bfs_flag = True

        conceptNet = ConceptNet()  # ConceptNet toolkit

        # Initialization
        src_node = conceptNet.get_concept_id(src_word)  # word to concept_id (node id)
        tgt_node = conceptNet.get_concept_id(tgt_word)

        bfs_src[0].append(src_node)
        bfs_tgt[0].append(tgt_node)
        weights[0][src_node] = 0
        prev_node[0][src_node] = None  # root node
        relations[0][src_node] = None

        bfs_src[1].append(tgt_node)
        bfs_tgt[1].append(src_node)
        weights[1][tgt_node] = 0
        prev_node[1][tgt_node] = None  # root node
        relations[1][tgt_node] = None

        for cur_depth in range(max_depth >> 1):
            if not continue_bfs_flag:
                break
            timer_s = time.perf_counter()
            # if verbose:
            #     print(f">>> >>> cur_depth: {cur_depth}; len(bfs_src): {len(bfs_src)}")
            if len(bfs_src[0]) == 0 and len(bfs_src[1]) == 0:
                break

            ignore_node_cnt = [0, 0]

            next_bfs = [[], []]  # the nodes of the next depth
            for s_idx in range(2):  # search index of the current source nodes
                t_set = set(bfs_tgt[s_idx])  # target node set (for matching)
                for cur_node in bfs_src[s_idx]:  # deal with all the nodes of the current depth
                    # Get the current node (concept id)
                    if cur_node not in visited[s_idx]:  # avoid visiting the same node and forming a loop in prev_node
                        visited[s_idx].add(cur_node)
                    else:
                        continue
                    # if verbose:
                    #     print(f"{cur_node}; cur_depth: {cur_depth}; len(bfs_src): {len(bfs_src)}")
                    assert isinstance(cur_node, str) and cur_node.startswith(conceptNet.prefix_cid)
                    if cur_node in matches_set:
                        continue

                    # Match the current node with the target word
                    # match_result = self.match(tgt_word, cur_node)
                    match_result = self.match_set(t_set, cur_node)
                    if match_result > 0:  # match the target word
                        if cur_node not in matches_set:
                            matches_set.add(cur_node)
                            matches.append(cur_node)
                        # if verbose:
                        #     print(f">>> *** Number {len(matches_set[s_idx])} matched node: {cur_node}")
                        # Once matching a word in the same type/category as the target word, end the Dij process
                        continue_bfs_flag = False  # end the whole Dij search process of the next loop
                        # break

                    # Get the corresponding concept from ConceptNet
                    cur_url = conceptNet.get_url(cur_node)
                    concept = conceptNet.get_concept(cur_url, verbose=verbose)
                    if not isinstance(concept, dict) or "edges" not in concept:
                        continue

                    # Dealing with its neighbors (related concepts)
                    edges = concept["edges"]
                    if not isinstance(edges, list) or len(edges) == 0:
                        continue
                    for edge in edges:
                        assert isinstance(edge, dict)
                        next_node, weight, relation, direction = conceptNet.get_next_node(cur_node, edge)
                        if next_node in visited[s_idx] or not next_node.startswith(conceptNet.prefix_cid):
                            continue

                        # Optimize by ignoring irrelevant nodes based on similarity
                        next_node_name = next_node.split("/")[-1].replace("_", " ")
                        next_emb = self.st_model.encode(next_node_name, convert_to_tensor=True)
                        if s_idx == 0:  # src-tgt search
                            cur_similarity_sent = float(util.pytorch_cos_sim(next_emb, self.tgt_emb))
                        else:  # tgt-src search
                            cur_similarity_sent = float(util.pytorch_cos_sim(next_emb, self.src_emb))
                        if cur_similarity_sent < self.simi_threshold:
                            ignore_node_cnt[s_idx] += 1
                            continue

                        # Add the new node to the path, update the weight.
                        if next_node not in weights[s_idx]:
                            assert next_node not in prev_node[s_idx], "Node hasn't been visited."
                            next_bfs[s_idx].append(next_node)  # next_node will be dealt in the next depth
                            weights[s_idx][next_node] = weights[s_idx][cur_node] + weight
                            prev_node[s_idx][next_node] = cur_node
                            relations[s_idx][next_node] = [relation, direction]
                        # If the weight of a new path is greater than the path before, replace the most likely path.
                        elif weights[s_idx][next_node] < weights[s_idx][cur_node] + weight:
                            assert next_node in prev_node[s_idx], "Node has been visited."
                            weights[s_idx][next_node] = weights[s_idx][cur_node] + weight
                            prev_node[s_idx][next_node] = cur_node
                            relations[s_idx][next_node] = [relation, direction]

            timer_e = time.perf_counter()
            timer_sec, timer_min = timer_e - timer_s, (timer_e - timer_s) / 60
            if verbose:
                log_text = f">>> >>> Depth [{cur_depth + 1}] Time {timer_sec:.1f} sec ({timer_min:.1f} min): "
                log_text += f"[src->tgt] BFS source {len(bfs_src[0])}; " \
                            f"Visited neighbors {len(next_bfs[0])} (ignored {ignore_node_cnt[0]}); "
                log_text += f"[tgt->src] BFS source {len(bfs_src[1])}; " \
                            f"Visited neighbors {len(next_bfs[1])} (ignored {ignore_node_cnt[1]})"
                print(log_text)

            cur_match_set = set(next_bfs[0]) & set(next_bfs[1])
            if len(cur_match_set) > 0:
                for cur_match_node in cur_match_set:
                    matches_set.add(cur_match_node)
                    matches.append(cur_match_node)
                break

            bfs_src[0] = next_bfs[0]  # source nodes of the next depth (the src-tgt direction)
            bfs_tgt[0] = next_bfs[1]  # target nodes of the next depth (the src-tgt direction)
            bfs_src[1] = next_bfs[1]  # source nodes of the next depth (the tgt-src direction)
            bfs_tgt[1] = next_bfs[0]  # target nodes of the next depth (the tgt-src direction)

        # Return the path list
        path_list = []
        for matched_node in matches:
            path_list.append(self.get_path(matched_node, prev_node, weights, relations))
        if verbose:
            if len(path_list) == 0:
                print(f">>> Path between {src_word} and {tgt_word} does not exist or is too long (> {max_depth}).")
            else:
                print(f">>> Found {len(path_list)} path(s) between \"{src_word}\" and \"{tgt_word}\".")

        return path_list

    @staticmethod
    def match(tgt_word: str, node: str) -> int:
        # string preproc
        tgt_word = tgt_word.strip().replace("_", " ")
        node_words = node.split("/")
        node_words = [w.strip().replace("_", " ") for w in node_words]

        # match the main concept
        if len(node_words) >= 4:  # e.g., "/c/en/food"
            if tgt_word == node_words[3]:
                return 1  # exact match the target word

        # match the sub-concept
        if len(node_words) >= 7:  # e.g., "/c/en/butter/n/wn/food"
            if tgt_word == node_words[6]:
                return 2  # exact match a word in the same type/category of the target word

        return 0  # not matched

    @staticmethod
    def match_set(tgt_word_set: Set[str], node: str) -> int:
        return 1 if node in tgt_word_set else 0

    @staticmethod
    def get_path(node: str, prev_node: List[dict], weights: List[dict],
                    relations: List[dict], do_print: bool = False) -> Tuple[list, list, list]:
        node_list = [[], []]
        w_list = [[], []]
        r_list = [[], []]
        for s_idx in range(2):
            cur_node = node
            while isinstance(cur_node, str):
                node_list[s_idx].append(cur_node)
                cur_r = relations[s_idx][cur_node]
                r_list[s_idx].append(cur_r)
                cur_w = weights[s_idx][cur_node]
                w_list[s_idx].append(cur_w)
                if do_print:
                    print(f"{node}, with a weight of {cur_w}.")
                    if cur_r[1] == 0:
                        print(f"Backward, relationship is {cur_r[0]}.")
                    else:
                        print(f"Forward, relationship is {cur_r[0]}")
                cur_node = prev_node[s_idx][cur_node]

        node_list = list(reversed(node_list[0])) + node_list[1][1:]
        w_list = list(reversed(w_list[0][:-1])) + w_list[1][:-1]
        r_list = [[item[0], (item[1] + 1) % 2] for item in list(reversed(r_list[0][:-1]))] + r_list[1][:-1]

        assert len(node_list) == len(w_list) + 1
        return node_list, w_list, r_list

    def print_path(self, node: str, prev_node: List[dict], weights: List[dict], relations: List[dict]) -> None:
        print(f"Valid Path:")
        node_list, w_list, r_list = self.get_path(node, prev_node, weights, relations, do_print=True)
        print(f"Path ends. Path Length: {len(w_list)}; Weight Sum: {sum(w_list)}\n")


In [None]:
dij = DijkstraSearchBiSource()  # bi-source weighted BFS
textParser = TextParser()

In [30]:
def get_keywords(example, max_number_kw: int = 3, n_bag: int = 1, verbose: bool = False):
    # import spacy

    # question = example["premise"] + " " + question_type_to_nl[example["question"]]  # question text
    question = example["premise"]  # question text
    choice_key = "choice" + str(example["label"] + 1)  # answer key
    choice = example[choice_key]  # answer text

    # nlp = spacy.load("en_core_web_sm")
    # q_nlp = nlp(question)
    # c_nlp = nlp(choice)

    q_keywords = textParser.get_keywords_keybert(question, n_bag=n_bag)  # get keywords of the question
    q_keywords = list(set(q_keywords))  # remove duplication

    c_keywords = textParser.get_keywords_keybert(choice, n_bag=n_bag)  # get keywords of the correct choice
    c_keywords = list(set(c_keywords))  # remove duplication

    # sort c_keywords by computing the average similarity between the current c_kw and each q_kw in q_keywords
    c_keywords = textParser.keyword_sort(references=q_keywords, keywords=c_keywords)
    c_keywords = c_keywords[:max_number_kw]
    q_keywords = q_keywords[:max_number_kw]

    if verbose:
        print(f"question: {question}")
        print(f"answer: {choice}")
        print(f"q_keywords: {q_keywords}")
        print(f"c_keywords: {c_keywords}")

    return q_keywords, c_keywords


In [31]:
import time
from typing import Optional


def retrieve_knowledge(
    example,
    e_id: int,
    kw_pairs: Optional[list] = None,
    pair_cnt: int = -1,
    max_depth: int = 6,
    verbose: bool = False,
) -> str:
    ############################################
    #  Complete the following code
    ############################################

    timer_start_example = time.perf_counter()

    if not (isinstance(kw_pairs, list) and len(kw_pairs) > 0):
        q_keywords, c_keywords = get_keywords(example, max_number_kw=3, verbose=verbose)

        kw_pairs = []
        for ck in c_keywords:
            for qk in q_keywords:
                if ck == qk:  # skip kw_pair that q_keyword == c_keyword (path length will be 0)
                    continue
                kw_pairs.append((ck, qk))

        assert len(kw_pairs) >= 1, f"Assertion Error: len(kw_pairs) = {len(kw_pairs)}"
        if pair_cnt > 0:
            kw_pairs = kw_pairs[: pair_cnt]  # only use the first pair_cnt of kw_pair that q_keyword is most relavant to c_keyword

    if verbose:
        print(f"Todo kw_pairs: {kw_pairs}")

    best_p_list = []  # best_path and best_config of each Dij search
    rationale_dict_list = []
    rationale_list = []
    rationale_id = 1

    for src_word, tgt_word in kw_pairs:
        if src_word == tgt_word:
            continue

        path_list = dij.dijkstra_path_search(src_word, tgt_word, verbose=False, max_depth=max_depth)

        if len(path_list) == 0:
            continue

        # There could be multiple paths from the src_word to tgt_word.
        # We select the best path by choosing the path with the largest average edge weights.
        best_path = [[], [], []]  # For the current pair, the best path of all returned paths
        best_config = [0.0, 0.0, 0.0]  # the best path_len, w_sun, w_avg
        for path in path_list:
            node_list, w_list, r_list = path
            assert len(node_list) == len(w_list) + 1
            path_len = len(w_list)
            w_sum = sum(w_list)
            w_avg = float(w_sum / path_len) if path_len > 0 else 0.0
            if w_avg > best_config[2]:  # update the best path
                best_path = path
                best_config = [path_len, w_sum, w_avg]

        best_p_list.append((best_path, best_config))

        textConverter = TextConverter(best_path)
        data_dict = textConverter.convert(verbose=False)

        cur_desc_list = [desc.strip() for desc in data_dict["description_list"]]
        cur_rationale = [desc.capitalize() for desc in cur_desc_list]
        cur_rationale = "; ".join(cur_rationale) + "."
        # cur_rationale = f"Conceptually, \"{src_word}\" is related to \"{tgt_word}\" because: " + cur_rationale
        # cur_rationale = f"{rationale_id}. \"{src_word}\" is conceptually related to \"{tgt_word}\" because: " + cur_rationale
        cur_rationale = f"\"{src_word}\" is conceptually related to \"{tgt_word}\" because: " + cur_rationale
        rationale_id += 1

        rationale_dict = {
            "src_word": src_word,
            "tgt_word": tgt_word,
            "description_list": cur_desc_list,
            "rationale": cur_rationale
        }
        rationale_dict_list.append(rationale_dict)
        rationale_list.append(cur_rationale)

    if len(rationale_list) > 0:
        # final_rationale = "\n".join(rationale_list)
        final_rationale = " ".join(rationale_list)
        print(f">>> Rationale: {final_rationale}")
    else:
        print(f">>> No rationale (path).")
        final_rationale = "None."

    timer_end_example = time.perf_counter()
    time_sec, time_min = timer_end_example - timer_start_example, (timer_end_example - timer_start_example) / 60
    if verbose:
        print(f"DONE Example {e_id} - Running Time: {time_sec:.1f} sec ({time_min:.1f} min)\n")

    return final_rationale


Add the rationale for each of the in-context and validation set examples.

In [32]:
for e_id, ex in enumerate(in_context_examples):
    question = ex["premise"]  # question text
    choice_key = "choice" + str(ex["label"] + 1)  # answer key
    choice = ex[choice_key]  # answer text
    print(f"Question: {question}")
    print(f"Answer: {str(ex['label'] + 1)}) {choice}\n")


Question: The woman felt lonely.
Answer: 2) She adopted a cat.

Question: The mother needed help looking after her children.
Answer: 1) She sent the children to daycare.

Question: I learned how to play the board game.
Answer: 1) My friend explained the rules to me.

Question: The woman's eyeglasses fogged up.
Answer: 2) She entered the sauna.

Question: I ran out of breath.
Answer: 1) I climbed several flights of stairs.



In [33]:
in_context_examples_kw_pairs_list = [
    [("lonely", "cat"), ("woman", "adopt")],
    [("mother", "children"), ("children", "daycare")],
    [("learn", "game"), ("game", "rules")],  # ("learn", "rules")
    [("eyeglasses", "fog"), ("fog", "sauna")],
    [("breath", "climb"), ("climb", "stair")],
]

In [34]:
print(f"The number of in_context_examples is {len(in_context_examples)}")

in_context_examples_kb = in_context_examples.copy()

assert len(in_context_examples_kw_pairs_list) == len(in_context_examples_kb)
for e_id, ex in enumerate(in_context_examples_kb):
    # cur_rationale = retrieve_knowledge(ex, e_id, kw_pairs=in_context_examples_kw_pairs_list[e_id], pair_cnt=1, verbose=True)
    cur_rationale = retrieve_knowledge(ex, e_id, kw_pairs=in_context_examples_kw_pairs_list[e_id], pair_cnt=2, verbose=True)
    ex["rationale"] = cur_rationale


The number of in_context_examples is 5
Todo kw_pairs: [('lonely', 'cat'), ('woman', 'adopt')]
>>> Rationale: "woman" is conceptually related to "adopt" because: Lady is like woman; Lady is like female; Girl is like female; Chick is like girl; Chick is like baby; Baby is like child; Adopt is like child.
DONE Example 0 - Running Time: 66.9 sec (1.1 min)

Todo kw_pairs: [('mother', 'children'), ('children', 'daycare')]
>>> Rationale: "mother" is conceptually related to "children" because: Daughter is like mother; Daughter is like child; You are likely to find child in school; You are likely to find children in school.
DONE Example 1 - Running Time: 24.8 sec (0.4 min)

Todo kw_pairs: [('learn', 'game'), ('game', 'rules')]
>>> Rationale: "learn" is conceptually related to "game" because: Something you do when you reading is learn; Sometimes reading causes learning; Playing is used for learning; Toy is like playing; Toy is like fun; Play is like fun; Game is like play. "game" is conceptually

In [35]:
for e_id, ex in enumerate(in_context_examples_kb):
    cur_rationale = ex["rationale"]
    print(cur_rationale)
    print()


"woman" is conceptually related to "adopt" because: Lady is like woman; Lady is like female; Girl is like female; Chick is like girl; Chick is like baby; Baby is like child; Adopt is like child.

"mother" is conceptually related to "children" because: Daughter is like mother; Daughter is like child; You are likely to find child in school; You are likely to find children in school.

"learn" is conceptually related to "game" because: Something you do when you reading is learn; Sometimes reading causes learning; Playing is used for learning; Toy is like playing; Toy is like fun; Play is like fun; Game is like play. "game" is conceptually related to "rules" because: Game is like play; Play is like fun; Toy is like fun; Toy is like playing; Card is like playing; Card is like king; Rule is like king; Rule is like law; Law is like rules.

"eyeglasses" is conceptually related to "fog" because: Pair of glasses and eyeglasses have similar meanings; Pair of glasses is like lens; Lens is part of e

In [36]:
print(f"The number of val_set is {len(val_set)}")

val_set_kb = val_set.copy()

for e_id, ex in enumerate(tqdm.notebook.tqdm(val_set_kb)):
    # cur_rationale = retrieve_knowledge(ex, e_id, pair_cnt=1, verbose=True)
    cur_rationale = retrieve_knowledge(ex, e_id, pair_cnt=2, verbose=True)
    ex["rationale"] = cur_rationale


The number of val_set is 100


  0%|          | 0/100 [00:00<?, ?it/s]

question: The man turned on the faucet.
answer: Water flowed from the spout.
q_keywords: ['man', 'turned', 'faucet']
c_keywords: ['flowed', 'water', 'spout']
Todo kw_pairs: [('flowed', 'man'), ('flowed', 'turned')]
>>> Rationale: "flowed" is conceptually related to "man" because: Backwash is like flowed; Backwash is like food; Farmer is like food; Farmer is like man. "flowed" is conceptually related to "turned" because: Backwash is like flowed; Backwash is like backward; Backward is similar to backward; Transposed is similar to backward; Reversed and transposed have similar meanings; Turned is similar to reversed; Turned and turned have similar meanings.
DONE Example 0 - Running Time: 48.5 sec (0.8 min)

question: The girl found a bug in her cereal.
answer: She lost her appetite.
q_keywords: ['cereal', 'girl', 'bug']
c_keywords: ['appetite', 'lost']
Todo kw_pairs: [('appetite', 'cereal'), ('appetite', 'girl')]
>>> No rationale (path).
DONE Example 1 - Running Time: 96.8 sec (1.6 min)



* If `pair_cnt=1`, run time: about 1h 30m
* If `pair_cnt=2`, run time: about 3h

Now, we need to change the prompt creation functions to include this rationale.

In [37]:
def single_example_prompt_with_kb(example, include_answer=False):
    prompt = f"Q: {example['premise']} {question_type_to_nl[example['question']]}" + \
             f"\n1) {example['choice1']}\n2) {example['choice2']}" + \
             f"\nRationale: {example['rationale']}\n"

    if include_answer:
      prompt += f"\nA: {example['label'] + 1}"

    return prompt


Let's look at an example prompt

In [38]:
prompt = create_prompt(in_context_examples_kb, val_set_kb[0], 
                       single_fn=single_example_prompt_with_kb)
print(prompt)

Q: The woman felt lonely. What might have happened as a result?
1) She renovated her kitchen.
2) She adopted a cat.
Rationale: "woman" is conceptually related to "adopt" because: Lady is like woman; Lady is like female; Girl is like female; Chick is like girl; Chick is like baby; Baby is like child; Adopt is like child.

A: 2

Q: The mother needed help looking after her children. What might have happened as a result?
1) She sent the children to daycare.
2) She gave up custody of the children.
Rationale: "mother" is conceptually related to "children" because: Daughter is like mother; Daughter is like child; You are likely to find child in school; You are likely to find children in school.

A: 1

Q: I learned how to play the board game. What could have caused this?
1) My friend explained the rules to me.
2) My friend got the rules wrong.
Rationale: "learn" is conceptually related to "game" because: Something you do when you reading is learn; Sometimes reading causes learning; Playing is 

Finally, let's predict the answers, compute the accuracy, and save the predictions.

In [39]:
val_predictions_kb = [predict_retry(
    generator, create_prompt(
        in_context_examples_kb, ex, single_fn=single_example_prompt_with_kb), 
        gen_config=gen_config)
    for ex in tqdm.notebook.tqdm(val_set_kb)]


  0%|          | 0/100 [00:00<?, ?it/s]

In [40]:
predicts = val_predictions_kb
incorrect_predicts = []
for idx in range(len(predicts)):
    predict = predicts[idx]
    gold = val_set_kb[idx]
    if gold["label"] != predict:
        incorrect_predicts.append(predict)

accuracy = 1 - len(incorrect_predicts) / len(predicts)
print(f"The prediction accuray is {accuracy:.2f}")

The prediction accuray is 0.47


In [41]:
res_fp_3 = "output/kb_predictions-{}.jsonl".format(MODEL_NAME.split("/")[-1])
print(f"Saving to {res_fp_3}")

Saving to output/kb_predictions-gpt-neo-125m.jsonl


In [42]:
with open(res_fp_3, "w", encoding="utf-8") as f_out:
    for ex, pred in zip(val_set_kb, val_predictions_kb):
        new_ex = ex.copy()
        new_ex["prediction"] = pred
        f_out.write(json.dumps(new_ex) + "\n")


In [43]:
predicts = []
with open(res_fp_3, "r", encoding="utf-8") as f_in:
    for line in f_in:
        predicts.append(json.loads(line))

incorrect_predicts = []
for predict in predicts:
    if predict["label"] != predict["prediction"]:
        incorrect_predicts.append(predict)

accuracy = 1 - len(incorrect_predicts) / len(predicts)
print(f"The prediction accuray is {accuracy:.2f}")

selected_incorrect_predicts = random.sample(incorrect_predicts, 20)
selected_incorrect_predicts

The prediction accuray is 0.47


[{'premise': 'I tidied up my house.',
  'choice1': 'I was swamped with work.',
  'choice2': 'I was expecting company.',
  'question': 'cause',
  'idx': 96,
  'label': 1,
  'rationale': '"company" is conceptually related to "house" because: Company is like business; Bank is like business; Bank is like building; You are likely to find door in building; Door is like room; Room is like house.',
  'prediction': 0},
 {'premise': 'The man perceived that the woman looked different.',
  'choice1': 'The woman got her hair cut.',
  'choice2': 'The woman wore a bracelet.',
  'question': 'cause',
  'idx': 65,
  'label': 0,
  'rationale': '"woman" is conceptually related to "looked" because: Dress is like woman; Dress is like female; Girl is like female; Chick is like girl; Chick is like young; Boy is like young; Boy is like child; Child is like human; Baby is like human; Baby is like small; Drop is like small; Drop is like down; The word downlooked is derived from the word down; The word downlooked

## DONE