## PLSA

In [1]:
import re
import numpy as np
import pandas as pd
import nltk
from pyspark import SparkContext
from nltk.stem import SnowballStemmer
from tokenize import tokenize
from nltk.corpus import stopwords,words

In [2]:
# coding:utf8
from pyspark import SparkContext
from pyspark import RDD
import numpy as np
from numpy.random import RandomState

import sys
if sys.version[0] == '2':
    reload(sys)
    sys.setdefaultencoding("utf-8")



class PLSA:

    def __init__(self, data, sc, k, is_test=False, max_itr=1000, eta=1e-6):

        """
        init the algorithm

        :type data RDD
        :param data: document rdd
        :type max_itr int
        :param max_itr: maximum EM iter
        :type is_test bool
        :param is_test: test or not,if yes, rd = RandomState(1)，otherwise rd = RandomState()
        :type sc SparkContext
        :param sc: spark context
        :type k int
        :param k : number of theme
        :type eta float
        :param : threshold，when the changement of log likelyhood<eta, stop iteration
        :return : PLSA object
        """

        self.max_itr = max_itr
        self.k = sc.broadcast(k)
        self.ori_data = data.map(lambda x: x.split(' '))
        self.data = data
        self.sc = sc
        self.eta = eta
        self.rd = sc.broadcast(RandomState(1) if is_test else RandomState())

    def train(self):
        #get the dictionary words
        self.word_dict_b = self._init_dict_()
        #transform the words in the documents into the indexes in the dictionary
        self._convert_docs_to_word_index()
        #initialization, the distribution under each theme
        self._init_probility_word_topic_()

        pre_l= self._log_likelyhood_()

        print("L(%d)=%.5f" %(0,pre_l))

        for i in range(self.max_itr):
            #update the posterior distribution
            self._E_step_()
            #maimize the lower bound
            self._M_step_()
            now_l = self._log_likelyhood_()

            improve = np.abs((pre_l-now_l)/pre_l)
            pre_l = now_l

            print("L(%d)=%.5f with %.6f%% improvement" %(i+1,now_l,improve*100))
            if improve <self.eta:
                break

    def _M_step_(self):
        """
        update: p(z=k|d),p(w|z=k)
        :return: None
        """
        k = self.k
        v = self.v

        def update_probility_of_doc_topic(doc):
            """
            update the distribution of the documents of the themes
            """
            topic_doc = doc['topic'] - doc['topic']
            words = doc['words']
            for (word_index,word) in words.items():
                topic_doc += word['count']*word['topic_word']
            topic_doc /= np.sum(topic_doc)

            return {'words':words,'topic':topic_doc}

        self.data = self.data.map(update_probility_of_doc_topic)
        """
        rdd相当于一系列操作过程的结合，且前面的操作过程嵌套在后面的操作过程里，当这个嵌套超过大约60，spark会报错，
        这里每次M step都通过cache将前面的操作执行掉
        """
        self.data.cache()

        def update_probility_word_given_topic(doc):
            """
            up date the distribution of the words of the themes
            """
            probility_word_given_topic = np.matrix(np.zeros((k.value,v.value)))

            words = doc['words']
            for (word_index,word) in words.items():
                probility_word_given_topic[:,word_index] += np.matrix(word['count']*word['topic_word']).T

            return probility_word_given_topic

        probility_word_given_topic = self.data.map(update_probility_word_given_topic).sum()
        probility_word_given_topic_row_sum = np.matrix(np.sum(probility_word_given_topic,axis=1))

        #使每个主题下单词概率和为1
        probility_word_given_topic = np.divide(probility_word_given_topic,probility_word_given_topic_row_sum)

        self.probility_word_given_topic = self.sc.broadcast(probility_word_given_topic)

    def _E_step_(self):
        """
        update the latent viariable:  p(z|w,d)-给定文章，和单词后，该单词的主题分布
        :return: None
        """
        probility_word_given_topic = self.probility_word_given_topic
        k = self.k

        def update_probility_of_word_topic_given_word(doc):
            topic_doc = doc['topic']
            words = doc['words']

            for (word_index,word) in words.items():
                topic_word = word['topic_word']
                for i in range(k.value):
                    topic_word[i] = probility_word_given_topic.value[i,word_index]*topic_doc[i]
                #使该单词各主题分布概率和为1
                topic_word /= np.sum(topic_word)
                word['topic_word'] = topic_word # added
            return {'words':words,'topic':topic_doc}

        self.data = self.data.map(update_probility_of_word_topic_given_word)

    def  _init_probility_word_topic_(self):
        """
        init p(w|z=k)
        :return: None
        """
        #dict length(words in dict)
        m = self.v.value

        probility_word_given_topic = self.rd.value.uniform(0,1,(self.k.value,m))
        probility_word_given_topic_row_sum = np.matrix(np.sum(probility_word_given_topic,axis=1)).T

        #使每个主题下单词概率和为1
        probility_word_given_topic = np.divide(probility_word_given_topic,probility_word_given_topic_row_sum)

        self.probility_word_given_topic = self.sc.broadcast(probility_word_given_topic)

    def _convert_docs_to_word_index(self):

        word_dict_b = self.word_dict_b
        k = self.k
        rd = self.rd
        '''
        I wonder is there a better way to execute function with broadcast varible
        '''
        def _word_count_doc_(doc):
            print(doc)
            wordcount ={}
            word_dict = word_dict_b.value
            for word in doc:
                if word_dict[word] in wordcount:
                    wordcount[word_dict[word]]['count'] += 1
                else:
                    #first one is the number of word occurance
                    #second one is p(z=k|w,d)
                    wordcount[word_dict[word]] = {'count':1,'topic_word': rd.value.uniform(0,1,k.value)}

            topics = rd.value.uniform(0, 1, k.value)
            topics = topics/np.sum(topics)
            return {'words':wordcount,'topic':topics}
        self.data = self.ori_data.map(_word_count_doc_)

    def _init_dict_(self):
        """
        init word dict of the documents,
        and broadcast it
        :return: None
        """
        words = self.ori_data.flatMap(lambda d: d).distinct().collect()
        word_dict = {w: i for w, i in zip(words, range(len(words)))}
        self.v = self.sc.broadcast(len(word_dict))
        return self.sc.broadcast(word_dict)

    def _log_likelyhood_(self):
        
        probility_word_given_topic = self.probility_word_given_topic
        k = self.k
        def likelyhood(doc):
            print("succ")
            l = 0.0
            topic_doc = doc['topic']
            words = doc['words']
            for (word_index,word) in words.items():
                print(word)
                l += word['count']*np.log(np.matrix(topic_doc)*probility_word_given_topic.value[:,word_index])
            return l
        return self.data.map(likelyhood).sum()



    def save(self):
        """
        保存模型结果 TODO 添加分布式保存结果
        :param f_word_given_topic: 文件路径，用于给定主题下词汇分布
        :param f_doc_topic: 文件路径，用于保存文档的主题分布
        :return:
        """
        doc_topic = self.data.map(lambda x:' '.join([str(q) for q in x['topic'].tolist()])).collect()
        probility_word_given_topic = self.probility_word_given_topic.value

        word_dict = self.word_dict_b.value
        word_given_topic = []

        for w,i in word_dict.items():
            word_given_topic.append('%s %s' %(w,' '.join([str(q[0]) for q in probility_word_given_topic[:,i].tolist()])))
        return word_given_topic

