https://spark.apache.org/docs/latest/api/python/pyspark.ml.html#pyspark.ml.feature.CountVectorizer

https://spark.apache.org/docs/latest/ml-features#countvectorizer

https://spark.apache.org/docs/latest/ml-clustering.html#latent-dirichlet-allocation-lda

https://spark.apache.org/docs/latest/api/python/pyspark.ml.html#pyspark.ml.clustering.LDA

https://www.zstat.pl/2018/02/07/scala-spark-get-topics-words-from-lda-model/

https://stackoverflow.com/questions/51456838/match-index-from-pyspark-dataframe-in-pandas/51457137#51457137

https://databricks-prod-cloudfront.cloud.databricks.com/public/4027ec902e239c93eaaa8714f173bcfc/3741049972324885/3783546674231782/4413065072037724/latest.html

In [1]:
# START SPARKSESSION
import findspark
findspark.init()

In [2]:
import pyspark
#from pyspark.mllib.linalg import Vector, Vectors
from pyspark.sql import SparkSession,SQLContext
from pyspark.sql.types import *
from pyspark.ml.feature import Tokenizer, CountVectorizer, StopWordsRemover
from pyspark.ml.clustering import LDA
from pyspark.sql.functions import udf
#from pyspark.sql.functions import asc, count, col, collect_list

In [3]:
spark = SparkSession.builder.master("local[*]") \
                    .appName('Topic Modelling') \
                    .config("spark.driver.memory", "10g") \
                    .master("local[*]") \
                    .getOrCreate()

In [4]:
# PHASE 1: TOPIC MODELLING WITH SPARK

# Auxiliar functions
def equivalent_type(f):
    if f == 'datetime64[ns]': return TimestampType()
    elif f == 'int64': return LongType()
    elif f == 'int32': return IntegerType()
    elif f == 'float64': return FloatType()
    else: return StringType()

def define_structure(string, format_type):
    try: typo = equivalent_type(format_type)
    except: typo = StringType()
    return StructField(string, typo)

def pandas_to_spark(pandas_df):
    columns = list(pandas_df.columns)
    types = list(pandas_df.dtypes)
    struct_list = []
    for column, typo in zip(columns, types): 
      struct_list.append(define_structure(column, typo))
    p_schema = StructType(struct_list)
    return spark.createDataFrame(pandas_df, p_schema)

#data_df = pandas_to_spark(data)

In [114]:
# LOAD PROCESSED DATA 
data_df_full = spark.read.option('header', True).csv('/home/giangvdq/data/NIPS Papers/papers_processed.csv')

In [115]:
data_df = data_df_full.limit(3000)

In [116]:
# STEP: TOKENIZE

# source: https://gist.github.com/Bergvca/a59b127afe46c1c1c479

tokenizer = Tokenizer(inputCol="lemmatize_joined", outputCol="words")
wordsDataFrame = tokenizer.transform(data_df)

In [117]:
# STEP: REMOVE X MOST OCCURING WORDS
cv_tmp = CountVectorizer(inputCol="words", outputCol="tmp_vectors")
cv_tmp_model = cv_tmp.fit(wordsDataFrame)

topWords = list(cv_tmp_model.vocabulary[0:500])

remover = StopWordsRemover(inputCol="words", outputCol="filtered", stopWords = topWords)
wordsDataFrame = remover.transform(wordsDataFrame)

In [118]:
# STEP: COUNTVECTORIZER

cv = CountVectorizer(inputCol="filtered", outputCol="vectors")
cvmodel = cv.fit(wordsDataFrame)
df_vect = cvmodel.transform(wordsDataFrame)

In [119]:
#transform the dataframe to a format that can be used as input for LDA.train. 
#LDA train expects a RDD with lists,
#where the list consists of a uid and (sparse) Vector
def parseVectors(line):
    return [ int(line[2]), line[0] ]

sparsevector = (df_vect.select('vectors', 'lemmatize_joined', 'id')
                .rdd.map(parseVectors) )

In [120]:
sparsevector = sparsevector.toDF()

In [132]:
#Train the LDA model

lda = LDA(k=10, maxIter=50, featuresCol='_2', seed=1, optimizer='em')
model = lda.fit(sparsevector)

In [133]:
model.vocabSize()

75334

In [134]:
cvmodel.vocabulary

