In [1]:
import gzip
import json

import MeCab
import unidic

import pickle
import time

import math
from tqdm import tqdm

from multiprocessing import Pool, cpu_count
from collections import Counter, defaultdict

In [2]:
input_file = '../data/all_entities.json.gz'
with gzip.open(input_file, "rt", encoding="utf-8") as fin:
    lines = fin.readlines()

In [3]:
entities = dict()
for line in lines:
    entity = json.loads(line.strip())
    entities[entity["title"]] = entity["text"]
del lines

In [4]:
tagger = MeCab.Tagger('-d "{}"'.format(unidic.DICDIR))
STOP_POSTAGS = ('代名詞', '接続詞', '感動詞', '動詞,非自立可能', '助動詞', '助詞', '接頭辞', '記号,一般', '補助記号', '空白', 'BOS/EOS')

In [5]:
def parse_text(text):
    node = tagger.parseToNode(text)
    tokens = []
    while node:
        if node.feature.startswith(STOP_POSTAGS):
            pass
        elif len(feature := node.feature.split(",")) > 7:            
            tokens += [feature[7].lower()]
        else:
            tokens += [node.surface.lower()]
        node = node.next
    return tokens

In [6]:
with open('../ir_dumps/doc_id2title.pickle', 'rb') as f:
    doc_id2title = pickle.load(f)
with open('../ir_dumps/doc_id2token_count.pickle', 'rb') as f:
    doc_id2token_count = pickle.load(f)
with open('../ir_dumps/token2pointer.pickle', 'rb') as f:
    token2pointer = pickle.load(f)

In [7]:
def load_index_line(index_line):
    return list(map(lambda x: tuple(map(int, x.split(':'))), index_line.split(' ')))

In [8]:
def search_entity_candidates(query, candidates, topk=10):
    avgdl = sum(doc_id2token_count) / len(doc_id2token_count)
    parsed_query = parse_text(query)
    target_posting = {}
    with open('../ir_dumps/inverted_index', 'r', encoding='utf-8') as index_file:
        for token in parsed_query:
            if token in token2pointer:
                pointer, offset = token2pointer[token]
                index_file.seek(pointer)
                index_line = index_file.read(offset-pointer).rstrip()
                postings_list = load_index_line(index_line)
                target_posting[token] = postings_list

    # bm25スコアでor検索
    k1 = 2.0
    b = 0.75
    all_docs = len(entities)
    doc_id2tfidf = [0 for i in range(all_docs)]
    for token, postings_list in target_posting.items():
        idf = math.log2((all_docs-len(postings_list)+0.5) / (len(postings_list) + 0.5))
        # idfが負になる単語は一般的すぎるので無視
        idf = max(idf, 0)
        if idf == 0:
            continue
        for doc_id, tf in postings_list:
            dl = doc_id2token_count[doc_id]
            token_tfidf = idf * ((tf * (k1 + 1))/(tf + k1 * (1-b+b*(dl/avgdl))))
            doc_id2tfidf[doc_id] += token_tfidf
    
    # candidateごとの検索
    search_results = []
    with open('../ir_dumps/inverted_index', 'r', encoding='utf-8') as index_file:
        for candidate in candidates:
            parsed_candidate = parse_text(candidate)
            
            candidate_target_posting = {}
            for token in parsed_candidate:
                if token in token2pointer:
                    pointer, offset = token2pointer[token]
                    index_file.seek(pointer)
                    index_line = index_file.read(offset-pointer).rstrip()
                    postings_list = load_index_line(index_line)
                    candidate_target_posting[token] = postings_list
                    
            candidate_tfidf = []
            # candidateとなる文字列が含まれるdoc_idの集合
            candidate_doc_ids = set()
            for token_position, (token, postings_list) in enumerate(candidate_target_posting.items()):
                idf = math.log2((all_docs-len(postings_list)+0.5) / (len(postings_list) + 0.5))
                # idfが負になる単語は一般的すぎるので無視
                idf = max(idf, 0)
                if idf == 0:
                    continue
                token_doc_ids = []
                for doc_id, tf in postings_list:
                    dl = doc_id2token_count[doc_id]
                    token_tfidf = idf * ((tf * (k1 + 1))/(tf + k1 * (1-b+b*(dl/avgdl))))
                    doc_id2tfidf[doc_id] += token_tfidf
                    candidate_tfidf += [(doc_id, token_tfidf)]
                    token_doc_ids += [doc_id]
                
                if token_position == 0:
                    candidate_doc_ids |= set(token_doc_ids)
                else:
                    candidate_doc_ids &= set(token_doc_ids)

            docs = [(doc_id, doc_id2tfidf[doc_id]) for doc_id in candidate_doc_ids]
            docs = sorted(docs, key=lambda x: x[1], reverse=True)
            search_results += [docs[:topk]]
            for doc_id, tfidf in candidate_tfidf:
                doc_id2tfidf[doc_id] -= tfidf
            
    return search_results

