## PLSA

In [1]:
import re
import numpy as np
import pandas as pd
import nltk
from pyspark import SparkContext
from nltk.stem import SnowballStemmer
from tokenize import tokenize
from nltk.corpus import stopwords,words

In [14]:
# coding:utf8
from pyspark import SparkContext
from pyspark import RDD
import numpy as np
from numpy.random import RandomState

import sys
if sys.version[0] == '2':
    reload(sys)
    sys.setdefaultencoding("utf-8")



class PLSA:

    def __init__(self, data, sc, k, is_test=False, max_itr=1000, eta=1e-6):

        """
        init the algorithm

        :type data RDD
        :param data: document rdd
        :type max_itr int
        :param max_itr: maximum EM iter
        :type is_test bool
        :param is_test: test or not,if yes, rd = RandomState(1)，otherwise rd = RandomState()
        :type sc SparkContext
        :param sc: spark context
        :type k int
        :param k : number of theme
        :type eta float
        :param : threshold，when the changement of log likelyhood<eta, stop iteration
        :return : PLSA object
        """

        self.max_itr = max_itr
        self.k = sc.broadcast(k)
        self.ori_data = data.map(lambda x: x.split(' '))
        self.data = data
        self.sc = sc
        self.eta = eta
        self.rd = sc.broadcast(RandomState(1) if is_test else RandomState())

    def train(self):
        #get the dictionary words
        self.word_dict_b = self._init_dict_()
        #transform the words in the documents into the indexes in the dictionary
        self._convert_docs_to_word_index()
        #initialization, the distribution under each theme
        self._init_probility_word_topic_()

        pre_l= self._log_likelyhood_()

        print("L(%d)=%.5f" %(0,pre_l))

        for i in range(self.max_itr):
            #update the posterior distribution
            self._E_step_()
            #maimize the lower bound
            self._M_step_()
            now_l = self._log_likelyhood_()

            improve = np.abs((pre_l-now_l)/pre_l)
            pre_l = now_l

            print("L(%d)=%.5f with %.6f%% improvement" %(i+1,now_l,improve*100))
            if improve <self.eta:
                break

    def _M_step_(self):
        """
        update: p(z=k|d),p(w|z=k)
        :return: None
        """
        k = self.k
        v = self.v

        def update_probility_of_doc_topic(doc):
            """
            update the distribution of the documents of the themes
            """
            topic_doc = doc['topic'] - doc['topic']
            words = doc['words']
            for (word_index,word) in words.items():
                topic_doc += word['count']*word['topic_word']
            topic_doc /= np.sum(topic_doc)

            return {'words':words,'topic':topic_doc}

        self.data = self.data.map(update_probility_of_doc_topic)
        """
        rdd相当于一系列操作过程的结合，且前面的操作过程嵌套在后面的操作过程里，当这个嵌套超过大约60，spark会报错，
        这里每次M step都通过cache将前面的操作执行掉
        """
        self.data.cache()

        def update_probility_word_given_topic(doc):
            """
            up date the distribution of the words of the themes
            """
            probility_word_given_topic = np.matrix(np.zeros((k.value,v.value)))

            words = doc['words']
            for (word_index,word) in words.items():
                probility_word_given_topic[:,word_index] += np.matrix(word['count']*word['topic_word']).T

            return probility_word_given_topic

        probility_word_given_topic = self.data.map(update_probility_word_given_topic).sum()
        probility_word_given_topic_row_sum = np.matrix(np.sum(probility_word_given_topic,axis=1))

        #使每个主题下单词概率和为1
        probility_word_given_topic = np.divide(probility_word_given_topic,probility_word_given_topic_row_sum)

        self.probility_word_given_topic = self.sc.broadcast(probility_word_given_topic)

    def _E_step_(self):
        """
        update the latent viariable:  p(z|w,d)-给定文章，和单词后，该单词的主题分布
        :return: None
        """
        probility_word_given_topic = self.probility_word_given_topic
        k = self.k

        def update_probility_of_word_topic_given_word(doc):
            topic_doc = doc['topic']
            words = doc['words']

            for (word_index,word) in words.items():
                topic_word = word['topic_word']
                for i in range(k.value):
                    topic_word[i] = probility_word_given_topic.value[i,word_index]*topic_doc[i]
                #使该单词各主题分布概率和为1
                topic_word /= np.sum(topic_word)
                word['topic_word'] = topic_word # added
            return {'words':words,'topic':topic_doc}

        self.data = self.data.map(update_probility_of_word_topic_given_word)

    def  _init_probility_word_topic_(self):
        """
        init p(w|z=k)
        :return: None
        """
        #dict length(words in dict)
        m = self.v.value

        probility_word_given_topic = self.rd.value.uniform(0,1,(self.k.value,m))
        probility_word_given_topic_row_sum = np.matrix(np.sum(probility_word_given_topic,axis=1)).T

        #使每个主题下单词概率和为1
        probility_word_given_topic = np.divide(probility_word_given_topic,probility_word_given_topic_row_sum)

        self.probility_word_given_topic = self.sc.broadcast(probility_word_given_topic)

    def _convert_docs_to_word_index(self):

        word_dict_b = self.word_dict_b
        k = self.k
        rd = self.rd
        '''
        I wonder is there a better way to execute function with broadcast varible
        '''
        def _word_count_doc_(doc):
            print(doc)
            wordcount ={}
            word_dict = word_dict_b.value
            for word in doc:
                if word_dict[word] in wordcount:
                    wordcount[word_dict[word]]['count'] += 1
                else:
                    #first one is the number of word occurance
                    #second one is p(z=k|w,d)
                    wordcount[word_dict[word]] = {'count':1,'topic_word': rd.value.uniform(0,1,k.value)}

            topics = rd.value.uniform(0, 1, k.value)
            topics = topics/np.sum(topics)
            return {'words':wordcount,'topic':topics}
        self.data = self.ori_data.map(_word_count_doc_)

    def _init_dict_(self):
        """
        init word dict of the documents,
        and broadcast it
        :return: None
        """
        words = self.ori_data.flatMap(lambda d: d).distinct().collect()
        word_dict = {w: i for w, i in zip(words, range(len(words)))}
        self.v = self.sc.broadcast(len(word_dict))
        return self.sc.broadcast(word_dict)

    def _log_likelyhood_(self):
        
        probility_word_given_topic = self.probility_word_given_topic
        k = self.k
        def likelyhood(doc):
            print("succ")
            l = 0.0
            topic_doc = doc['topic']
            words = doc['words']
            for (word_index,word) in words.items():
                print(word)
                l += word['count']*np.log(np.matrix(topic_doc)*probility_word_given_topic.value[:,word_index])
            return l
        return self.data.map(likelyhood).sum()



    def save(self):
        """
        保存模型结果 TODO 添加分布式保存结果
        :param f_word_given_topic: 文件路径，用于给定主题下词汇分布
        :param f_doc_topic: 文件路径，用于保存文档的主题分布
        :return:
        """
        doc_topic = self.data.map(lambda x:' '.join([str(q) for q in x['topic'].tolist()])).collect()
        probility_word_given_topic = self.probility_word_given_topic.value

        word_dict = self.word_dict_b.value
        word_given_topic = []

        for w,i in word_dict.items():
            word_given_topic.append('%s %s' %(w,' '.join([str(q[0]) for q in probility_word_given_topic[:,i].tolist()])))
        return word_given_topic

