In [1]:
import openai
import time
from dotenv import load_dotenv
import os
import xml.etree.ElementTree as ET
import nltk
import datetime
import random
from nltk.corpus import wordnet as wn
from estnltk.wordnet import Wordnet as EstWordnet
import estnltk as et
nltk.download('wordnet')
estwn = EstWordnet()

[nltk_data] Downloading package wordnet to /Users/erudi/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
Downloading resources index: 20.1kB [00:00, 8.06MB/s]


In [2]:
load_dotenv()

openai.api_type = os.getenv("API_TYPE")
openai.api_key = os.getenv("API_KEY")
openai.api_base = os.getenv("API_BASE")
openai.api_version = os.getenv("API_VERSION")

In [3]:
WORD_ID = 1


def get_initial_prompt_xml(is_est=False):
    return f"""You are a highly skilled AI trained in language comprehension and WordNet generation. You will be given a word and you have to give all defenitions of the word and give an example.
            {"You will be given the word in Estonian. Meaning and example must be in Estonian." if is_est else ""}
            The output must contain only XML formatted answer. The XML must look like this:
            <definitions>
                <definition>
                    <word>[Given word]</word>
                    <type>[adjectives/adverbs/conjunctions/determiners/nouns/prepositions/pronouns/verbs]</type>
                    <meaning>[Meaning of the word]</meaning>
                    <example>[An example sentece with given word]</example>
                </definition>
            </definitions>"""


def get_initial_prompt(is_est=False):
    return f"""You are a highly skilled AI trained in language comprehension and WordNet generation. 
            You will be given a word and you must give all defenitions of the word{" exactly like are in the Python NLTK library English WordNet" if True else ""}.
            The output must contain only plain text and must contain given word, word type, meaing of the word and and example separeted by a new line. 
            An example of the output:
            Word: [Given word]
            Type: [adjectives/adverbs/conjunctions/determiners/nouns/prepositions/pronouns/verbs]
            Meaning: [Meaning of the word]
            Example: [An example sentece with given word]
            Do not include any other information. Each difinition must have word, type, meaning and example. Each difinition must have only one meaning and one example."""


def get_initial_prompt_est():
    return f""" Sa oled kõrgelt kvalifitseeritud keele mõistmise ja WordNeti genereerimise AI.
                Sulle antakse sõna ja sa pead andma kõik selle sõna definitsioonid{" täpselt nagu on Python EstNLTK teegi Eesti WordNetis" if True else ""}.
                Väljund peab sisaldama ainult tavalist teksti ja peab sisaldama, uue reaga eraldatud: antud sõna, sõna tüüpi, sõna tähendust ja näide. Väljundi näide:
                Sõna: [Antud sõna]
                Tüüp: [omadussõna/abiverb/sidesõna/määrsõna/asesõna/nimisõna/palind/tegusõna]
                Tähendus: [Sõna tähendus]
                Näide: [Näide lause antud sõnaga]
                Ära lisa väljundisse muud informatsiooni. Iga definitsioon peab sisaldama antud sõna, tüüp, tähendust ja näidet. Iga definitsioon peab sisaldama ainult ühte tähendust ja näidet."""


def get_prompt_str_xml(word):
    return f"""Now you will be given the following fields: id, word, type, meaning and example. 
    You will have to give exact {word}s {"that are in WordNet in Python NLTK library" if True else ""}.
    The output must contain only XML. Here is an example of what XML must look like:

            <{word}s>
                <{word}>[{word} of word 1]</{word}>
                <{word}>[{word} of word 2]</{word}>
                <{word}>[{word} of word 3]</{word}>
                ...
            </{word}s>

    The XML is just an example, there can be more or less {word}s for each word. If there are no {word}s, just leave the {word}s tag empty.
    """


