# Dataset

## Load Package

In [2]:
import os
import json
from tqdm import tqdm
from torch import tensor,matmul
import xml.etree.ElementTree as ET
from nltk.corpus import wordnet31 as wn
from langchain_openai import OpenAIEmbeddings
from _api_key import get_openai_api_key

## Basic Function

In [None]:
def load_json_file(file_path):
    """
    Load json file
    """
    with open(file_path,'r',encoding='utf-8') as f:
        file = json.load(f)
        f.close()
    return file

def save_json_file(file, file_path):
    """
    Save json file
    """
    with open(file_path,'w',encoding='utf-8') as f:
        json.dump(file, f, indent=4, ensure_ascii=False)
        f.close()

In [2]:
def getCorpus(filepath):
    """
    Enter the path of the .xml data file, Return dataset(list)
    """
    # Open the XML format file
    assert '.xml' in filepath
    tree = ET.parse(filepath)
    # Get the root element
    root = tree.getroot()
    # Iterate through the sub-elements to get dataset of all sentences
    # corpus[ID] = (word_id_list, split_sentence, pun_id)
    corpus = dict()
    for item in root:
        ID = item.attrib['id']
        data = []
        word_id_list, sentence, sense_list= [], [], []
        # Splice sentence
        for word in item:
            word_id_list.append(word.attrib['id'])
            sentence.append(word.text)
            # Check the location of the pun word
            if 'senses' in word.attrib:
                sense_list.append(list(word.attrib.values()))
        data.append(word_id_list)
        data.append(sentence)
        if len(sense_list) != 0:
            ind = [sense for __, sense in sense_list].index('2')
            pun_id = sense_list[ind][0]
            data.append(pun_id)
        corpus[ID] = tuple(data)
    return corpus


def getGold(filepath):
    """
    Get the gold standard for dataset (whether is a pun or not/double sense of the pun)
    """
    assert '.gold' in filepath
    with open(filepath, 'r') as f:
        temp = f.readlines()
        f.close()
    # golds[ID]= ((gold)label, )
    golds = dict()
    for item in temp:
        sp = '\t' if '\t' in item else ' '
        item = item.split('\n')[0].split(sp)
        ID = item[0]
        label = item[1:]
        if len(ID.split('_')) > 2:
            ID = '_'.join(ID.split('_')[0:2])
        golds[ID] = tuple(label)
    return golds


def getExPun(filepath):
    """
    Get the ID, the longest explanation, the best keyword set and the average funniness rating from the ExPun dataset
    """
    dataset = load_json_file(filepath)
    # ExPun[ID] = (explanation, keywords, rating)
    ExPun = dict()
    for data in dataset:
        ID = data['ID']
        try:
            is_a_joke = data['Is a Joke?']
        except:
            is_a_joke = data['Understand the Joke?']
        explanations = data['Natural language explanation']
        joke_keywords = data['Joke keywords']
        funniness = data['Funniness (1-5)']
        # Take out the joke explanation, keywords, and funniness rating
        explanations = [expl for i,expl in zip(is_a_joke,explanations) if i == 1 and len(expl) > 0]
        joke_keywords = [kw for i,kw in zip(is_a_joke,joke_keywords) if i == 1 and type(kw) == list]
        funniness = [f for i,f in zip(is_a_joke,funniness) if i == 1]
        if len(explanations) * len(joke_keywords) != 0:
            # Calculate the length of explanation and keywords
            expl_len = [len(expl) for expl in explanations]
            keywords_len = [[len(w.split(' ')) for w in kw] for kw in joke_keywords]
            keywords_len = [len(kw)-sum(kwl)/len(kwl) for kw, kwl in
                            zip(joke_keywords, keywords_len)]
            # Choose the longest explanation
            chosen_expl = explanations[expl_len.index(max(expl_len))]
            # Pick the keyword set with the largest number of keywords but a relatively small number of words per keyword
            chosen_keywords = joke_keywords[keywords_len.index(max(keywords_len))]
            rating = round(sum(funniness)/len(funniness), 2)
            ExPun[ID] = (chosen_expl, chosen_keywords, rating)
    return ExPun


