In [1]:
import os
from collections import defaultdict

# filename
DATA_DIR = '../data'
TRAIN_ANS_FILE = os.path.join(DATA_DIR, 'ans_train.csv')
QUERY_TRAIN_FILE = os.path.join(DATA_DIR, 'query-train.xml')
QUERY_TEST_FILE = os.path.join(DATA_DIR, 'query-test.xml')
FILE_LIST = os.path.join(DATA_DIR, 'file-list')
INV_FILE = os.path.join(DATA_DIR, 'inverted-file')
VOCAB_FILE = os.path.join(DATA_DIR, 'vocab.all')
OUTPUT_FILE = 'output.csv'

# building maps
docname2id = dict()
with open(FILE_LIST, 'r') as f:
    for idx, line in enumerate(f):
        docname2id[line.strip()] = idx
id2docname = {y:x for x, y in docname2id.items()}

word2id = dict()
with open(VOCAB_FILE, 'r') as f:
    f.readline()
    for idx, line in enumerate(f, 1):
        word2id[line.strip()] = idx

word_freq = dict()
gram_freq = dict()
gram2id = dict()
with open(INV_FILE, 'r') as f:
    idx = 0
    while True:
        line = f.readline().strip()
        if not line:
            break
        id_1, id_2, doc_count = [int(i) for i in line.split(' ')]
        doc_records = defaultdict(int)
        for i in range(doc_count):
            doc_id, freq = [int(i) for i in f.readline().strip().split(' ')]
            doc_records[doc_id] = freq
        if id_2 == -1:
            word_freq[(id_1)] = doc_records
        gram_freq[(id_1, id_2)] = doc_records
        gram2id[(id_1, id_2)] = idx
        idx += 1

In [2]:
doc2len = defaultdict(int)
for word_id, records in word_freq.items():
    for doc_id, freq in records.items():
        doc2len[doc_id] += freq
avdl = sum([length for _, length in doc2len.items()]) / len(doc2len)

In [3]:
import numpy as np

BM25_k = 3
norm_b = 0.7

word_num = len(word2id) + 1
doc_num = len(docname2id)
weight = np.zeros((doc_num, word_num))
for word_id, records in word_freq.items():
    IDF = np.log((doc_num + 1) / len(records))
    for doc_id, freq in records.items():
        TF = (BM25_k + 1) * freq / (freq + BM25_k)
        normalizer = 1 - norm_b + norm_b * doc2len[doc_id] / avdl
        weight[doc_id, word_id] = TF * IDF / normalizer

In [4]:
import xml.etree.ElementTree as ET
from collections import Counter

tree = ET.ElementTree(file=QUERY_TEST_FILE)
root = tree.getroot()
query_num = len(root)

queries = np.zeros((query_num, word_num))
for query_id, child in enumerate(root):
    query = list(''.join(child[4].text.strip('\n。 ').split('、')))
    print(query)
    for word, freq in Counter(query).items():
        IDF = np.log((doc_num + 1) / len(word_freq[word2id[word]]))
        TF = (BM25_k + 1) * freq / (freq + BM25_k)
        queries[query_id, word2id[word]] = TF * IDF

['白', '案', '白', '曉', '燕', '綁', '架', '擄', '人', '勒', '贖', '判', '決', '宣', '判', '審', '判', '改', '判', '徒', '刑', '自', '白', '高', '院', '最', '高', '法', '院', '檢', '察', '官', '合', '議', '庭', '證', '據', '一', '審', '二', '審', '陳', '進', '興', '張', '志', '輝', '張', '素', '真']
['漢', '語', '拼', '音', '注', '音', '符', '號', '通', '用', '拼', '音', '拼', '音', '系', '統', '羅', '馬', '拼', '音', '街', '道', '譯', '名', '中', '文', '拼', '音', '系', '統', '中', '文', '英', '譯', '系', '統', '注', '音', '第', '二', '式', '統', '一', '標', '準', '音', '標', '母', '語', '國', '際', '化', '教', '育', '部', '行', '政', '院']
['職', '棒', '簽', '賭', '案', '職', '棒', '體', '委', '會', '球', '員', '球', '迷', '球', '團', '球', '場', '虧', '損', '票', '房', '戰', '績', '解', '散', '停', '權', '聯', '盟', '賭', '博', '黑', '金', '約', '談', '涉', '賭', '放', '水']
['受', '虐', '兒', '家', '庭', '暴', '力', '婚', '姻', '暴', '力', '父', '母', '施', '虐', '兒', '扶', '基', '金', '會', '家', '扶', '中', '心', '統', '計', '兒', '童', '福', '利', '兒', '童', '保', '護', '親', '職', '教', '育']
['選', '舉', '候', '選', '人', '中', '選', '會', '電', '視', '台', '錄', '影', 

In [5]:
ret = np.matmul(queries, np.transpose(weight)).argsort(axis=1)[:, ::-1][:, :100]

In [6]:
with open(OUTPUT_FILE, 'w+') as f:
    print('query_id,retrieved_docs', file=f)
    for idx, result in enumerate(ret, 11):
        print(str(idx).zfill(3), ' '.join(['doc_%d' % (id2docname[i].split('/')[-1].lower()) for i in result]), sep=',', file=f)

TypeError: %d format: a number is required, not str