In [None]:
import openai
openai.api_key = '<your api key>'

In [None]:
from pathlib import Path
import json
import re

import editdistance

In [None]:
model_name = 'gpt-3.5-turbo'

In [None]:
kwdlc_path = Path('../KWDLC/')  # ひとつ上の階層に kwdlc を置きます
pairs = []
knp_paths = []

for subdir_path in sorted(kwdlc_path.glob('knp/w*')):
    subdir_name = subdir_path.name
    for fname in sorted(subdir_path.glob('*.knp')):
        knp_paths.append(fname)

        fstem = fname.stem
        org_fname = kwdlc_path / 'org' / subdir_name / (fstem + '.org')
        pair = (fname, org_fname)
        pairs.append(pair)

In [None]:
def extract_raw_and_yomi_pair(fname):
    sentences = []
    sentence = []
    for line in open(fname, 'rt').readlines():
        line = line.strip()
        if re.match('[\+\*#]\s', line):
            continue
        elif line == 'EOS':
            sentences.append(sentence)
            sentence = []
        else:
            surface, yomi, base, info = line.split(' ', 3)
            sentence.append((surface, yomi))
    return sentences

In [None]:
docs = []
for knp in knp_paths:
    sents = extract_raw_and_yomi_pair(knp)
    full_sent_pairs = []
    for sent in sents:
        surface = ''.join([surface for surface, yomi in sent])
        yomi = ''.join([yomi for surface, yomi in sent])
        full_sent_pairs.append((surface, yomi))
    docs.append(full_sent_pairs)

In [None]:
few_shots = docs[:4]
targets = docs[4:]

In [None]:

def make_surface(doc):
    return  '\n'.join([s for s, y in doc])

def make_yomi(doc):
    return '\n'.join([y for s, y in doc])

In [None]:
instruct_ja = """
入力文の読み方をひらがな（あいうえお など）および約物（、。！？「」" など）のみで出力してください。

### 条件
- 数字の読み方もひらがなで書いてください。
- 使っていい文字は、ひらがなと約物だけです。
"""

instruct_en = """
Output the reading of the input sentence using only hiragana (such as あいうえお) and punctuation marks (such as 、。！？「」" and so on).

### Conditions
- Write the reading of numbers in hiragana as well.
- Only use hiragana and punctuation marks for allowed characters.
"""

num = 0
if num > 0:
    shots = "### Examples\n"
    for doc in few_shots[:num]:
        surface =make_surface(doc)
        yomi = make_yomi(doc)
        shots += f"input:\n{surface}\noutput:\n{yomi}\n###\n"
else:
    shots = ""

In [None]:
print(shots)

In [None]:
def gpt_predict(system_text, user_text):
    output = openai.ChatCompletion.create(
        model=model_name,
        temperature=0,
        messages=[
            {"role": "system", "content": system_text},
            {"role": "user", "content": user_text}]
    )
    return output["choices"][0]["message"]["content"]


def remove_control_chars(text):
    return text.replace('\n', '').replace(' ', '')

In [None]:
results = []

In [None]:
for i, doc in enumerate(targets[:50]):
    print(i)
    surface = make_surface(doc)
    pred_yomi_via_ja = remove_control_chars(gpt_predict(f"{instruct_ja}\n\n{shots}", f"input:\n{surface}\noutput:\n"))
    pred_yomi_via_en = remove_control_chars(gpt_predict(f"{instruct_en}\n\n{shots}", f"input:\n{surface}\noutput:\n"))
    gold_yomi = remove_control_chars(make_yomi(doc))
    dist_ja= editdistance.eval(pred_yomi_via_ja, gold_yomi)
    dist_en= editdistance.eval(pred_yomi_via_en, gold_yomi)
    rs = {
        'input': remove_control_chars(surface),
        'gold_yomi': gold_yomi,
        'ja': {'pred': pred_yomi_via_ja, 'dist': dist_ja},
        'en': {'pred': pred_yomi_via_en, 'dist': dist_en}
    }
    results.append(rs)

In [None]:
with open(f'yomi_results_{num}shots.json', 'wt') as f:
    json.dump(results, f, ensure_ascii=False, indent=4)

In [None]:
dists_ja = []
dists_en = []
for rs in results:
    dists_ja.append(rs['ja']['dist'])
    dists_en.append(rs['en']['dist'])

ja_beats_en = 0
en_beats_ja = 0
draw = 0
for j, e in zip(dists_ja, dists_en):
    if j < e: ja_beats_en += 1
    elif e < j: en_beats_ja += 1
    else: draw += 1
print(f"{ja_beats_en}, {en_beats_ja}, {draw}")

In [None]:
print(sum(dists_ja) / len(dists_ja))
print(sum(dists_en) / len(dists_en))

In [None]:
print(len([d for d in dists_ja if d == 0]))
print(len([d for d in dists_en if d == 0]))