In [12]:
import numpy as np
import pandas as pd
from pylab import random
import math
from nltk.corpus import PlaintextCorpusReader
from collections import Counter
from tqdm import tqdm 
from numba import jit,njit

# Preprocess

In [36]:
# 使用 nltk PlaintextCorpusReader 存取指定目錄下的所有檔案
doc_corpus_root = './ntust-ir-2020/docs'
query_corpus_root = './ntust-ir-2020/queries'
doc_corpus = PlaintextCorpusReader(doc_corpus_root, '.*')
query_corpus = PlaintextCorpusReader(query_corpus_root, '.*')

# 將 document 依序 存取內容與檔名
all_words = {}
word2id = {}
id2word = {}
doc_names = []
doc_terms = []
for docs in doc_corpus.fileids():
    words = doc_corpus.words(docs)
    word_dict = dict(Counter(words))
    for w in range(len(word_dict)):
        k = list(word_dict.keys())[w]
        v = list(word_dict.values())[w]
        if k in all_words:
            all_words[k] += v
        else:
            all_words[k] = v
    # 檔名篩掉.txt
    doc_names.append(docs[:len(docs) - 4])
    doc_terms.append(word_dict)

query_words = {}
# 將 query 依序 存取內容與檔名
query_names = []
query_terms = []
for queries in query_corpus.fileids():
    words = query_corpus.words(queries)
    word_dict = dict(Counter(words))
    for w in range(len(word_dict)):
        k = list(word_dict.keys())[w]
        v = list(word_dict.values())[w]
        if k in query_words:
            query_words[k] += v
        else:
            query_words[k] = v
    # 檔名篩掉.txt
    query_names.append(queries[:len(queries) - 4])
    query_terms.append(word_dict)

filtered_word = {}
i = 0
for w in all_words.keys():
    if w in query_words or all_words[w] > 500:
        filtered_word[w] = all_words[w]
        word2id[w] = i
        id2word[i] = w
        i += 1

filtered_doc_terms = []
for doc in doc_terms:
    word_dict = {}
    for i in range(len(doc.keys())):
        word = list(doc.keys())[i]
        if word in filtered_word:
            word_dict[word] = doc[word]
    filtered_doc_terms.append(word_dict)

tf_words = np.zeros((len(filtered_doc_terms),len(filtered_word)))
for j in range(len(filtered_doc_terms)):
    for i in range(len(filtered_word)):
        if id2word[i] in filtered_doc_terms[j]:
            tf_words[j][i] = filtered_doc_terms[j][id2word[i]]

# 確認長度相符
print(len(doc_names),len(doc_terms),len(filtered_doc_terms))
print(len(query_names),len(query_terms))
print(len(word2id),len(id2word),len(filtered_word))
print(len(all_words),len(query_words))
print(tf_words.shape)

14955 14955 14955
100 100
226 226 226
111449 226
(14955, 226)


# Function

In [15]:
@jit
def initial_p(words_len,documents_len,K = 8):
    # pwt[i, k] : p(wi|tk)
    pwt = random([words_len, K])
    # ptd[k, j] : p(tk|dj)
    ptd = random([K, documents_len])
    
    for i in range(words_len):
        normalize = sum(pwt[i, :])
        for j in range(K):
            pwt[i, j] /= normalize

    for i in range(K):
        normalize = sum(ptd[i, :])
        for j in range(documents_len):
            ptd[i, j] /= normalize
    
    return pwt,ptd

In [16]:
@jit
def e_step(pwt,ptd):
    words_len = len(pwt)
    documents_len = len(ptd[0])
    K = len(ptd)
    pt_w_d = np.zeros([words_len,documents_len,K])
    for j in range(documents_len):
        for i in range(words_len):
            if id2word[i] in filtered_doc_terms[j]:
                sum_pwt_ptd = 0
                for k in range(K):
                    pt_w_d[i,j,k] = pwt[i,k] * ptd[k,j]
                    sum_pwt_ptd += pt_w_d[i,j,k]
                if sum_pwt_ptd != 0:
                    pt_w_d[i,j] /= sum_pwt_ptd
    return pt_w_d

In [19]:
@njit
def m_step(pwt,ptd,pt_w_d,tf):
    words_len = len(pwt)
    documents_len = len(ptd[0])
    K = len(ptd)
    for k in range(K):
        sum_c_p = 0
        for i in range(words_len):
            pwt[i,k] = 0
            for j in range(documents_len):
                    pwt[i,k] += tf[j][i] * pt_w_d[i,j,k]
            sum_c_p += pwt[i,k]
        for i in range(words_len):
            if sum_c_p == 0:
                pwt[i,k] = 1 / words_len
            else:
                pwt[i,k] /= sum_c_p
    for j in range(documents_len):
        dj_len = len(tf[j])
        for k in range(K):
            ptd[k,j] = 0
            for i in range(words_len):
                    ptd[k,j] += tf[j][i] * pt_w_d[i,j,k]
            ptd[k,j] /= dj_len 
    return pwt,ptd

In [26]:
def get_pwd_pwbg(words_len,documents_len):
    pwd = np.zeros((words_len,documents_len))
    for j in range(documents_len):
        dj_len = len(doc_terms[j])
        if dj_len != 0:
            for i in range(words_len):
                pwd[i,j] = tf_words[j][i] / dj_len
    pwbg = np.zeros(words_len)
    for i in range(words_len):
        pwbg[i] = filtered_word[id2word[i]] / sum(filtered_word.values())
    return pwd,pwbg

In [8]:
def em_step(pwt,ptd,pt_w_d,tf,max_iter = 30):
    for i in tqdm(range(max_iter)):
        pt_w_d = e_step(pwt,ptd)
        pwt,ptd = m_step(pwt,ptd,pt_w_d,tf)
    return pwt,ptd

In [9]:
def get_pwt_ptd(tf,K = 8,max_iter = 30):
    pwt,ptd = initial_p(len(filtered_word),len(doc_terms),K)
    pwt,ptd = em_step(pwt,ptd,max_iter,tf)
    return pwt,ptd

In [43]:
def plsa(pwt,ptd,alpha=0.2,beta=0.4):
    ans = "Query,RetrievedDocuments"
    f = open("vsm_result.txt","w+")
    f.write(ans+'\n')

    buf = ""
    
    # Initial
    pwd,pwbg = get_pwd_pwbg(len(filtered_word),len(doc_terms))
    dot_wd = np.dot(pwt,ptd)
    
    for i in range(len(query_terms)):
        buf = query_names[i] + ','

        first = True

        pqd = np.ones(len(doc_terms))
        qwords = list(query_terms[i].keys())
        for word in qwords:
            wordid = word2id[word]
            for j in range(len(doc_terms)):
                awd = alpha * pwd[wordid,j]
                bwd = beta * dot_wd[wordid,j]
                a_b = (1 - alpha - beta) * pwbg[wordid]
                pqd[j] *= (awd+bwd+a_b)
                
        pqd_sort = pqd.argsort()[::-1]
        for s in range(1000):
            if first == True:
                buf += doc_names[pqd_sort[s]]
            else:
                buf += (' ' + doc_names[pqd_sort[s]])
            first = False
        
        f.write(buf+'\n')

In [45]:
# set K and iter to get pwt and ptd
pwt,ptd = get_pwt_ptd(tf_words,K=10,max_iter=30)
# plsa
plsa(pwt,ptd)

100%|██████████| 30/30 [00:14<00:00,  2.10it/s]