def getPunDataset(corpus, golds, expun, puntype='hom'):
    """
    Obtain pun dataset[ID] = (pun word, sense1, alternative word, sense2, human (pun) text,
    human explanation, human keywords, human rating) \n
    Alternative word of homographic pun is just pun word
    """
    # corpus[ID] = (word_id_list, split_sentence, pun_id)
    # golds[ID]= ((gold)label, )
    # expun[ID] = (explanation, keywords, rating)
    IDs = list(set(corpus.keys()) & set(golds.keys()) & set(expun.keys()))
    IDs = sorted([ID for ID in IDs if puntype in ID], reverse=True)
    punDataset = dict()
    for ID in IDs:
        data = dict()
        # Pun word and its sense key
        pun_sense_key = golds[ID][0].split(';')
        pun_word = ' '.join(pun_sense_key[0].split('%')[0].split('_'))
        # Alternative word and its sense key
        alter_sense_key = golds[ID][1].split(';')
        alter_word = ' '.join(alter_sense_key[0].split('%')[0].split('_'))
        # Make sure the word sense can be retrieved by sense key
        try:
            pun_sense = [wn.synset_from_sense_key(sk).definition() for sk in pun_sense_key]
            alter_sense = [wn.synset_from_sense_key(sk).definition() for sk in alter_sense_key]
        except:
            continue
        # Human pun text, explanation, keywords
        pun_word_ind = corpus[ID][2]
        human_text = ' '.join(corpus[ID][1])
        human_explanation = expun[ID][0]
        human_keywords = expun[ID][1]
        human_rating = expun[ID][2]
        # Construct dataset
        data['pun_word'] = pun_word
        data['pun_sense_key'] = ';'.join(pun_sense_key)
        data['pun_sense'] = '; '.join(pun_sense)
        data['alter_word'] = alter_word
        data['alter_sense_key'] = ';'.join(alter_sense_key)
        data['alter_sense'] = '; '.join(alter_sense)
        data['pun_word_ind'] = pun_word_ind
        data['human_text'] = human_text
        data['human_explanation'] = human_explanation
        data['human_keywords'] = human_keywords
        data['human_rating'] = human_rating
        punDataset[ID] = data
    return punDataset


def getNonpunDataset(detectionSet, detectionGold, puntype='hom'):
    """
    Obtain non-pun dataset (with only text)
    """
    # Get all non-pun
    nonpunIDs = []
    for ID in detectionGold:
        gold = int(detectionGold[ID][0])
        if gold == 0 and puntype in ID:
            nonpunIDs.append(ID)
    nonpunDataset = dict()
    for ID in nonpunIDs:
        human_text = ' '.join(detectionSet[ID][1])
        nonpunDataset[ID] = {'human_text': human_text}
    return nonpunDataset


def splitDataset(file, puntype='hom'):
    """
    Enter path or json file to separate the pun part from the non-pun part of the dataset
    """
    if isinstance(file,str):
        dataset = load_json_file(file)
    else:
        dataset = file
    punDataset = dict()
    nonpunDataset = dict()
    for ID in dataset:
        data = dataset[ID]
        if puntype in ID:
            if data.get('pun_word', False):
                punDataset[ID] = data
            else:
                nonpunDataset[ID] = data
    return punDataset, nonpunDataset

