In [1]:
def extract_predict(vocab):
    gold = None
    with open(predict_file) as fp, open(test_file) as ft:
        for line in fp:
            predict = json.loads(line)
            if gold is None or gold["file name"] != predict["file"] or gold["sentence id"] != predict["sent"]:                
                gold = json.loads(next(ft))
                
            gold_args = None
            for pas in gold["pas"]:
                if pas["p_id"] == predict["pred"]:
                    gold_args = pas["args"]
                    break
            if gold_args is None:
                raise RuntimeError
                
            tokens = [vocab[idx] for idx in gold["tokens"]]
            predicate = tokens[predict["pred"]]
            
            for idx, case in enumerate(["ga", "o", "ni"]):
                prd_term = tokens[predict[case]] if case in predict else "None"
                gold_term = tokens[gold_args.index(idx)] if idx in gold_args else "None"
                if prd_term == "None" and gold_term == "None":
                    continue
                if case in predict and gold_args[predict[case]] == idx:
                    yield ("correct", case, predicate, gold_term, tokens)
                else:
                    yield (prd_term, case, predicate, gold_term, tokens)

In [2]:
def chain_combination(before, after):
    if len(after) == 1:
        for word in after[0]:
            yield before + word
    else:
        for word in after[0]:
            for text in chain_combination(before + word, after[1:]):
                yield text

In [3]:
def concat_rep(rep):
    chain = [[token.split("/")[0] for token in chunk.split("?")] for chunk in rep.split("+")]
    for text in chain_combination("", chain):
        yield text

In [4]:
def extract_frames(root):
    frames = {}
    for entry in tqdm(root):
        for predicate in concat_rep(entry.attrib["headword"]):
            extract_cases = defaultdict(lambda : {})
            for caseframe in entry:
                for case in caseframe:
                    for cn, cl in [("ガ格", "ga"), ("ヲ格", "o"), ("ニ格", "ni")]:
                        if case.attrib["case"] == cn:
                            for comp in case:
                                for word in concat_rep(comp.text):
                                    extract_cases[cl][word] = comp.attrib['frequency']
            frames[predicate] = extract_cases
    
    return frames

In [6]:
import xml.etree.ElementTree as ET

caseframe_file = "/Users/ryuto/lab/research/data/raw/kyoto-univ-web-cf-1.0/kyoto-univ-web-cf-1.0.xml"

tree = ET.parse(caseframe_file)
root = tree.getroot()

In [9]:
import json
from tqdm import tqdm
from collections import defaultdict

frames = extract_frames(root)

100%|██████████| 34059/34059 [00:30<00:00, 1125.85it/s]


In [11]:
out_file = "/Users/ryuto/lab/research/data/raw/kyoto-univ-web-cf-1.0/case-frame.json"

with open(out_file, "w") as fo:
    json.dump(frames, fo)

In [12]:
predict_file = "/Users/ryuto/lab/research/work/ACL2020/predict-dev-base_full-olr0.001_plr0.001_h256_layer10_d0.0_True_size100-0.4-0.52-0.13.txt"
test_file = "/Users/ryuto/lab/research/work/ACL2020/dev.jsonl"
wordindex_file = "/Users/ryuto/lab/research/data/raw/NTC_Matsu_original/wordIndex.txt"

In [13]:
def read_vocab():
    vocab = {}
    with open(wordindex_file) as fi:
        for line in fi:
            word, index = line.rstrip("\n").split("\t", 1)
            vocab[index] = word
    return vocab

In [14]:
vocab = read_vocab()

In [15]:
prec = []
recall = []
co = []
correct = []

for p, case, prd, g, tokens in extract_predict(vocab):
    if p == "None":
        recall.append((p, case, prd, g))
    elif g == "None":
        prec.append((p, case, prd, g))
    elif p == "correct":
        correct.append((p, case, prd, g))
    else:
        co.append((p, case, prd, g))