In [1]:
from pyspark.ml.clustering import LDA

In [2]:
# Loads data.
dataset = spark.read.format("libsvm").load("datasets/sample_lda_libsvm_data.txt")
dataset.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(11,[0,1,2,4,5,6,...|
|  1.0|(11,[0,1,3,4,7,10...|
|  2.0|(11,[0,1,2,5,6,8,...|
|  3.0|(11,[0,1,3,6,8,9,...|
|  4.0|(11,[0,1,2,3,4,6,...|
|  5.0|(11,[0,1,3,4,5,6,...|
|  6.0|(11,[0,1,3,6,8,9,...|
|  7.0|(11,[0,1,2,3,4,5,...|
|  8.0|(11,[0,1,3,4,5,6,...|
|  9.0|(11,[0,1,2,4,6,8,...|
| 10.0|(11,[0,1,2,3,5,6,...|
| 11.0|(11,[0,1,4,5,6,7,...|
+-----+--------------------+



In [3]:
# Trains a LDA model.
lda = LDA(k=10, maxIter=10)
model = lda.fit(dataset)

In [4]:
ll = model.logLikelihood(dataset)
lp = model.logPerplexity(dataset)
print("The lower bound on the log likelihood of the entire corpus: " + str(ll))
print("The upper bound on perplexity: " + str(lp))

The lower bound on the log likelihood of the entire corpus: -830.8200362218348
The upper bound on perplexity: 3.195461678640353


In [5]:
# Describe topics.
topics = model.describeTopics(3)
print("The topics described by their top-weighted terms:")
topics.show(truncate=False)

The topics described by their top-weighted terms:
+-----+-----------+---------------------------------------------------------------+
|topic|termIndices|termWeights                                                    |
+-----+-----------+---------------------------------------------------------------+
|0    |[10, 6, 1] |[0.17716277365583435, 0.175093826940891, 0.14398144339477317]  |
|1    |[0, 5, 9]  |[0.10767384081908464, 0.09803424340533426, 0.09707083774679913]|
|2    |[5, 10, 9] |[0.09819703267561146, 0.09813706321638012, 0.09566066687701354]|
|3    |[5, 10, 2] |[0.1043336472136259, 0.10204514734224286, 0.09789654769297573] |
|4    |[5, 6, 8]  |[0.17117267209890458, 0.10008771147187673, 0.09380215424402512]|
|5    |[2, 1, 5]  |[0.10181812241305552, 0.09675765527782697, 0.09604418553503413]|
|6    |[6, 4, 9]  |[0.10646588514827376, 0.10135478933291643, 0.099179157965757]  |
|7    |[8, 3, 5]  |[0.10453789038581693, 0.09705020776286659, 0.09687785234996922]|
|8    |[2, 1, 5]  |[0.1120

In [6]:
# Shows the result
transformed = model.transform(dataset)
transformed.show(truncate=False)

+-----+---------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|label|features                                                       |topicDistribution                                                                                                                                                                                                     |
+-----+---------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|0.0  |(11,[0,1,2,4,5,6,7,10],[1.0,2.0,6.0,2.0,3.0,1.0,1.0,3.0])      |[0.3718280662219035,0.004731039286517012,0.004731037198720411,0.0047