In [3]:
sc = SparkContext()

In [18]:
#data = sc.textFile("/Users/panxiao/IdeaProjects/lyric_project/input/lyrics.csv")
data = sc.textFile("../lyr.txt")
data1 = sc.parallelize(data.take(500))

In [19]:
stemmer = SnowballStemmer('english')
def token_processor(token):
    return stemmer.stem(token)
wordss = words.words()
#stopwds = ['oh','yeah']

In [20]:
data2 = data1.map(lambda x: re.sub(r"\W+"," ", str(x).lower()))
data3 = data2.map(lambda x: x.split(' '))
data4 = data3.map(lambda l: [x for x in l if x and x in wordss and x not in stopwords.words("english")])
data5 = data4.map(lambda l: " ".join(x for x in l))

In [22]:
data5.take(1)

['oh baby know cut right chase made like think special purpose know special feel baby let get lost need call work cause boss real want show feel consider lucky big deal well got key heart need rather open body show know inside need lie big wide strong fit much tough talk like cause back got big ego huge ego love big ego much walk like cause back usually humble right choose leave could blues call arrogant call confident decide find working damn know killing better yet matter fact smile maybe boy site see kind something like big wide strong fit much tough talk like cause back got big ego huge ego love big ego much walk like cause back walk like cause back talk like cause back back back walk like cause back big wide strong fit much tough talk like cause back got big ego huge ego huge ego love big ego much walk like cause back ego big must admit got every reason feel like bitch ego strong know need beat sing piano']

In [23]:
plsa = PLSA(data=data5,sc=sc,k=5,max_itr=10,is_test=True)

In [24]:
plsa.train()
#plsa.save('topic_word.txt','doc_topic.txt')
ls = plsa.save()