In [9]:
def parse_argument_wrapper(args):
    return search_entity_candidates(*args)

In [10]:
def get_sentence_list(search_result, topk=5):
    sentence_list = []
    for doc_id, _ in search_result[:topk]:
        title = doc_id2title[doc_id]
        text = entities[title]
        sentences = text.split('。')
        sentence_list += sentences
    return sentence_list

In [11]:
def sort_sentence(query, candidate, sentence_list, topk=10):
    inverted_index = defaultdict(list)
    sentence_id2sentence = [sentence for sentence in sentence_list]
    sentence_id2token_count = []
#     for sentence_id, sentence in tqdm(enumerate(sentence_list), total=len(sentence_list)):
    for sentence_id, sentence in enumerate(sentence_list):
        tokens = parse_text(sentence)
    
        sentence_id2token_count += [len(tokens)]

        count_tokens = Counter(tokens)
        for token, count in count_tokens.items():
            inverted_index[token] += [(sentence_id, count)]

    avgdl = sum(sentence_id2token_count) / len(sentence_id2token_count)
    parsed_query = parse_text(query)
    parsed_query += parse_text(candidate)
    target_posting = {}
    for token in parsed_query:
        if token in inverted_index:
            postings_list = inverted_index[token]
            target_posting[token] = postings_list

    # bm25スコアでor検索
    k1 = 2.0
    b = 0.75
    all_docs = len(sentence_list)
    sentence_id2tfidf = [0 for i in range(all_docs)]
    for token, postings_list in target_posting.items():
        idf = math.log2((all_docs-len(postings_list)+0.5) / (len(postings_list) + 0.5))
        # idfが負になる単語は一般的すぎるので無視
        idf = max(idf, 0)
        if idf == 0:
            continue
        for sentence_id, tf in postings_list:
            dl = sentence_id2token_count[sentence_id]
            token_tfidf = idf * ((tf * (k1 + 1))/(tf + k1 * (1-b+b*(dl/avgdl))))
            sentence_id2tfidf[sentence_id] += token_tfidf

    sentences = [(sentence_id, tfidf) for sentence_id, tfidf in enumerate(sentence_id2tfidf) if tfidf != 0]
    sentences = sorted(sentences, key=lambda x: x[1], reverse=True)
    return list(map(lambda x: (sentence_id2sentence[x[0]], x[1]), sentences[:topk]))

In [12]:
input_file = '../data/train_questions.json'
with open(input_file, "r", encoding="utf-8") as fin:
    lines = fin.readlines()

queries = []
candidates_list = []
for line in tqdm(lines):
    data_raw = json.loads(line.strip("\n"))
#     qid = data_raw["qid"]
    question = data_raw["question"].replace("_", "")  # "_" は cloze question
    options = data_raw["answer_candidates"]  # TODO
#     answer = data_raw["answer_entity"]
    queries += [question]
    candidates_list += [options]

100%|██████████| 13061/13061 [00:00<00:00, 69757.30it/s]


In [13]:
queries[0]

'格闘家ボブ・サップの出身国はどこでしょう?'

In [14]:
candidates_list[0]

['アメリカ合衆国',
 'ミネソタ州',
 'オンタリオ州',
 'ペンシルベニア州',
 'オレゴン州',
 'ニューヨーク州',
 'コロラド州',
 'オーストラリア',
 'ニュージャージー州',
 'マサチューセッツ州',
 'カナダ',
 'テキサス州',
 'ミシガン州',
 'ワシントン州',
 'ニュージーランド',
 'オハイオ州',
 'カリフォルニア州',
 'メリーランド州',
 'イリノイ州',
 'イギリス']

In [15]:
%%time
search_results = search_entity_candidates(queries[0], candidates_list[0])
search_results[0]

CPU times: user 4.51 s, sys: 68.4 ms, total: 4.58 s
Wall time: 4.58 s


[(503868, 67.16353159534238),
 (305429, 66.78403294179013),
 (795598, 66.12236396817468),
 (149041, 55.385281614879695),
 (457401, 53.10957954882017),
 (187616, 51.31076360804543),
 (85846, 47.76138663830434),
 (798722, 47.521333481150016),
 (838757, 45.60942607128479),
 (477688, 44.170923102934026)]

In [16]:
doc_id2title[503868], doc_id2title[305429], doc_id2title[795598]