def get_prompt_str(word):
    return f"""Now you will be given the following fields: word, type, meaning and example.
    You will have to give exact {word}s{" that are exactly like in the Python NLTK library English Wordnet" if True else ""}. 
    The output must contain only plain text. Here is an example of what the text must look like:
            [First {word}]
            [Second {word}]
            [Third {word}]
            ...
    If there are no {word}s, just leave the output empty.
    """


def get_prompt_str_est(word):
    return f"""Nüüd antakse sulle järgmised väljad: sõna, tüüp, tähendus ja näide.
    Sa pead andma täpsed {word}id{", mis on täpselt nagu Python EstNLTK teegi Eesti Wordnetis" if True else ""}.
    Väljund peab sisaldama ainult tavalist teksti. Siin on näide, kuidas tekst peab välja nägema:
            [Esimene {word}]
            [Teine {word}]
            [Kolmas {word}]
            ...
    Kui sõnu pole, jäta väljund tühjaks.
    """


relations = ['synonym', 'hyponym', 'meronym', 'antonym', 'hypernym', 'holonym']
relations_est = {'synonym': 'sünonüüm', 
                 'hyponym': 'hüponüüm', 
                 'meronym': 'meronüüm', 
                 'antonym': 'antonüüm',
                 'hypernym': 'hüperonüüm',
                 'holonym': 'holonüüm'}

In [4]:
total_price = 0
def openai_api_calculate_cost(usage, model="gpt-4-1106-preview"): # https://community.openai.com/t/how-to-calculate-the-cost-of-a-specific-request-made-to-the-web-api-and-its-reply-in-tokens/270878/15
    global total_price
    pricing = {
        'gpt-3.5-turbo-1106': {
            'prompt': 0.001,
            'completion': 0.002,
        },
        'gpt-4-1106-preview': {
            'prompt': 0.01,
            'completion': 0.03,
        },
        'gpt-4': {
            'prompt': 0.03,
            'completion': 0.06,
        }
    }

    try:
        model_pricing = pricing[model]
    except KeyError:
        raise ValueError("Invalid model specified")

    prompt_cost = usage.prompt_tokens * model_pricing['prompt'] / 1000
    completion_cost = usage.completion_tokens * \
        model_pricing['completion'] / 1000

    total_cost = prompt_cost + completion_cost
    # round to 6 decimals
    total_cost = round(total_cost, 6)

    # print(
    #     f"\nTokens used:  {usage.prompt_tokens:,} prompt + {usage.completion_tokens:,} completion = {usage.total_tokens:,} tokens")
    # print(f"Total cost for {model}: ${total_cost:.4f}\n")
    total_price += total_cost
    return total_cost


def is_person(words):
    words_checked = [part[0].isupper()
                     for word in words for part in word.split('_')]
    return all(words_checked)


def remove_short_words(words):
    return [word for word in words if len(word) > 2]


# https://stackoverflow.com/questions/53416780/how-to-convert-token-list-into-wordnet-lemma-list-using-nltk
def convert_to_lemma(sentence):
    lemmatizer = nltk.WordNetLemmatizer()
    text = [lemmatizer.lemmatize(word) for word in sentence]
    lemmas = []
    for token in text:
        try:
            lemmas += [synset.lemmas()[0].name()
                       for synset in wn.synsets(token)]
        except:
            lemmas += [token]
    return set(lemmas)


def convert_to_lemma_est(sentence_ls):
    sentence = ' '.join(sentence_ls)
    text = et.Text(sentence)
    lemmas_layer = text.tag_layer().morph_analysis.lemma
    lemmas = [word for lemmas_list in lemmas_layer for word in lemmas_list]
    return set(lemmas)