L(0)=-519423.63439
L(1)=-414624.34012 with 20.176074% improvement
L(2)=-412472.63136 with 0.518954% improvement
L(3)=-410011.56790 with 0.596661% improvement
L(4)=-407131.43120 with 0.702453% improvement
L(5)=-404035.25694 with 0.760485% improvement
L(6)=-401016.55028 with 0.747139% improvement
L(7)=-398322.00978 with 0.671928% improvement
L(8)=-396055.06040 with 0.569125% improvement
L(9)=-394200.93908 with 0.468147% improvement
L(10)=-392710.05321 with 0.378205% improvement


In [25]:
topic_word = pd.DataFrame([sub.split(" ") for sub in ls])
topic_word_1 = topic_word[topic_word.columns[:6]]
topic_word_1.head(4)

Unnamed: 0,0,1,2,3,4,5
0,right,0.0041982482283977,0.0018460317701269,0.0058832933656716,0.0024316904024804,0.0060714750560997
1,like,0.0255764953993839,0.0054901048725914,0.031201165185119,0.0025825491836189,0.0180369293554047
2,think,2.988225359226317e-05,0.0044015303502615,0.0069643108991005,0.002531584672666,0.0035333135692125
3,feel,0.0093934114643313,0.0014197868494694,0.0079014099222479,0.0057139947853309,0.0063125802156915


In [26]:
topic_word.sort_values(topic_word.columns[1],ascending=False)

Unnamed: 0,0,1,2,3,4,5
428,silk,9.987793790512631e-20,0.00016764450189616517,1.4961907693543038e-10,1.819199153879392e-15,4.9690679623266654e-08
2970,easy,9.974889198741627e-06,0.0004864952160510638,0.0007433387998779449,0.0006377915230294228,0.00135130084152132
2572,starting,9.970872766952465e-06,0.0002476277506536763,4.31164360141329e-08,2.0361921751639067e-14,0.00017443381097943933
1312,pink,9.970599615354474e-05,2.905771648746494e-10,7.749774956131984e-05,0.0001321796310546565,5.544179635964781e-11
737,throat,9.961049031282879e-06,0.00015061971490683046,5.038090751094411e-07,9.943895192362565e-10,9.762657947649796e-05
1744,nature,9.956160597268959e-05,1.4218002720665332e-06,1.056460095177404e-11,4.466049830003593e-05,2.252032973222607e-05
1192,bey,9.952803411389545e-05,3.475007505124795e-05,6.730354816855844e-06,0.0002592380931780464,6.28409049463942e-10
900,slandering,9.947339937766497e-14,8.978386448732395e-10,3.5481210360025108e-09,7.54297833882271e-05,2.9129697379774838e-06
2497,taller,9.927202725479196e-19,8.426648015277486e-07,0.00013495126378796162,3.0922644256753563e-10,1.8246429262035765e-07
1220,wisdom,9.92622697223348e-20,8.384511709015553e-05,3.7129155102979104e-16,2.526451359618142e-10,2.7767155591863337e-17


In [27]:
topic_word.sort_values(topic_word.columns[2],ascending=False)

Unnamed: 0,0,1,2,3,4,5
3345,pace,6.156460100245986e-05,9.991954681233137e-05,4.140972383629875e-06,0.0006339223589276521,0.00035223041498812246
1454,basement,3.3190736441116e-13,9.980381790169678e-10,2.6592158000587475e-08,3.7055376307067586e-06,0.00017645350625856445
3634,arrest,1.3415796531668286e-11,9.977217732204137e-07,9.338455985419466e-11,0.0005436006973027203,1.2727621454067432e-06
524,,2.43310971950123e-80,9.969212868372406e-30,0.0,0.0,0.007412277530202364
3493,roc,0.002689672298615259,9.946726690773107e-06,3.887432538826659e-05,5.795406651195915e-05,1.827341155446946e-05
3798,curin,6.669916165976315e-05,9.937731448010704e-08,3.0230006613673715e-26,7.904438322063019e-16,1.9982944836751805e-05
1887,froze,7.191333610655678e-17,9.926812155419056e-15,6.766989057672993e-05,2.0402330911761675e-07,5.0378763292311994e-08
387,lap,7.177341153249981e-06,9.918675992162166e-05,0.0016101131395791462,8.094108839607482e-10,1.3761549495469432e-06
2725,haze,2.355694590467816e-10,9.917990711305021e-05,1.5016156752890328e-08,6.367341182524277e-05,6.0274364154528445e-12
340,risk,0.0007861053244542086,9.899983289254441e-11,1.1998321292885958e-07,0.0003162327008586878,0.00016042749734496227


In [28]:
topic_word.sort_values(topic_word.columns[3],ascending=False)

