In [3]:
from collections import defaultdict
from pyspark import SparkContext
from pyspark.mllib.linalg import Vector, Vectors
from pyspark.mllib.clustering import LDA, LDAModel
from pyspark.sql import SQLContext
import re

n_stopwords = 50      # Number of most common words to remove, trying to eliminate stop words
n_topics = 3	            # Number of topics we are looking for
n_wordspertopic = 10    # Number of words to display for each topic
maxiter = 35         # Max number of times to iterate before finishing

sc = SparkContext('local', 'PySPARK LDA Example')
sql_context = SQLContext(sc)

In [4]:
data = sc.wholeTextFiles('cookbook_text/*').map(lambda x: x[1])
tokens = data \
    .map( lambda document: document.strip().lower()) \
    .map( lambda document: re.split("[\s;,#]", document)) \
    .map( lambda word: [x for x in word if x.isalpha()]) \
    .map( lambda word: [x for x in word if len(x) > 3] )

In [5]:
termCounts = tokens \
    .flatMap(lambda document: document) \
    .map(lambda word: (word, 1)) \
    .reduceByKey( lambda x,y: x + y) \
    .map(lambda tuple: (tuple[1], tuple[0])) \
    .sortByKey(False)

threshold_value = termCounts.take(n_stopwords)[n_stopwords - 1][0]

In [6]:
vocabulary = termCounts \
    .filter(lambda x : x[0] < threshold_value) \
    .map(lambda x: x[1]) \
    .zipWithIndex() \
    .collectAsMap()

def document_vector(document):
    id = document[1]
    counts = defaultdict(int)
    for token in document[0]:
        if token in vocabulary:
            token_id = vocabulary[token]
            counts[token_id] += 1
    counts = sorted(counts.items())
    keys = [x[0] for x in counts]
    values = [x[1] for x in counts]
    return (id, Vectors.sparse(len(vocabulary), keys, values))

In [7]:
documents = tokens.zipWithIndex().map(document_vector).map(list)

In [8]:
inv_voc = {value: key for (key, value) in vocabulary.items()}

with open("output.txt", 'w') as f:
    lda_model = LDA.train(documents, k=n_topics, maxIterations=maxiter)

    topic_indices = lda_model.describeTopics(maxTermsPerTopic=n_wordspertopic)
        
    # Print topics, showing the top-weighted 10 terms for each topic
    for i in range(len(topic_indices)):
        f.write("Topic #{0}\n".format(i + 1))
        for j in range(len(topic_indices[i][0])):
            f.write("{0}\t{1}\n".format(inv_voc[topic_indices[i][0][j]] \
                .encode('utf-8'), topic_indices[i][1][j]))
            

    f.write("{0} topics distributed over {1} documents and {2} unique words\n" \
        .format(n_topics, documents.count(), len(vocabulary)))