# 大数据管理作业04_PLSA

郭英明 2183211376

In [33]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
from sklearn.datasets import fetch_20newsgroups
from nltk.corpus import stopwords
import re
import jieba
from sklearn.feature_extraction.text import TfidfVectorizer
import collections

In [34]:
class PLSA:
    '''
    model:PLSA(EM)
    k:话题数
    language：文本语言
    method:  1:生成，2：共现 
    '''
    def __init__(self,text_list,k,language,method = 1):
        self.k = k
        self.text_list = text_list
        self.text_num = len(text_list)
        self.method = method
        self.get_X(language)

    def get_X(self,language):
        if language == 'chinese':
            self.cuted_text = [jieba.lcut(text,cut_all=True) for text in self.text_list]
            
        if language == 'english':
            news_df = pd.DataFrame({'document':self.text_list})
            news_df['clean_doc'] = news_df['document'].str.replace("[^a-zA-Z#]", " ")
            news_df['clean_doc'] = news_df['clean_doc'].apply(lambda x: ' '.join([w for w in x.split() if len(w)>3]))
            news_df['clean_doc'] = news_df['clean_doc'].apply(lambda x: x.lower())
            stop_words = stopwords.words('english')
            tokenized_doc = news_df['clean_doc'].apply(lambda x: x.split())
            self.cuted_text = tokenized_doc.apply(lambda x: [item for item in x if item not in stop_words])
            
        self.word_all = []
        for i in self.cuted_text:
            self.word_all.extend(i)   
            
        self.word_set = list(set(self.word_all))
        self.word_num = len(self.word_set)
        self.word_dict = {}
        for index,word in enumerate(self.word_set):
            self.word_dict[word] = index
        self.X = np.zeros((self.word_num,self.text_num))
        for i in range(self.text_num):
            count_ = collections.Counter(self.cuted_text[i])
            for k, v in count_.items():
                self.X[self.word_dict[k],i] = v


    def shengcheng(self,max_iter):
        self.initial_1()       
        for iter in range(max_iter):
            self.update_E_1()
            self.update_M_1()
            
    def initial_1(self):
        self.w_z  = np.random.random((self.word_num,self.k))
        self.z_d = np.random.random((self.k,self.text_num))
        
    def update_E_1(self):
        self.z_wd  = np.zeros((self.word_num,self.text_num,self.k))
        for i in range(self.word_num):
            for j in range(self.text_num):
                self.z_wd[i,j] = np.array([self.w_z[i]*self.z_d[:,j]]) / np.sum([self.w_z[i]*self.z_d[:,j]])
    
    def update_M_1(self):
        for k in range(self.k):
            for i in range(self.word_num):
                self.w_z[i,k] = np.sum(self.X[i]*self.z_wd[i,:,k])/\
                np.sum(self.X*self.z_wd[:,:,k])
            for j in range(self.text_num):
                self.z_d[k,j] = np.sum(self.X[:,j]*self.z_wd[:,j,k])/np.sum(self.X[:,j])
      
    
    def gongxian(self,max_iter):
        self.initial_2()       
        for iter in range(max_iter):
            self.update_E_2()
            self.update_M_2()
    
    def initial_2(self):
        self.w_z  = np.random.random((self.k,self.word_num))
        self.d_z = np.random.random((self.k,self.text_num))
        self.z = np.random.random((1,self.k))
    
    def update_E_2(self):
#         self.z_wd  = np.zeros((self.word_num,self.text_num,self.k))
#         for i in range(self.word_num):
#             for j in range(self.text_num):
#                 self.z_wd[i,j] = np.array([self.w_z[i]*self.z_d[:,j]]) / np.sum([self.w_z[i]*self.z_d[:,j]])
        self.z_wd  = np.zeros((self.word_num,self.text_num,self.k))
        for i in range(self.word_num):
            for j in range(self.text_num):
                self.z_wd[i,j] = np.array([self.w_z[:,i]*self.d_z[:,j]*self.z[0]]) / np.sum([self.w_z[:,i]*self.d_z[:,j]*self.z[0]])
        
    
    def update_M_2(self):
        for k in range(self.k):
            for i in range(self.word_num):
                self.w_z[k,i] = np.sum(self.X[i]*self.z_wd[i,:,k])/\
                np.sum(self.X*self.z_wd[:,:,k])
            for j in range(self.text_num):
                self.d_z[k,j] = np.sum(self.X[:,j]*self.z_wd[:,j,k])/np.sum(self.X[:,j])
            self.z[0] = np.sum(self.X*self.z_wd[:,:,k]) / np.sum(self.X)
    
    def fit(self,max_iter):
        if self.method == 1:
            self.shengcheng(max_iter)
        else:
            self.gongxian(max_iter)

## 生成

使用Scikit-learn库中导入新闻文本数据集(fetch_20newsgroups)的前10条

In [35]:
all_data = fetch_20newsgroups(subset='all')
data = all_data.data[:10]
# print(type(data))
lsa1 = PLSA(data,k=2,language = 'english')
lsa1.fit(10)
print(lsa1.w_z)
print(lsa1.z_d)

[[1.62850821e-07 1.50850220e-03]
 [3.47729525e-05 1.48682675e-03]
 [2.40885082e-03 3.85926282e-10]
 ...
 [6.22685736e-07 1.50821421e-03]
 [8.73520607e-07 1.50805712e-03]
 [6.31287669e-25 1.50860419e-03]]
[[7.45291538e-01 3.22796470e-03 5.73011770e-07 3.14248184e-02
  9.98675965e-01 9.99999980e-01 1.15790050e-04 9.99999999e-01
  2.70815136e-01 1.00000000e+00]
 [2.54708462e-01 9.96772035e-01 9.99999427e-01 9.68575182e-01
  1.32403534e-03 2.02159695e-08 9.99884210e-01 5.39113604e-10
  7.29184864e-01 6.37492650e-12]]


## 共现

In [36]:
lsa2 = PLSA(data,k=2,language = 'english',method = 2)
lsa2.fit(10)
print(lsa2.z)
print(lsa2.w_z)
print(lsa2.d_z)

[[0.64726386 0.64726386]]
[[1.63860914e-31 1.25123659e-06 1.16741506e-24 ... 2.48603069e-31
  6.93744227e-33 1.85175567e-11]
 [1.43317716e-03 1.43249528e-03 1.43317716e-03 ... 1.43317716e-03
  1.43317716e-03 1.43317715e-03]]
[[3.44560128e-01 9.99999894e-01 2.81121601e-04 5.86423719e-08
  5.62375115e-06 8.28131825e-01 7.15545504e-01 9.94091268e-01
  1.38825537e-01 9.99999988e-01]
 [6.55439872e-01 1.05712228e-07 9.99718878e-01 9.99999941e-01
  9.99994376e-01 1.71868175e-01 2.84454496e-01 5.90873228e-03
  8.61174463e-01 1.18310526e-08]]