Unnamed: 0,0,1,2,3,4,5
1477,instance,8.575924772028526e-05,9.772812804999387e-18,9.996207763607926e-20,2.390938956547019e-19,1.5550911592056626e-12
2552,father,1.8163747564791105e-05,0.000543872417805372,9.97816005061917e-06,0.0005398365127482586,2.068907606629042e-05
338,safe,0.00011668919410660521,7.079173558172317e-07,9.966883555216088e-06,3.409577640538692e-08,0.0003148988652071716
2765,ninety,3.2273499034227156e-08,1.2865811471468574e-13,9.959404482675826e-07,7.677367001899182e-05,1.186655194531833e-15
3925,amazing,8.068203048907179e-08,9.118251080765701e-10,9.947953100382818e-10,7.786891758193526e-05,2.679551016216886e-09
2842,regard,4.931517626273373e-06,3.5996643299480053e-07,9.932772283009156e-18,7.194167226899521e-05,1.377687332313162e-06
2132,anything,1.330974954345719e-06,0.0015799816280226042,9.91410509120907e-05,6.513006050464127e-06,0.0005058893587310516
2678,dwell,1.629739606705602e-05,0.0002356004560671006,9.902103069907372e-10,7.522631377059115e-11,7.697765675523461e-10
2404,cycle,0.00024965786023460384,3.835551665482984e-06,9.894258485808653e-08,3.2461982798776987e-06,2.3394585985920416e-10
2417,slim,4.824596247490626e-07,2.6279592868662957e-11,9.877906200523734e-08,2.023882291990639e-14,0.00027054079500461937


In [29]:
topic_word.sort_values(topic_word.columns[4],ascending=False)

Unnamed: 0,0,1,2,3,4,5
3387,film,3.560152813011912e-06,1.6947682336504111e-12,4.519802313322153e-23,9.996436747532097e-06,0.0005270164860025525
2931,sighted,1.287403762733885e-09,0.00016729321541508353,2.722034601693692e-09,9.995918473971515e-10,4.224707608039084e-07
593,rip,7.214113719347581e-07,1.3896812778805708e-06,6.559666911614296e-06,9.980083092664554e-07,0.0008013920657742529
3119,mort,8.57581520488852e-05,6.993888396678325e-22,2.6757610825633827e-31,9.971962977189538e-10,1.795128618814623e-20
2533,negro,5.097719789392088e-12,0.00032248536332287327,2.7399948718924757e-13,9.968704737761752e-11,1.3903250236868477e-05
2937,nope,3.081603220191553e-05,3.482391773381884e-09,0.00031494263876923544,9.965215278932382e-08,1.8993220578905774e-09
581,gue,8.574887283430182e-05,9.19872703708682e-24,2.909003473389043e-15,9.964634076511375e-10,9.781503913793523e-09
2604,miracle,0.0002526097722204594,3.0881940472842824e-09,2.334149964216896e-06,9.96336416067886e-07,6.533967744259929e-07
3496,golly,8.9200232940758e-13,8.156973523967617e-05,2.38225685308097e-15,9.960358712383437e-11,2.453263395422553e-06
891,stell,8.574928212291784e-05,1.2588619932592038e-16,4.673033709576547e-10,9.958380046522e-16,9.883442827225742e-09


In [30]:
topic_word.sort_values(topic_word.columns[5],ascending=False)

Unnamed: 0,0,1,2,3,4,5
994,ego,2.8574868776111153e-16,3.8513620137526034e-06,0.0017618949921520868,1.4767548985657737e-15,9.989714509598533e-09
820,drill,3.340116514156032e-16,8.374419917228538e-05,8.116827616763227e-08,1.326125259837157e-11,9.965331898036602e-10
3856,nitty,1.0838607518763716e-05,1.8289798010990232e-07,5.9115597286795434e-05,4.8332975800756376e-08,9.96031049553352e-14
1992,reward,0.0005985420853775077,2.1625466812576513e-05,5.101241979356265e-05,8.808268682111516e-07,9.94623200064934e-12
1306,growing,1.598979860763117e-07,0.00017570501814376005,5.426203042977658e-05,2.504363786343615e-07,9.943489470551181e-05
1253,presenter,6.983021320588812e-11,0.00016676890543336595,7.105464017346803e-10,1.5793469591435112e-20,9.928498241696194e-07
1927,passable,2.5400316801611417e-05,5.901192237751715e-05,5.5342574473104315e-18,7.924972784314421e-18,9.92244889386265e-15
2333,flame,8.11995151508468e-05,5.049637596854215e-08,0.000588179864379619,0.00010062257563428839,9.919300000237616e-05
1892,calculating,8.575924911297247e-05,4.485826698119404e-24,1.3879828482135454e-14,5.913138060119052e-14,9.914689527339784e-17
796,poppa,1.5077247923841876e-06,1.609902899336297e-11,0.0001345482161467313,3.2570978554881107e-08,9.909778966490076e-10