def find_wordnet_synset(word, definition):
    overlap = 0
    most_overlap_synset = None
    parsed_def = remove_short_words(definition.lower().split())
    parsed_def_lemma = convert_to_lemma(parsed_def)
    for i, synset in enumerate(wn.synsets(word)):
        # if is_person(synset.lemma_names()):
        #     continue
        actual_def = remove_short_words(synset.definition().lower().split())
        actual_def_lemma = convert_to_lemma(actual_def)
        overlap_temp = len(set(actual_def).intersection(set(parsed_def)))
        overlap_temp_lemma = len(
            actual_def_lemma.intersection(parsed_def_lemma))
        overlap_temp += overlap_temp_lemma
        if overlap_temp > overlap:
            overlap = overlap_temp
            most_overlap_synset = synset
    return most_overlap_synset


def get_word_synset(synset, syn_type):
    match syn_type:
        case 'synonym':
            return synset.lemmas()
        case 'hyponym':
            return synset.hyponyms()
        case 'meronym':
            return synset.part_meronyms()
        case 'antonym':
            return synset.lemmas()[0].antonyms()
        case 'hypernym':
            return synset.hypernyms()
        case 'holonym':
            return synset.member_holonyms()
        case _:
            raise ValueError(f"Unknown syn_type: {syn_type}")


def find_wordnet_synset_est(word, definition):
    overlap = -1
    most_overlap_synset = None
    parsed_def = remove_short_words(definition.lower().split())
    parsed_def_lemma = convert_to_lemma_est(parsed_def)
    for i, synset in enumerate(estwn[word]):
        actual_def = remove_short_words(synset.definition.lower().split())
        actual_def_lemma = convert_to_lemma_est(actual_def)
        overlap_temp = len(set(actual_def).intersection(set(parsed_def)))
        overlap_temp_lemma = len(
            actual_def_lemma.intersection(parsed_def_lemma))
        overlap_temp += overlap_temp_lemma
        if overlap_temp > overlap:
            overlap = overlap_temp
            most_overlap_synset = synset
    return most_overlap_synset


def get_word_synset_est(synset, syn_type):
    match syn_type:
        case 'synonym':
            return synset.lemmas
        case 'hyponym':
            return synset.hyponyms
        case 'meronym':
            return synset.meronyms
        case 'antonym':
            return synset.get_related_synset('antonym')
        case 'hypernym':
            return synset.hypernyms
        case 'holonym':
            return synset.holonyms
        case _:
            raise ValueError(f"Unknown syn_type: {syn_type}")

In [5]:
def gen_from_prompt(msg, tags):
    completion = None
    try:
        completion = openai.ChatCompletion.create(deployment_id="gec", model="gpt-4-1106-preview", messages=msg)
        answer = completion["choices"][0]["message"]["content"]
        # print(completion)
        # if check_XML_validity(answer) and check_tags_XML(answer, tags):
        #     break

    except openai.error.ServiceUnavailableError:
        pass
        # Happens sometimes, just asking again usually helps

    except openai.error.APIError:
        pass
        # Haven't looked, why does it happen, asking again helps usually
    except KeyError as e:
        if completion["choices"][0]["finish_reason"] == "content_filter":
            pass
            # Some filter, happens even when nothing is wrong with the input, asking again might help

    except openai.error.InvalidRequestError:
        pass
        # Aslo something related to input text

    except openai.error.RateLimitError:
        time.sleep(3)
        # The error message said, that it's better to wait three seconds and try again
    if completion is not None:
        openai_api_calculate_cost(completion["usage"])
    return completion["choices"][0]["message"] if completion is not None else None

def check_XML_validity(xml_str):
    try:
        ET.fromstring(xml_str)
        return True
    except ET.ParseError:
        return False
    
def check_tags_XML(xml_str, tags):
    try:
        for k, v in tags.items():
            for elem in ET.fromstring(xml_str).iter(k):
                for tag in v:
                    if elem.find(tag) is None:
                        return False
        return True
    except ET.ParseError:
        return False