In [3]:
def select_examples(embeddings_model, dataset, top_k:int=10):
    """
    Extract several typical data from the dataset (with the highest average similarity to other data) as examples of prompt words
    """
    pun_embeddings = []
    nonpun_embeddings = []
    # Embeddings of dataset
    IDs = list(dataset.keys())
    for ID in tqdm(IDs):
        data = dataset[ID]
        human_text = data['human_text']
        if data.get('pun_word', False):
            pun_embeddings.append((ID,embeddings_model.embed_query(human_text)))
        else:
            nonpun_embeddings.append((ID,embeddings_model.embed_query(human_text)))
    pun_embeddings_tensor = tensor([embed for ID,embed in pun_embeddings])
    nonpun_embeddings_tensor = tensor([embed for ID,embed in nonpun_embeddings])
    # Calculate average cos_sim as score
    pun_scores = []
    nonpun_scores = []
    for ID,embed in pun_embeddings:
        cos_sim = matmul(pun_embeddings_tensor, tensor(embed))
        score = round(float(cos_sim.mean()),4)
        pun_scores.append([ID, score])
    for ID,embed in nonpun_embeddings:
        cos_sim = matmul(nonpun_embeddings_tensor, tensor(embed))
        score = round(float(cos_sim.mean()),4)
        nonpun_scores.append([ID, score])
    # Sort scores to get the highest ones
    pun_scores.sort(key=lambda x:x[1], reverse=True)
    nonpun_scores.sort(key=lambda x:x[1], reverse=True)
    # Get IDs of examples
    pun_examples = [ID for ID,score in pun_scores[0:top_k]]
    nonpun_examples = [ID for ID,score in nonpun_scores[0:top_k]]
    print(f"# Examples of pun: {pun_examples}")
    print(f"# Examples of non-pun: {nonpun_examples}")
    IDs_examples = pun_examples + nonpun_examples
    examples = dict()
    remainings = dict()
    for ID in IDs:
        if ID in IDs_examples:
            examples[ID] = dataset[ID]
        else:
            remainings[ID] = dataset[ID]
    return remainings, examples

## Build Dataset

In [4]:
rootpath = os.path.dirname(os.path.dirname(os.getcwd()))
hom_interpretation_corpuspath = os.path.join(rootpath, r'semeval2017_task7/test/subtask3-homographic-test.xml')
hom_interpretation_goldpath = os.path.join(rootpath, r'semeval2017_task7/test/subtask3-homographic-test.gold')
het_interpretation_corpuspath = os.path.join(rootpath, r'semeval2017_task7/test/subtask3-heterographic-test.xml')
het_interpretation_goldpath = os.path.join(rootpath, r'semeval2017_task7/test/subtask3-heterographic-test.gold')

expun_path1 = os.path.join(rootpath, r'expun/data/expunations_annotated_full.json')
expun_path2 = os.path.join(rootpath, r'expun/data/expunations_annotated_pilot_100.json')

hom_interpretation_corpus = getCorpus(hom_interpretation_corpuspath)
hom_interpretation_golds = getGold(hom_interpretation_goldpath)
het_interpretation_corpus = getCorpus(het_interpretation_corpuspath)
het_interpretation_golds = getGold(het_interpretation_goldpath)

expun = getExPun(expun_path1)
expun.update(getExPun(expun_path2))

hom_punDataset = getPunDataset(corpus=hom_interpretation_corpus, golds=hom_interpretation_golds,
                               expun=expun, puntype='hom')
het_punDataset = getPunDataset(corpus=het_interpretation_corpus, golds=het_interpretation_golds,
                               expun=expun, puntype='het')

In [5]:
rootpath = os.path.dirname(os.path.dirname(os.getcwd()))
hom_detection_corpuspath = os.path.join(rootpath, r'semeval2017_task7/test/subtask1-homographic-test.xml')
hom_detection_goldpath = os.path.join(rootpath, r'semeval2017_task7/test/subtask1-homographic-test.gold')
het_detection_corpuspath = os.path.join(rootpath, r'semeval2017_task7/test/subtask1-heterographic-test.xml')
het_detection_goldpath = os.path.join(rootpath, r'semeval2017_task7/test/subtask1-heterographic-test.gold')

hom_detection_corpus = getCorpus(hom_detection_corpuspath)
hom_detection_golds = getGold(hom_detection_goldpath)
het_detection_corpus = getCorpus(het_detection_corpuspath)
het_detection_golds = getGold(het_detection_goldpath)

