In [1]:
import json
import os
import random
import re
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
from nltk.stem import WordNetLemmatizer

lemmatizer = WordNetLemmatizer()

import json


def cal_metrics(labels, answers):
    accuracy = accuracy_score(labels, answers)
    precision = precision_score(labels, answers)
    recall = recall_score(labels, answers)
    f1 = f1_score(labels, answers)

    return {"acc": accuracy, "prec": precision, "recall": recall, "f1": f1}


def load_data(path, train_dataset=None, train=False):
    def remove_non_letters(s):
        return re.sub(r"^[^a-zA-Z]+|[^a-zA-Z]+$", "", s)

    with open("../data/dictionary.json", "r", encoding="utf-8") as f:
        dictionary = json.load(f)

    with open("../data/contexts.json", "r", encoding="utf-8") as f:
        contexts = json.load(f)

    if train_dataset is not None:
        with open(os.path.join(path, "distances.json"), "r") as f:
            distances = json.load(f)

        with open(os.path.join(path, "vals.json"), "r") as f:
            vals = json.load(f)

    if train:
        data_path = os.path.join(path, "train.tsv")
    else:
        data_path = os.path.join(path, "test.tsv")

    labels = []
    dataset = []

    with open(data_path, "r", encoding="utf-8") as f:
        lines = f.readlines()

    for index, line in enumerate(lines):
        if index == 0:
            continue
        cells = line.strip().split("\t")

        label = int(cells[1])
        sent = cells[2]
        POS = cells[3]
        v_index = int(cells[-1])
        word = remove_non_letters(sent.split()[v_index])

        splits = sent.split()
        splits.insert(v_index, "<tar>")
        splits.insert(v_index + 2, "</tar>")

        sample = {"sentence": sent, "word": word, "label": label, "pos": POS, "v_index": v_index, "s_sentence": " ".join(splits)}
        if train_dataset is not None:
            shots = random.sample(train_dataset, 10)

            sample["shots"] = shots
        if train_dataset is not None:
            samples_distances = distances[index - 1]
            samples_ids = vals[index - 1]
            sample["samples_distances"] = samples_distances
            sample["samples_knn"] = [train_dataset[_id] for _id in samples_ids]
        word = word.lower()
        base_words = [word]
        for pos in ["v", "a", "r", "s", "n"]:
            base_word = lemmatizer.lemmatize(word, pos)
            if base_word in base_words:
                continue
            base_words.append(base_word)

        for word in base_words:
            if word in dictionary:
                dict_info = dictionary[word.lower()]
                sample["dict_word"] = word
                sample["dict"] = dict_info
                break
            else:
                sample["dict"] = {}

        for word in base_words:
            if word in contexts:
                sample["pos_sent"] = contexts[word]["pos"][0]
                sample["neg_sent"] = contexts[word]["neg"][0]
                sample["exam_word"] = word
                break

        dataset.append(sample)
        labels.append(int(cells[1]))

    return dataset, labels
train_dataset, _ = load_data("../data/VUA18", train=True)

In [5]:
import os

from openai import OpenAI

llm_type = "gpt-4o-2024-08-06"
api_key = ""
api_base = ""
client = OpenAI(api_key=api_key, base_url=api_base)

from tqdm import tqdm

global_vars = {"EXAMNUM": 2}


def get_response(llm_type, prompt, temp=1.0):
    completion = client.chat.completions.create(
        model=llm_type,
        stream=False,
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt},
        ],
        temperature=temp,
        timeout=600,
    )
    return completion.choices[0].message.content


def load_prompts(dataset, func):
    res = []
    for data in dataset:
        res.append(func(data))
    return res

In [None]:
import time


def knn_cot_prompt_func(data):
    example_template = """Example {no}:
Sentence: {sentence}
Word: {word}
Answer: {answer}\n\n"""

    system = """Does the given word in the given sentence express metaphorical meaning? Please give an answer 'yes' or 'no' like examples below and give your explanation.\n\n"""

    prompt_template = """Sentence: {sentence}
Word: {word}
Answer: """
    examples_prompts = ""
    for index, example in enumerate(data["samples_knn"][: global_vars["EXAMNUM"]]):
        examples_prompts += example_template.format(
            no=index + 1, sentence=example["sentence"], word=example["word"], answer="yes" if example["label"] == 1 else "no"
        )

    return system + examples_prompts + "\n" + prompt_template.format(sentence=data["sentence"], word=data["word"])


for times in [1, 2, 3]:
    for data_name in [
        "MOH-X",
        "TroFi",
    ]:
        KNN=8
        setting = f"implicit-{KNN}-{times}"
        global_vars["EXAMNUM"] = KNN

        data_dir = f"../data/EVAL-samples/{data_name}"

        output_dir = f"./results/{setting}/{data_name}/{llm_type}"
        os.makedirs(output_dir, exist_ok=True)
        print(f"===================={output_dir}=====================")
        dataset, labels = load_data(data_dir, train_dataset=train_dataset)


        prompts = load_prompts(dataset, knn_cot_prompt_func)

        with open(os.path.join(output_dir, "prompts.json"), "w") as f:
            json.dump(prompts, f)
        while True:
            try:
                with open(os.path.join(output_dir, "responses.json"), "r") as f:
                    responses = json.load(f)
            except:
                responses = []
            if len(dataset) == len(responses):
                break

            try:
                for index, prompt in tqdm(enumerate(prompts[len(responses) :]), total=len(prompts) - len(responses)):
                    response = get_response(llm_type, prompt, temp=0.0)
                    responses.append(response)
                    if index % 50 == 0:
                        with open(os.path.join(output_dir, "responses.json"), "w") as f:
                            json.dump(responses, f)

            except Exception as e:
                with open(os.path.join(output_dir, "responses.json"), "w") as f:
                    json.dump(responses, f)
                time.sleep(5)
                continue

            with open(os.path.join(output_dir, "responses.json"), "w") as f:
                json.dump(responses, f)

        def extract_answer(response):
            if "yes" in response.lower():
                return 1
            else:
                return 0

        preds = [extract_answer(response) for response in responses]
        with open(os.path.join(output_dir, "preds.json"), "w") as f:
            json.dump(preds, f)

        metrics = cal_metrics(labels, preds)
        with open(os.path.join(output_dir, "metrics.json"), "w") as f:
            json.dump(metrics, f)