In [15]:
def main(input_file='test.txt', is_est=False, cur_time=None):
    global root
    WORD_ID = 1
    if cur_time is None:
        cur_time = datetime.datetime.now()
    with open(input_file, 'r') as in_fp, open(f'{cur_time}_broken.xml', 'w') as broken_fp, open(f'{cur_time}_log.txt', 'w') as log_fp:
        root = ET.Element('synsets')
        for i, line in enumerate(in_fp.readlines()):
            print(i, line.strip())
            messages = [
                # {"role": "system", "content":  get_initial_prompt(is_est)},
                {"role": "system", "content":  get_initial_prompt_est() if is_est else get_initial_prompt()},
                {"role": "user", "content": line.strip()},
            ]
            check = False
            for _ in range(3):
                answer = gen_from_prompt(messages, None)
                if answer is not None and 'content' in answer:
                    answer_ls = [el.split(':')[-1].strip() for el in answer['content'].split('\n') if len(el) > 0]
                    if len(answer_ls) % 4 == 0:
                        for l in range(0, len(answer_ls), 4):
                            if answer_ls[l].strip().lower() != line.strip().lower():
                                break
                        else:
                            check = True
                            break
            if not check:
                broken_fp.write('BROKEN WORD: ' + line.strip() + "\n")
                continue
            log_fp.write(f"WORD: {line.strip()}\n")
            log_fp.write(answer['content'] + "\n")
            print("GEN: meanings, size:", len(answer_ls)//4, answer_ls)
            messages.append(dict(answer))
            # answer_ls = answer['content'].split('\n')
            for i in range(len(answer_ls)//4):
                print(f"GEN: {i+1}th word: {answer_ls[i*4]}")
                xml_str = f"""<synset id="{WORD_ID}" word="{answer_ls[i*4]}" type="{answer_ls[i*4+1]}">
                    <generated>
                    <meaning>{answer_ls[i*4+2]}</meaning>
                    <example>{answer_ls[i*4+3]}</example>
                    """
                gen_rel_dict = {}
                for relation in relations:
                    temp_list = messages.copy()
                    # temp_list.append({"role": "system", "content": get_prompt_str(relation)})
                    temp_list.append({"role": "system", "content": get_prompt_str_est(relations_est[relation]) if is_est else get_prompt_str(relation)})
                    prompt = f"""Word: {answer_ls[i*4]},
                        Type: {answer_ls[i*4+1]},
                        Meaning: {answer_ls[i*4+2]},
                        Example: {answer_ls[i*4+3]}"""
                    temp_list.append({"role": "user", "content": prompt})
                    check = False
                    for _ in range(3):
                        rel_answer = gen_from_prompt(temp_list, None)
                        if rel_answer is not None and 'content' in  rel_answer:
                            check = True
                            break
                    if not check:
                        gen_rel_dict[relation] = []
                        continue
                    log_fp.write(f"RELATION: {relation}\n\{rel_answer['content']}\n")
                    rel_answer_ls = rel_answer['content'].split('\n')
                    rel_answer_ls = [rel.strip().lower().replace(' ', '_') for rel in rel_answer_ls]
                    gen_rel_dict[relation] = rel_answer_ls
                    xml_str += f"""<{relation}s>{rel_answer_ls}</{relation}s>"""
                xml_str += "</generated>"
                if is_est:
                    wn_synset = find_wordnet_synset_est(answer_ls[i*4], answer_ls[i*4+2])
                else:
                    wn_synset = find_wordnet_synset(answer_ls[i*4], answer_ls[i*4+2])
                actual_rel_dict = dict.fromkeys(relations, [])
                if wn_synset is None:
                    xml_str += f"""
                    <actual>NONE</actual>
                    <stats>
                    """
                else:
                    xml_str += f"""
                    <actual>
                    <wn_name>{wn_synset.name if is_est else wn_synset.name()}</wn_name>
                    <meaning>{wn_synset.definition if is_est else wn_synset.definition()}</meaning>"""
                    for relation in relations:
                        try:
                            synset = get_word_synset_est(wn_synset, relation) if is_est else get_word_synset(wn_synset, relation)
                        except:
                            synset = []
                        if is_est:
                            if relation == 'synonym':
                                synset = [s.lower() for s in synset]
                            else:
                                synset = [s.name.lower().split('.')[0] for s in synset]
                        else:
                            synset = [s.name().lower().split('.')[0] for s in synset]
                        actual_rel_dict[relation] = synset
                        
                        xml_str += f"""
                        <{relation}s>
                        {synset}
                        </{relation}s>"""
                    xml_str += f"""
                    </actual>
                    <stats>
                    """
                total_gen = 0
                total_actual = 0
                total_overlapping = 0
                total_over_generated = 0
                total_under_generated = 0
                for relation in relations:
                    gen_rel_set = set(gen_rel_dict[relation])
                    actual_rel_set = set(actual_rel_dict[relation])
                    cur_total_gen = len(gen_rel_dict[relation])
                    cur_total_actual = len(actual_rel_dict[relation])
                    cur_total_overlapping = len(gen_rel_set.intersection(actual_rel_set))
                    cur_total_over_generated = len(gen_rel_set.difference(actual_rel_set))
                    cur_total_under_generated = len(actual_rel_set.difference(gen_rel_set))
                    total_actual += cur_total_actual
                    total_gen += cur_total_gen
                    total_overlapping += cur_total_overlapping
                    total_over_generated += cur_total_over_generated
                    total_under_generated += cur_total_under_generated
                    xml_str += f"""
                    <{relation}>
                    <generated_size>{cur_total_gen}</generated_size>
                    <actual_size>{cur_total_actual}</actual_size>
                    <overlapping>{cur_total_overlapping}</overlapping>
                    <over_generated>{cur_total_over_generated}</over_generated>
                    <under_generated>{cur_total_under_generated}</under_generated>
                    </{relation}>"""  
                xml_str += f"""
                <total>
                <generated_size>{total_gen}</generated_size>
                <actual_size>{total_actual}</actual_size>
                <overlapping>{total_overlapping}</overlapping>
                <over_generated>{total_over_generated}</over_generated>
                <under_generated>{total_under_generated}</under_generated>
                </total>
                """   
                xml_str += f"""
                </stats>
                </synset>
                """
                try:
                    root.append(ET.fromstring(xml_str))
                except ET.ParseError:
                    broken_fp.write('BROKEN WORD: ' + line.strip() + "\n")
                    broken_fp.write(xml_str + "\n")
                WORD_ID += 1
            
            ET.ElementTree(root).write(f'{cur_time}_output.xml', encoding="UTF-8")
        return root

In [16]:
is_test = True
cur_time = datetime.datetime.now()
file_name = f'{cur_time}_random_words.txt'
if not is_test:
    rand_lines_nr = 5
    with open('words.txt') as fp:
    # with open('lemmad.txt') as fp:
        rand_lines = random.sample(list(fp), rand_lines_nr)
    with open(file_name, 'w') as fp:
        fp.writelines(rand_lines)
    r = main(input_file=file_name, cur_time=cur_time, is_est=False)
else:
    r = main(input_file='test_est.txt', cur_time=cur_time, is_est=True)

0 keel
GEN: meanings, size: 5 ['keel', 'nimisõna', 'suus paiknev liikuv organ, millel on tähtis roll toidu maitsmisel ja neelamisel ning häälikute moodustamisel', 'Ta näitas oma roosat keelt.', 'keel', 'nimisõna', 'abstraktne süsteem, mida inimesed kasutavad helide ja kirjalike märkide abil mõtete väljendamiseks ja edasiandmiseks', 'Eesti keel on oma keerukuses võluv.', 'keel', 'nimisõna', 'instrumendi osa, mille abil selle toimimist reguleeritakse', 'Kitarri keel.', 'keel', 'nimisõna', 'keeleteaduslik tüpoloogia tanüümiks nimetamissüsteem.', 'Isurite keel kuulub soome-ugri keelkonda.', 'keel', 'nimisõna', 'Keeleteaduse mõiste, mille abil jaotatakse keeli nende struktuuri või päritolu järgi.', 'Soome keel on soome-ugri keel.']
GEN: 1th word: keel
GEN: 2th word: keel
GEN: 3th word: keel
GEN: 4th word: keel
GEN: 5th word: keel
1 kurg
GEN: meanings, size: 4 ['kurg', 'nimisõna', 'inimese kaelasopa sees paiknev luu- ja lihaseline moodustis, mis on nikastamise korral inimese hääleaparaat ja 

In [75]:
print(total_price)

1.2071799999999997


In [6]:
def test_accuracy(input_file):
    with open(input_file, 'r') as fp:
        xml_file = ET.parse(fp)
        root = xml_file.getroot()
        # Get generated and actual synsets and compare relations by calculating the overlap
        for synset in root.findall('synset'):
            generated = synset.find('generated')
            actual = synset.find('actual')
            print(f"Word: {synset.get('word')}, Type: {synset.get('type')}")
            print(
                f"Generated: {generated.find('meaning').text}, Actual: {actual.find('meaning').text}")
            for relation in relations:
                gen_rel = generated.find(f"{relation}s")
                act_rel = actual.find(f"{relation}s")
                if gen_rel is None and act_rel is None:
                    continue
                if gen_rel is None or act_rel is None:
                    print(
                        f"Relation: {relation}, Gen: {gen_rel}, Act: {act_rel}")
                    continue
                gen_rel = set(gen_rel.text.split())
                act_rel = set(act_rel.text.split())
                print(f"Relation: {relation}, Gen: {gen_rel}, Act: {act_rel}")
                print(
                    f"Overlap: {len(gen_rel.intersection(act_rel))}, Gen: {len(gen_rel)}, Act: {len(act_rel)}")


def count_total_stats(input_file):
    with open(input_file, 'r') as fp:
        xml_file = ET.parse(fp)
        root = xml_file.getroot()
        total_gen = 0
        total_actual = 0
        total_overlapping = 0
        total_over_generated = 0
        total_under_generated = 0
        for synset in root.findall('synset'):
            if synset.find('actual').text == 'NONE':
                continue
            stats = synset.find('stats')
            total_stats = stats.find('total')
            total_gen += int(total_stats.find('generated_size').text)
            total_actual += int(total_stats.find('actual_size').text)
            total_overlapping += int(total_stats.find('overlapping').text)
            total_over_generated += int(total_stats.find('over_generated').text)
            total_under_generated += int(
                total_stats.find('under_generated').text)
        print(f"Total: Gen: {total_gen}, Act: {total_actual}, Overlap: {total_overlapping}, Over Gen: {total_over_generated}, Under Gen: {total_under_generated}")

In [10]:
def count_total_meanings(input_file, is_est=False):
    found_words = set()
    actual_meanings = 0
    with open(input_file, 'r') as fp:
        xml_file = ET.parse(fp)
        root = xml_file.getroot()
        total_meanings = 0
        for synset in root.findall('synset'):
            if synset.get('word') not in found_words:
                found_words.add(synset.get('word'))
                if is_est:
                    actual_meanings += len(estwn[synset.get('word')])
                else:
                    actual_meanings += len(wn.synsets(synset.get('word')))
            total_meanings += 1
        print(f"Total: {total_meanings}, Actual: {actual_meanings}")

In [12]:
count_total_meanings('orig_prompt_eng.xml')
count_total_meanings('orig_prompt_est.xml', is_est=True)
count_total_meanings('new_prompt_eng.xml')
count_total_meanings('new_prompt_est.xml', is_est=True)

Total: 36, Actual: 60
Total: 20, Actual: 21
Total: 35, Actual: 60
Total: 25, Actual: 21
