In [5]:
from __future__ import print_function
from time import time

from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.decomposition import NMF, LatentDirichletAllocation
from sklearn.datasets import fetch_20newsgroups

import pickle
from preprocess import *

In [6]:
n_components = 51

In [34]:
def getQLPSortedIndexList(question, cv, tf, lda,cwlist, lam, u):
    # should return list of 
    simiList = []
    qes = cv.transform([question.content])[0].toarray()[0]
    for row in tf:
        prob = 1
        answerDis = lda.transform([row]).tolist()[0]
        for i in range(len(qes)):
            if(qes[i] == 0): continue
            pseudo = (row[i] + u * cwlist[i]) / (u + sum(row))
            
            # calculate plda
            plda = 0
            
            # traverse topic
            for topic_idx, topic in enumerate(lda.components_):
                plda += answerDis[topic_idx] * topic.item(i)*qes[i]
                
            prob *= lam*pseudo + (1-lam) * plda
        
        simiList.append(prob)
    # sort the similarity list, and return the index list.
    res = list(range(len(simiList)))
    return sorted(res, key = lambda i : simiList[i], reverse= True)
    

In [8]:
X, cv, answers, word_ratio = generate_count_vectorizer()

In [50]:
lda = LatentDirichletAllocation(n_components=n_components, max_iter=300,
                                learning_method='online',
                                learning_offset=50.,
                                random_state=0)

In [51]:
lda.fit(X)

LatentDirichletAllocation(batch_size=128, doc_topic_prior=None,
             evaluate_every=-1, learning_decay=0.7,
             learning_method='online', learning_offset=50.0,
             max_doc_update_iter=100, max_iter=300, mean_change_tol=0.001,
             n_components=51, n_jobs=1, n_topics=None, perp_tol=0.1,
             random_state=0, topic_word_prior=None,
             total_samples=1000000.0, verbose=0)

In [33]:
row = X[2]
print(row)
d = lda.transform([row]).tolist()[0]
print(d)

[0 0 0 ... 0 0 0]
[0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.7898537555278093, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.00025799793601675997, 0.19750434560736937, 0.00

In [12]:
def print_top_words(model, feature_names, n_top_words):
    for topic_idx, topic in enumerate(model.components_):
        message = "Topic #%d: " % topic_idx
        message += " ".join([feature_names[i]
                             for i in topic.argsort()[:-n_top_words - 1:-1]])
        print(message)

In [52]:
ques = answers[17]

In [53]:
l = getQLPSortedIndexList(ques, cv, X, lda, word_ratio, 0, 1)
print(l)

[17, 75, 484, 557, 10, 436, 386, 401, 553, 443, 390, 545, 224, 477, 379, 488, 324, 270, 173, 492, 550, 578, 326, 422, 430, 426, 513, 406, 522, 199, 434, 566, 455, 423, 320, 415, 554, 352, 323, 402, 487, 429, 463, 269, 360, 504, 437, 441, 399, 514, 466, 534, 570, 520, 421, 439, 531, 593, 395, 509, 201, 8, 460, 467, 416, 473, 292, 310, 497, 480, 266, 451, 479, 242, 380, 456, 502, 410, 470, 474, 227, 580, 594, 501, 582, 392, 599, 458, 211, 425, 495, 273, 419, 296, 475, 457, 568, 483, 260, 440, 462, 397, 579, 325, 503, 222, 530, 393, 248, 468, 308, 432, 572, 560, 563, 231, 573, 461, 302, 542, 215, 527, 362, 405, 418, 486, 555, 450, 548, 435, 589, 538, 261, 588, 567, 347, 590, 596, 576, 539, 581, 528, 333, 376, 544, 396, 218, 369, 367, 281, 295, 583, 398, 312, 510, 587, 481, 449, 284, 446, 597, 523, 547, 283, 506, 533, 512, 552, 372, 404, 517, 453, 303, 304, 489, 519, 464, 595, 348, 389, 354, 476, 253, 592, 459, 412, 543, 403, 518, 330, 252, 365, 516, 341, 318, 431, 272, 275, 317, 417, 549,

In [42]:
sum(cv.transform([ques.content])[0].toarray()[0])

227

In [None]:
tf_feature_names = cv.get_feature_names()
print_top_words(lda, tf_feature_names, 10)

In [46]:
simiList = [3,2,1,5,9]
res = list(range(len(simiList)))
print(sorted(res, key = lambda i : simiList[i], reverse= True))

[4, 3, 0, 1, 2]