In [3]:
sc = SparkContext()

In [4]:
#data = sc.textFile("/Users/panxiao/IdeaProjects/lyric_project/input/lyrics.csv")
data = sc.textFile("../lyr.txt")
data1 = sc.parallelize(data.take(100))

In [5]:
stemmer = SnowballStemmer('english')
def token_processor(token):
    return stemmer.stem(token)
wordss = words.words()
#stopwds = ['oh','yeah']

In [6]:
data2 = data1.map(lambda x: re.sub(r"\W+"," ", str(x).lower()))
data3 = data2.map(lambda x: x.split(' '))
data4 = data3.map(lambda l: [x for x in l if x and x in wordss and x not in stopwords.words("english")])
data5 = data4.map(lambda l: " ".join(x for x in l))

In [8]:
data5.take(1)

['oh baby know cut right chase made like think special purpose know special feel baby let get lost need call work cause boss real want show feel consider lucky big deal well got key heart need rather open body show know inside need lie big wide strong fit much tough talk like cause back got big ego huge ego love big ego much walk like cause back usually humble right choose leave could blues call arrogant call confident decide find working damn know killing better yet matter fact smile maybe boy site see kind something like big wide strong fit much tough talk like cause back got big ego huge ego love big ego much walk like cause back walk like cause back talk like cause back back back walk like cause back big wide strong fit much tough talk like cause back got big ego huge ego huge ego love big ego much walk like cause back ego big must admit got every reason feel like bitch ego strong know need beat sing piano']

In [15]:
plsa = PLSA(data=data5,sc=sc,k=5,max_itr=10,is_test=True)

In [16]:
plsa.train()
#plsa.save('topic_word.txt','doc_topic.txt')
ls = plsa.save()

