In [1]:
import gzip
import json

import MeCab
import unidic

from contextlib import ExitStack
from collections import defaultdict, Counter
from tqdm import tqdm

import pickle
import time

# wikipedia読み込み

In [2]:
input_file = '../data/all_entities.json.gz'
with gzip.open(input_file, "rt", encoding="utf-8") as fin:
    lines = fin.readlines()

In [3]:
entities = dict()
for line in lines:
    entity = json.loads(line.strip())
    entities[entity["title"]] = entity["text"]
del lines

# 形態素解析器初期化

In [4]:
tagger = MeCab.Tagger('-d "{}"'.format(unidic.DICDIR))

In [5]:
STOP_POSTAGS = ('代名詞', '接続詞', '感動詞', '動詞,非自立可能', '助動詞', '助詞', '接頭辞', '記号,一般', '補助記号', '空白', 'BOS/EOS')

In [6]:
def parse_text(text):
    node = tagger.parseToNode(text)
    tokens = []
    while node:
        if node.feature.startswith(STOP_POSTAGS):
            pass
        elif len(feature := node.feature.split(",")) > 7:            
            tokens += [feature[7].lower()]
        else:
            tokens += [node.surface.lower()]
        node = node.next
    return tokens

# 転置インデックス分割作成

In [7]:
partial_size = 10**5
inverted_index = defaultdict(list)
doc_id2title = []
doc_id2token_count = []
for doc_id, (title, text) in tqdm(enumerate(entities.items()), total=len(entities)):
    tokens = parse_text(text)
    
    doc_id2title += [title]
    doc_id2token_count += [len(tokens)]
    
    count_tokens = Counter(tokens)
    for token, count in count_tokens.items():
        inverted_index[token] += [(doc_id, count)]
        
    if (doc_id + 1) % partial_size == 0:
        sorted_vocab = sorted(inverted_index.keys())
        partial_id = doc_id // partial_size
        
        with open('../ir_dumps/partial_dict_{}'.format(partial_id), 'w', encoding='utf-8') as fout:
            for token in sorted_vocab:
                fout.write(token + '\n')
                
        with open('../ir_dumps/partial_inverted_index_{}'.format(partial_id), 'w', encoding='utf-8') as fout:
            for token in sorted_vocab:
                posting_list = ' '.join([str(doc_id)+':'+str(tf)for doc_id, tf in inverted_index[token]])
                fout.write(posting_list + '\n')
        inverted_index = defaultdict(list)


sorted_vocab = sorted(inverted_index.keys())
partial_id = (len(entities)-1) // partial_size

with open('../ir_dumps/partial_dict_{}'.format(partial_id), 'w', encoding='utf-8') as fout:
    for token in sorted_vocab:
        fout.write(token + '\n')

with open('../ir_dumps/partial_inverted_index_{}'.format(partial_id), 'w', encoding='utf-8') as fout:
    for token in sorted_vocab:
        posting_list = ' '.join([str(doc_id)+':'+str(tf)for doc_id, tf in inverted_index[token]])
        fout.write(posting_list + '\n')

# docment_idをタイトルやトークン数に変換するlistを保存
with open('../ir_dumps/doc_id2title.pickle', 'wb') as f:
    pickle.dump(doc_id2title, f)

with open('../ir_dumps/doc_id2token_count.pickle', 'wb') as f:
    pickle.dump(doc_id2token_count, f)

100%|██████████| 920172/920172 [56:36<00:00, 270.89it/s]  


# 分割転置インデックスのマージ

In [8]:
def load_index_line(index_line):
    if index_line == '':
        return []
    return list(map(lambda x: tuple(map(int, x.split(':'))), index_line.split(' ')))

In [9]:
start_time = time.time()

dict_filenames = ['../ir_dumps/partial_dict_{}'.format(partial_id) for partial_id in range(10)]
index_filenames = ['../ir_dumps/partial_inverted_index_{}'.format(partial_id) for partial_id in range(10)]
line2token = []

with ExitStack() as stack, open('../ir_dumps/inverted_index', 'w', encoding='utf-8') as fout:
    dict_files = [stack.enter_context(open(fname, 'r', encoding='utf-8')) for fname in dict_filenames]
    index_files = [stack.enter_context(open(fname, 'r', encoding='utf-8')) for fname in index_filenames]
    tokens = []
    postings = []
    for dict_file, index_file in zip(dict_files, index_files):
        token = dict_file.readline().rstrip()
        index_line = index_file.readline().rstrip()
        partial_posting_list = load_index_line(index_line)
        tokens += [token]
        postings += [partial_posting_list]
        
    while sorted_token := sorted(filter(lambda x: x != '', tokens)):
        top_token = sorted_token[0]
        posting_list = []
        for partial_id, (dict_file, index_file) in enumerate(zip(dict_files, index_files)):
            token = tokens[partial_id]
            if token == top_token:
                posting_list += postings[partial_id]
                token = dict_file.readline().rstrip()
                index_line = index_file.readline().rstrip()
                partial_posting_list = load_index_line(index_line)
                tokens[partial_id] = token
                postings[partial_id] = partial_posting_list
        
        posting_list = ' '.join([str(doc_id)+':'+str(tf)for doc_id, tf in posting_list])
        line2token += [top_token]
        fout.write(posting_list + '\n')

end_time = time.time() - start_time
print(end_time)

# トークンと転置インデックスファイルのポインタ対応づけ
token2pointer = {}
with open('../ir_dumps/inverted_index', 'r', encoding='utf-8') as fin:
    for token in tqdm(line2token):
        start = fin.tell()
        line = fin.readline()
        end = fin.tell()
        token2pointer[token] = (start, end)
        
with open('../ir_dumps/token2pointer.pickle', 'wb') as f:
    pickle.dump(token2pointer, f)

  1%|          | 24013/3302574 [00:00<00:27, 120306.71it/s]

329.8755843639374


100%|██████████| 3302574/3302574 [00:29<00:00, 111312.89it/s]