('チャド・バノン', 'ボブ・サップ', 'マット・ヒューム')

In [17]:
search_results[1]

[(305429, 67.76079779388517),
 (715678, 47.623069924615784),
 (813895, 45.23169151314288),
 (746282, 44.59881819156162),
 (451862, 41.65596099849918),
 (449421, 40.626545292485474),
 (894796, 40.09513261039599),
 (851731, 39.178850788243054),
 (725307, 38.81948050223761),
 (404124, 37.660272326206965)]

In [18]:
parse_text('ミネソタ州')

['ミネソタ-minnesota', '州']

In [19]:
'ミネソタ州' in entities['ボブ・サップ'], 'ミネソタ' in entities['ボブ・サップ'], '州' in entities['ボブ・サップ']

(False, True, True)

In [20]:
search_results[2]

[(512070, 42.23082025433194),
 (391473, 40.218478525205306),
 (11451, 38.814849054614676),
 (690730, 38.734221194828606),
 (614755, 38.689525015104074),
 (816922, 37.50988913182361),
 (149660, 36.74755099732139),
 (141661, 36.46235181669672),
 (165329, 36.2479389065969),
 (490924, 35.825664500682386)]

In [21]:
doc_id2title[512070]

'ソロモン・ノーサップ'

In [22]:
'オンタリオ州' in entities['ソロモン・ノーサップ']

True

In [23]:
%%time
with Pool(cpu_count()) as p:
#     results = list(tqdm(p.imap(parse_argument_wrapper, zip(queries, candidates_list)), total=len(queries)))
    partial_results = list(tqdm(p.imap(parse_argument_wrapper, zip(queries[:1000], candidates_list[:1000])), total=1000))

100%|██████████| 1000/1000 [16:02<00:00,  1.04it/s]

CPU times: user 1.99 s, sys: 563 ms, total: 2.55 s
Wall time: 16min 2s





In [24]:
# with open('../ir_dumps/train_search_cadidates_results', 'wb') as f:
#     pickle.dump(results, f)
with open('../ir_dumps/partial_train_search_cadidates_results', 'wb') as f:
    pickle.dump(partial_results, f)

In [25]:
sample_index = 0
result = search_entity_candidates(queries[sample_index], candidates_list[sample_index], topk=3)

In [26]:
for candidate_index in range(3):
    candidate = candidates_list[sample_index][candidate_index]
    print(f'candidate={candidate}')
    sentence_list = get_sentence_list(result[candidate_index])
    sorted_sentence = sort_sentence(queries[0], candidates_list[0][candidate_index], sentence_list)
    print(sorted_sentence)

candidate=アメリカ合衆国
[('チャド・バノン（Chad Bannon、1970年11月13日 - ）は、アメリカ合衆国の男性格闘家、俳優である', 26.582446203988724), ('マット・ヒューム（Matt Hume、1966年7月14日 - ）は、アメリカ合衆国の男性総合格闘家', 26.582446203988724), ('ボブ・サップ（Bob Sapp、1973年9月22日 - ）は、アメリカ合衆国のキックボクサー、プロレスラー、総合格闘家、タレント、俳優、元アメリカンフットボール選手', 26.156155272330924), ('彼とは別にアメリカ出身で同姓同名の格闘家・マイケル・マクドナルドが存在するが、こちらは1991年生まれでありUFCを中心として総合格闘技一本に絞った競技生活を送っている', 16.810192741467606), (' ^ Bob Sapp explains DREAM "Dynamite!! 2010" no-show, says DREAM is "broke" MMA JUNKIE.COM 2011年1月18日 ^ SimonInoki 2011年1月17日 ^ 野獣ボブ・サップ怒りの暴露「格闘技界のカネ、女、ＦＢＩ」 vol.1デジタル大衆 週刊大衆4月1日号 ^ 野獣ボブ・サップ怒りの暴露「格闘技界のカネ、女、ＦＢＩ」 vol.2 デジタル大衆 週刊大衆4月1日号 ^ カナダ出身の格闘家で、1965年生まれ', 10.589177345992997), ('なお、ウォーレン・サップはフロリダ州オーランド出身であり、マイアミ大学の卒業生である', 9.315335484332618), ('コロラド州出身', 8.504074349386341), ('コロラド州コロラドスプリングス出身', 8.161026252958882), ('ワシントン州シアトル出身', 8.161026252958882), (' 一部で「ボブ・サップは元NFLプレイヤーのウォーレン・サップ（2013年プロフットボール殿堂入り）の兄弟（または従兄弟）である」というまったくの誤報が流れたが、ボブ・サップとウォーレン・サップとの間に特筆すべき血縁関係はない', 7.961718535845536)]
candidat