L(0)=-110909.55902
L(1)=-91462.08496 with 17.534534% improvement
L(2)=-90441.77898 with 1.115551% improvement
L(3)=-89204.87577 with 1.367624% improvement
L(4)=-87844.96764 with 1.524477% improvement
L(5)=-86587.70984 with 1.431223% improvement
L(6)=-85553.80703 with 1.194053% improvement
L(7)=-84755.59226 with 0.932997% improvement
L(8)=-84170.31000 with 0.690553% improvement
L(9)=-83750.51584 with 0.498744% improvement
L(10)=-83457.14553 with 0.350291% improvement


In [17]:
topic_word = pd.DataFrame([sub.split(" ") for sub in ls])
topic_word_1 = topic_word[topic_word.columns[:6]]
topic_word_1.head(4)

Unnamed: 0,0,1,2,3,4,5
0,right,0.0081690120547807,0.0036052470660925,0.0046087763505936,0.0002166960317816,0.003821495523717
1,like,0.0056395178406721,0.0134497120266915,0.0015976668484767,0.0468821223268225,0.0147340346082277
2,think,1.0663123845065648e-05,0.0041543068885823,0.0006545273644415,0.0098923239846797,0.0023623738013863
3,feel,0.0182695396782924,0.0018300227020574,0.0126621284215963,0.0006609615624248,0.0055921231426698


In [18]:
topic_word.sort_values(topic_word.columns[1],ascending=False)

Unnamed: 0,0,1,2,3,4,5
1056,three,9.962209050552621e-09,0.0021353944351227287,6.171477087078391e-06,1.7748935039954166e-05,0.0002625587260974951
1207,cave,9.962035265137298e-18,8.00990268695055e-20,0.00030345757500115897,2.2541268197021487e-09,9.333846742822569e-11
865,dream,9.946920325885371e-05,1.4043807171930419e-05,0.0014341572587929508,5.2299288403503666e-09,0.0008754034684200555
1014,hide,9.94490119392053e-09,2.390915316238893e-07,4.576621309866887e-14,4.170139892050165e-14,0.0014830601784230756
783,hill,9.923067496969288e-08,1.0821503346047442e-07,7.152088942093597e-07,4.531796215898954e-11,0.0002957864797834418
1259,poppin,9.920925842909726e-27,0.0004233323544095974,2.93343925025728e-10,2.7162886164412994e-18,5.48128482515086e-17
170,illusion,9.918746493172246e-06,9.8165746941058e-05,3.560530559080408e-15,0.0006136897395914706,9.749624315699093e-14
258,dice,9.915100449792384e-12,4.8240458410327874e-08,1.8056870774322212e-24,4.144019283998348e-06,0.0002931295443930155
1289,choose,9.87614549697593e-06,0.0004436809042849039,0.00027838612577178696,1.5240362846659795e-06,4.3498342545135006e-07
205,towards,9.854108351734434e-05,1.0522369552918306e-07,1.1617402712844486e-17,2.5688910945798273e-13,0.0002114084690641458


In [19]:
topic_word.sort_values(topic_word.columns[2],ascending=False)

Unnamed: 0,0,1,2,3,4,5
1663,doesnt,3.8594873656884405e-11,9.98068832658232e-23,2.2287228396374685e-12,1.2962980364738869e-20,0.0002966472274807554
1231,handle,1.5550766541463682e-29,9.970279073822527e-14,0.000303457855375921,2.039055407470484e-09,1.3482000970523787e-26
464,clock,0.00034323061193913763,9.952583569073159e-09,1.5784461468462287e-12,2.059894356373157e-29,1.4245077973768514e-19
1052,satisfaction,4.084356186911517e-07,9.952501215895359e-10,0.0010875938838763055,3.747513325005179e-18,0.00012305680446099586
288,political,1.1658779976175094e-07,9.941691172956935e-20,1.358553150784533e-12,0.00020727399415435861,0.00012228980712580782
201,ar,1.4700714056541912e-14,9.939020019368329e-12,1.2488493956409826e-16,8.995536251440924e-13,0.0008899417813125038
321,punish,0.0006779536553472432,9.926984247402558e-06,4.190051190624689e-07,9.978744446969041e-10,1.433849575753732e-22
1086,volar,3.430505518768611e-12,9.918375036379847e-09,1.1111404709237214e-24,1.969597879510202e-06,0.0002949844551498762
1181,sneak,5.704100593662629e-05,9.917761159207047e-11,0.0006081160168891936,4.742023137238859e-15,0.00024617903529017266
313,peak,0.0013461295504712677,9.911547467460712e-06,1.646250518770639e-05,1.7310976324578853e-07,1.8941550473116793e-20


In [20]:
topic_word.sort_values(topic_word.columns[3],ascending=False)

