In [1]:
import numpy as np
import pandas as pd
from pylab import random
import math
from nltk.corpus import PlaintextCorpusReader
from collections import Counter
from tqdm import tqdm 
from numba import jit,njit
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import normalize

# Preprocess

In [2]:
# 使用 nltk PlaintextCorpusReader 存取指定目錄下的所有檔案
doc_corpus_root = './ntust-ir-2020/docs'
query_corpus_root = './ntust-ir-2020/queries'
doc_corpus = PlaintextCorpusReader(doc_corpus_root, '.*')
query_corpus = PlaintextCorpusReader(query_corpus_root, '.*')

# 將 document 依序 存取內容與檔名
all_words = {}
word2id = {}
id2word = {}
doc_names = []
doc_terms = []
for docs in doc_corpus.fileids():
    words = doc_corpus.words(docs)
    word_dict = dict(Counter(words))
    for w in range(len(word_dict)):
        k = list(word_dict.keys())[w]
        v = list(word_dict.values())[w]
        if k in all_words:
            all_words[k] += v
        else:
            all_words[k] = v
    # 檔名篩掉.txt
    doc_names.append(docs[:len(docs) - 4])
    doc_terms.append(word_dict)

query_words = {}
# 將 query 依序 存取內容與檔名
query_names = []
query_terms = []
for queries in query_corpus.fileids():
    words = query_corpus.words(queries)
    word_dict = dict(Counter(words))
    for w in range(len(word_dict)):
        k = list(word_dict.keys())[w]
        v = list(word_dict.values())[w]
        if k in query_words:
            query_words[k] += v
        else:
            query_words[k] = v
    # 檔名篩掉.txt
    query_names.append(queries[:len(queries) - 4])
    query_terms.append(word_dict)

filtered_word = {}
i = 0
for w in all_words.keys():
    if w in query_words or all_words[w] > 10:
        filtered_word[w] = all_words[w]
        word2id[w] = i
        id2word[i] = w
        i += 1

filtered_doc_terms = []
for doc in doc_terms:
    word_dict = {}
    for i in range(len(doc.keys())):
        word = list(doc.keys())[i]
        if word in filtered_word:
            word_dict[word] = doc[word]
    filtered_doc_terms.append(word_dict)


# 確認長度相符
print(len(doc_names),len(doc_terms),len(filtered_doc_terms))
print(len(query_names),len(query_terms))
print(len(word2id),len(id2word),len(filtered_word))
print(len(all_words),len(query_words))

30000 30000 30000
150 150
29735 29735 29735
154240 324


# Function

In [3]:
def get_tfidf(queries,docs,words):
    
    word_lens = len(words)
    doc_lens = len(queries) + len(docs)
    tf_words = np.zeros((doc_lens,word_lens))
    idf_words = np.zeros(word_lens)
    n = 0
    
    for j in tqdm(range(doc_lens)):
        if j < 150:
            terms = queries
            n = j
        else:
            terms = docs
            n = j - 150
        for i in range(word_lens):
            if id2word[i] in terms[n]:
                # tf sublinear
                tf_words[j][i] = np.log(terms[n][id2word[i]]) + 1
                idf_words[i] += 1
            if j == (doc_lens - 1):
                # idf smoothing
                idf_words[i] = math.log(1 + (float(doc_lens) - idf_words[i] + 0.5) / (idf_words[i] + 0.5))
    
    tfidf = np.zeros((doc_lens,word_lens))
    
    for j in range(doc_lens):
        if j < 150:
            terms = queries
            n = j
        else:
            terms = docs
            n = j - 150
        words = list(terms[n].keys())
        for w in words:
            i = word2id[w]
            tfidf[j][i] = tf_words[j][i] * idf_words[i]

    return tf_words, idf_words, tfidf

In [46]:
@jit
def get_bm25_matrix(doc_tf,doc_idf,query_tf,avg_doc_words,doc_words,k1 = 1, k3 = 1000, b = 0.85):
    
    doc_len = len(doc_tf)
    query_len = len(query_tf)
    
    tfidf = np.zeros((query_len,doc_len))
    
    for i in range(query_len):
        for j in range(doc_len):
            for q in range(len(query_tf[i])):
                q_tf = query_tf[i][q]
                if q_tf != 0:
                    _f = doc_tf[j][q] / (1 - b + b * doc_words[j] / avg_doc_words)
                    w_d = (k1 + 1) * (_f + 0.5) / (k1 + _f + 0.5)
                    w_q = (k3 + 1) * q_tf / (k3 + q_tf)
                    tfidf[i][j] += doc_idf[q] * w_d * w_q
    return tfidf

In [None]:
def rocchio(query_tfidf,doc_tfidf,bm25,a=1,b=0.75,r=0,step=7,rel=5,nrel=1):
    
    # get vsm matrix
    vsm = cosine_similarity(query_tfidf,doc_tfidf)
    # combine vsm with bm25
    score = vsm * bm25
    # sort score matrix
    rank = np.flip(score.argsort(), axis=1)
    
    for _ in tqdm(range(step)):
        # record rel_docs
        rels = doc_tfidf[rank[:, :rel]].mean(axis=1)
        # record nonrel_docs
        nrels = doc_tfidf[rank[:, -nrel:]].mean(axis=1)
        # rewrite query tfidf
        query_tfidf = a * query_tfidf + b * rels - r * nrels

        # use new query tfidf get vsm matrix
        vsm = cosine_similarity(query_tfidf,doc_tfidf)
        score = vsm * bm25
        rank = np.flip(score.argsort(axis=1), axis=1)

    return rank

In [5]:
# get tf,idf,tfidf
tf,idf,tfidf = get_tfidf(query_terms,filtered_doc_terms,filtered_word)

100%|██████████| 30150/30150 [03:32<00:00, 142.05it/s]
(30150, 29735) (29735,)


In [42]:
# seperate query & doc tfidf
query_tfidf = tfidf[:150,:]
doc_tfidf = tfidf[150:,:]
# seperate query & doc tf
query_tf = tf[:150,:]
doc_tf = tf[150:,:]

avg_doc_words = sum(all_words.values()) / len(doc_tf)
doc_words = np.zeros(len(doc_tf))
for j in tqdm(range(len(doc_tf))):
    doc_words[j] = sum(doc_terms[j].values())

100%|██████████| 30000/30000 [00:01<00:00, 20325.95it/s]


In [47]:
# get bm25 matrix
bm25_tfidf = get_bm25_matrix(doc_tf,idf,query_tf,avg_doc_words,doc_words)

In [56]:
# get rocchio ranking
ranking = rocchio(query_tfidf,doc_tfidf,bm25_tfidf)

100%|██████████| 7/7 [04:32<00:00, 38.92s/it]


In [57]:
# 讀檔、寫入答案
ans = "Query,RetrievedDocuments"
f = open("rocchio_result.txt","w+")
f.write(ans+'\n')

buf = ""
for i in range(len(query_names)):
    buf = query_names[i] + ','
    first = True

    for s in range(5000):
        if first == True:
            buf += doc_names[ranking[i][s]]
        else:
            buf += (' ' + doc_names[ranking[i][s]])
        first = False

    f.write(buf+'\n')