['cortex',
 'exact',
 'vision',
 'illustrate',
 'direct',
 'significant',
 'trajectory',
 'capture',
 'advantage',
 'category',
 'query',
 'specify',
 'separate',
 'randomly',
 'various',
 'prove',
 'cambridge',
 'transform',
 'reinforcement',
 'amount',
 'manifold',
 'dependent',
 'always',
 'online',
 'boost',
 'adaptive',
 'coordinate',
 'mechanism',
 'still',
 'fast',
 'world',
 'expression',
 'ratio',
 'generative',
 'implementation',
 'population',
 'main',
 'expectation',
 'address',
 'free',
 'can',
 'user',
 'chain',
 'department',
 'variational',
 'not',
 'see',
 'identify',
 'activation',
 'language',
 'consistent',
 'frame',
 'turn',
 'assign',
 'relationship',
 'environment',
 'validation',
 'simply',
 'conclusion',
 'run',
 'underlie',
 'complete',
 'addition',
 'whether',
 'imply',
 'uniform',
 'bay',
 'interval',
 'movement',
 'chip',
 'significantly',
 'theoretical',
 'difficult',
 'formulation',
 'among',
 'code',
 'record',
 'artificial',
 'refer',
 'link',
 'typical

In [135]:
ll = model.logLikelihood(sparsevector)
lp = model.logPerplexity(sparsevector)
print("The lower bound on the log likelihood of the entire corpus: " + str(ll))
print("The upper bound on perplexity: " + str(lp))

The lower bound on the log likelihood of the entire corpus: -175628075.89487827
The upper bound on perplexity: 83.43404272359969


In [143]:
# Describe topics.

# Number of terms in topics
numTerms = 8

topics = model.describeTopics(numTerms)
print("The topics described by their top-weighted terms:")
topics.show(5, truncate=True)

The topics described by their top-weighted terms:
+-----+--------------------+--------------------+
|topic|         termIndices|         termWeights|
+-----+--------------------+--------------------+
|    0|[20, 83, 175, 165...|[0.00918584223962...|
|    1|[95, 125, 33, 2, ...|[0.00938510614558...|
|    2|[9, 172, 10, 186,...|[0.01147312123454...|
|    3|[24, 41, 23, 232,...|[0.00992664812854...|
|    4|[0, 81, 140, 134,...|[0.00765284871033...|
+-----+--------------------+--------------------+
only showing top 5 rows



In [144]:
# DISPLAY THE TOPIC DISTRIBUTION

def indices_to_terms(vocabulary):
    def indices_to_terms(xs):
        return [vocabulary[int(x)] for x in xs]
    return udf(indices_to_terms, ArrayType(StringType()))

topics_with_terms = topics.withColumn(
    "topics_words", indices_to_terms(cvmodel.vocabulary)("termIndices"))

#topics_with_terms.select(['topic','topics_words']).show(20,False)
topics_with_terms.select(['topics_words']).show(20,False)

+---------------------------------------------------------------------------------------+
|topics_words                                                                           |
+---------------------------------------------------------------------------------------+
|[manifold, dimensionality, subspace, operator, embed, principal, coordinate, reduction]|
|[patch, segmentation, generative, vision, frame, scene, contour, pose]                 |
|[category, nearest, query, concept, split, database, retrieval, code]                  |
|[boost, user, online, player, learner, weak, equilibrium, play]                        |
|[cortex, delay, cortical, motor, synapsis, record, auditory, excitatory]               |
|[chip, receptive, attention, analog, head, adaptation, voltage, velocity]              |
|[reinforcement, trajectory, robot, plan, environment, controller, world, future]       |
|[gibbs, chain, carlo, monte, particle, variational, message, processor]                |
|[semi, sp

In [145]:
# Shows the result
docTopic = model.transform(sparsevector)
docTopic.printSchema()
docTopic.select(['_1','topicDistribution']).show(10,truncate=False)

root
 |-- _1: long (nullable = true)
 |-- _2: vector (nullable = true)
 |-- topicDistribution: vector (nullable = true)

+----+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|_1  |topicDistribution                                                                                                                                                                                            |
+----+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|1   |[0.061481134491564816,0.09124081090577164,0.164676129777718,0.02353193674583338,0.032919301265873085,0.13291080481908854,0.3639027018535085,0.047934166590357666,0.02497048465622115,0.05643252889406327]    |
|10  |[0.025387459919016874