Unnamed: 0,0,1,2,3,4,5
1384,final,5.271473775443057e-12,1.0273375972730982e-14,9.999225334818793e-14,9.021926248508754e-13,0.0005932945206115104
157,hiver,4.037365041834211e-38,0.0004214565598589641,9.977926511685816e-18,1.5638445948122374e-06,6.814358863008622e-13
1470,strip,0.00022877763225788724,2.0145160245648753e-06,9.974964733156941e-05,2.3910080926294173e-09,3.549663051235926e-18
68,bar,0.0003395417750521666,0.0004278783811495758,9.964336558729114e-09,2.008970126359147e-15,3.966936241215306e-11
259,cur,6.972110856731649e-34,1.0988739087000217e-07,9.902710493670916e-18,6.697707938700197e-22,0.00029657026024941074
57,sun,0.007730341128333872,4.8027874721417866e-08,9.878873752530699e-07,6.295693435752928e-09,0.00014086230342287586
1596,sweating,3.1094234775171886e-20,3.3333786474978297e-22,9.873099666090643e-12,0.00035285451704773345,5.940963349404876e-17
1559,forward,2.7944311333464397e-14,0.00042333276345719027,9.86470182646769e-14,1.0267876093935244e-27,1.1021991278325386e-15
496,coat,1.5450016010675818e-30,4.271298558519024e-37,9.843086971391544e-16,0.0003528545285268297,2.3741136098772425e-32
1032,indeed,1.0406386127336917e-05,3.644153948713238e-08,9.835314857840807e-18,5.818178686991513e-13,0.000287627909222297


In [21]:
topic_word.sort_values(topic_word.columns[4],ascending=False)

Unnamed: 0,0,1,2,3,4,5
321,punish,0.0006779536553472432,9.926984247402558e-06,4.190051190624689e-07,9.978744446969041e-10,1.433849575753732e-22
1337,cant,0.0010951598757899278,2.055936794301839e-06,5.448118019956243e-06,9.977203155734957e-10,0.002903144852914099
645,soy,1.0247426321221022e-13,1.4039416483867829e-09,1.2825087939589244e-21,9.952717136591498e-08,0.005042918814185039
1253,deceit,8.53514614800308e-08,1.3155160356898054e-07,1.393200730441258e-06,9.948080146540083e-11,0.0002951193049232596
415,seem,0.0013695553398094596,5.598423885421652e-05,0.0004715098423180979,9.925966717669496e-08,9.599355247995107e-05
1037,rest,1.2772318313973636e-07,3.852428822290446e-19,0.0012117490207829832,9.904247676258545e-10,0.00029857855470233677
1677,intervention,1.1254171251009659e-07,3.2028249280333504e-07,3.4070602778354267e-06,9.894457802269543e-11,0.00029299490364815227
256,sale,2.2962890125477973e-15,1.1905438957085973e-08,2.436842648958622e-18,9.760535688088921e-14,0.0011865807093305021
1357,sense,3.154600169313066e-08,0.00042329264294882886,3.1608877698034176e-11,9.747024896701825e-10,1.3008149030333748e-22
206,puzzle,0.00015702442863163026,1.0346026790171422e-07,5.877124030369195e-18,9.68391211181381e-13,0.0001608649226430081


In [22]:
topic_word.sort_values(topic_word.columns[5],ascending=False)

Unnamed: 0,0,1,2,3,4,5
1604,cried,0.000684936594574154,1.6705905698222089e-06,1.6466953795611384e-07,1.7039008354263643e-16,9.997427626041537e-16
1177,rib,1.029171440613737e-06,7.505682885633182e-12,0.0006059990899571789,1.072110347595224e-19,9.995901995557746e-09
975,dit,2.440266196683686e-37,0.001260221248425106,1.1927560047986226e-17,8.149307068786535e-06,9.983498899257043e-12
760,gangster,4.580896023608476e-30,9.157337411982453e-14,0.0003034584941801155,1.2962781484918961e-09,9.955187373128525e-27
285,aspire,1.8864536278260632e-07,9.75845947957113e-19,3.974743371557337e-12,0.0005872288039958177,9.944413263788432e-05
1681,nine,7.364897195135603e-26,0.000423332589096799,1.2510190138453729e-10,6.318193186746748e-18,9.922066555965512e-15
1159,dire,2.1772844981484168e-20,3.1572968004000793e-22,2.1004273457049316e-12,0.00035285452608553683,9.915488833101917e-17
65,getting,0.0006526388918864237,0.0024403465358053115,1.696670106842919e-09,1.754829593224718e-20,9.907258496442872e-05
702,hell,0.00047242278918791013,9.031900132978371e-08,0.002009939793527791,2.1677976224669186e-10,9.8624602516644e-12
410,humble,1.563371292170464e-07,0.00029234811708648065,1.0014598792908781e-13,0.00010901704965561056,9.861709460198081e-27