hom_nonpunDataset = getNonpunDataset(detectionSet=hom_detection_corpus, detectionGold=hom_detection_golds,
                                     puntype='hom')
het_nonpunDataset = getNonpunDataset(detectionSet=het_detection_corpus, detectionGold=het_detection_golds,
                                     puntype='het')

In [6]:
# Connect text-embedding-ada-002
model_name = 'text-embedding-ada-002'
openai_api_key = get_openai_api_key()  # use your api key
embeddings_model = OpenAIEmbeddings(model=model_name,openai_api_key=openai_api_key, request_timeout=120)

# Select examples
hom_dataset = dict(**hom_punDataset, **hom_nonpunDataset)
het_dataset = dict(**het_punDataset, **het_nonpunDataset)
hom_dataset, hom_examples = select_examples(embeddings_model=embeddings_model, dataset=hom_dataset)
# Examples of pun: ['hom_1404', 'hom_1356', 'hom_4', 'hom_1556', 'hom_1477', 'hom_1182', 'hom_488', 'hom_1283', 'hom_705', 'hom_1162']
# Examples of non-pun: ['hom_158', 'hom_533', 'hom_1167', 'hom_111', 'hom_639', 'hom_436', 'hom_750', 'hom_846', 'hom_1464', 'hom_1623']
het_dataset, het_examples = select_examples(embeddings_model=embeddings_model, dataset=het_dataset)
# Examples of pun: ['het_1751', 'het_530', 'het_633', 'het_325', 'het_496', 'het_519', 'het_345', 'het_621', 'het_1453', 'het_1774']
# Examples of non-pun: ['het_151', 'het_957', 'het_115', 'het_225', 'het_999', 'het_563', 'het_917', 'het_41', 'het_533', 'het_1041']

 84%|████████▎ | 1223/1463 [10:06<01:54,  2.09it/s]Retrying langchain.embeddings.openai.embed_with_retry.<locals>._embed_with_retry in 4.0 seconds as it raised ServiceUnavailableError: The server is overloaded or not ready yet..
100%|██████████| 1463/1463 [11:59<00:00,  2.03it/s]


# Examples of pun: ['hom_1404', 'hom_1356', 'hom_4', 'hom_1556', 'hom_1477', 'hom_1182', 'hom_488', 'hom_1283', 'hom_705', 'hom_1162']
# Examples of non-pun: ['hom_158', 'hom_533', 'hom_1167', 'hom_111', 'hom_639', 'hom_436', 'hom_750', 'hom_846', 'hom_1464', 'hom_1623']


100%|██████████| 1166/1166 [09:35<00:00,  2.03it/s]


# Examples of pun: ['het_1751', 'het_530', 'het_633', 'het_325', 'het_496', 'het_519', 'het_345', 'het_621', 'het_1453', 'het_1774']
# Examples of non-pun: ['het_151', 'het_957', 'het_115', 'het_225', 'het_999', 'het_563', 'het_917', 'het_41', 'het_533', 'het_1041']


In [8]:
# Count
hom_punDataset, hom_nonpunDataset = splitDataset(hom_dataset, puntype='hom')
het_punDataset, het_nonpunDataset = splitDataset(het_dataset, puntype='het')
print(f'hom: pun {len(hom_punDataset)}, non-pun {len(hom_nonpunDataset)}')
print(f'het: pun {len(het_punDataset)}, non-pun {len(het_nonpunDataset)}')
# Save
hom_dataset_save = r'./dataset/hom_dataset.json'
het_dataset_save = r'./dataset/het_dataset.json'
hom_examples_save = r'./dataset/hom_examples.json'
het_examples_save = r'./dataset/het_examples.json'

save_json_file(hom_dataset, hom_dataset_save)
save_json_file(het_dataset, het_dataset_save)
save_json_file(hom_examples, hom_examples_save)
save_json_file(het_examples, het_examples_save)

hom: pun 810, non-pun 633
het: pun 647, non-pun 499
