In [1]:
import gzip
import json

import MeCab
import unidic

import pickle
import time

import math
from tqdm import tqdm

from multiprocessing import Pool, cpu_count

In [2]:
input_file = '../data/all_entities.json.gz'
with gzip.open(input_file, "rt", encoding="utf-8") as fin:
    lines = fin.readlines()

In [3]:
entities = dict()
for line in lines:
    entity = json.loads(line.strip())
    entities[entity["title"]] = entity["text"]
del lines

In [4]:
tagger = MeCab.Tagger('-d "{}"'.format(unidic.DICDIR))
STOP_POSTAGS = ('代名詞', '接続詞', '感動詞', '動詞,非自立可能', '助動詞', '助詞', '接頭辞', '記号,一般', '補助記号', '空白', 'BOS/EOS')

In [5]:
def parse_text(text):
    node = tagger.parseToNode(text)
    tokens = []
    while node:
        if node.feature.startswith(STOP_POSTAGS):
            pass
        elif len(feature := node.feature.split(",")) > 7:            
            tokens += [feature[7].lower()]
        else:
            tokens += [node.surface.lower()]
        node = node.next
    return tokens

In [6]:
with open('../ir_dump/doc_id2title.pickle', 'rb') as f:
    doc_id2title = pickle.load(f)
with open('../ir_dump/doc_id2token_count.pickle', 'rb') as f:
    doc_id2token_count = pickle.load(f)
with open('../ir_dump/token2pointer.pickle', 'rb') as f:
    token2pointer = pickle.load(f)
title2doc_id = {title:doc_id for doc_id, title in enumerate(doc_id2title)}

In [7]:
def load_index_line(index_line):
    return list(map(lambda x: tuple(map(int, x.split(':'))), index_line.split(' ')))

In [8]:
def search_entity(query, exclude_candidates=[], topk=10):
    avgdl = sum(doc_id2token_count) / len(doc_id2token_count)
    parsed_query = parse_text(query)
    target_posting = {}
    with open('../ir_dump/inverted_index', 'r', encoding='utf-8') as index_file:
        for token in parsed_query:
            if token in token2pointer:
                pointer, offset = token2pointer[token]
                index_file.seek(pointer)
                index_line = index_file.read(offset-pointer).rstrip()
                postings_list = load_index_line(index_line)
                target_posting[token] = postings_list

    # bm25スコアでor検索
    k1 = 2.0
    b = 0.75
    all_docs = len(entities)
    doc_id2tfidf = [0 for i in range(all_docs)]
    for token, postings_list in target_posting.items():
        idf = math.log2((all_docs-len(postings_list)+0.5) / (len(postings_list) + 0.5))
        # idfが負になる単語は一般的すぎるので無視
        idf = max(idf, 0)
        if idf == 0:
            continue
        for doc_id, tf in postings_list:
            dl = doc_id2token_count[doc_id]
            token_tfidf = idf * ((tf * (k1 + 1))/(tf + k1 * (1-b+b*(dl/avgdl))))
            doc_id2tfidf[doc_id] += token_tfidf
    
    # 選択肢を検索結果から除外
    exclude_doc_ids = [title2doc_id[candidate] for candidate in exclude_candidates]
    for exclude_doc_id in exclude_doc_ids:
        doc_id2tfidf[exclude_doc_id] = 0
    
    docs = [(doc_id, tfidf) for doc_id, tfidf in enumerate(doc_id2tfidf) if tfidf != 0]
    docs = sorted(docs, key=lambda x: x[1], reverse=True)
    return docs[:topk]

In [9]:
input_file = '../data/train_questions.json'
with open(input_file, "r", encoding="utf-8") as fin:
    lines = fin.readlines()

queries = []
candidates_list = []
for line in tqdm(lines):
    data_raw = json.loads(line.strip("\n"))
#     qid = data_raw["qid"]
    question = data_raw["question"].replace("_", "")  # "_" は cloze question
    options = data_raw["answer_candidates"]  # TODO
#     answer = data_raw["answer_entity"]
    queries += [question]
    candidates_list += [options]

100%|██████████| 13061/13061 [00:00<00:00, 62630.90it/s]


In [10]:
%%time
with Pool(cpu_count()) as p:
#     results = list(tqdm(p.imap(search_entity, queries), total=len(queries)))
    partial_results = list(tqdm(p.imap(search_entity, queries[:1000]), total=len(queries[:1000])))

100%|██████████| 1000/1000 [07:29<00:00,  2.23it/s]

CPU times: user 1.94 s, sys: 519 ms, total: 2.46 s
Wall time: 7min 29s





In [11]:
# with open('../ir_dump/train_search_results', 'wb') as f:
with open('../ir_dump/partial_train_search_results', 'wb') as f:
    pickle.dump(partial_results, f)

In [12]:
def search_entity_exclude_candidates(args):
    return search_entity(*args)

In [13]:
%%time
with Pool(cpu_count()) as p:
#     results = list(tqdm(p.imap(search_entity_exclude_candidates, zip(queries, candidates_list)), total=len(queries)))
    partial_results = list(tqdm(p.imap(search_entity_exclude_candidates, zip(queries[:1000], candidates_list[:1000])), total=1000))

100%|██████████| 1000/1000 [07:24<00:00,  2.25it/s]

CPU times: user 2.21 s, sys: 558 ms, total: 2.77 s
Wall time: 7min 24s





In [14]:
# with open('../ir_dump/train_search_exclude_candidates_results', 'wb') as f:
with open('../ir_dump/partial_train_search_exclude_candidates_results', 'wb') as f:
    pickle.dump(partial_results, f)