### Импорты

In [1]:
from LDANewsModel import LDANewsModel
from pyspark import SparkContext
from pyspark.sql import DataFrame, SparkSession

### Сессия Spark

In [2]:
sc = SparkContext('local', 'LDAModel')
spark = SparkSession(sparkContext=sc)

# LDA на экономических новостях

## Загружаем данные

In [3]:
news_data = spark.read.parquet('data/russian_news.parquet').limit(100)
news_data.show(10)

+---+--------------------+
| id|                text|
+---+--------------------+
|  1|МОСКВА, 17 ноября...|
|  3|ДУБАЙ, 17 ноя - Р...|
|  4|МИНСК, 17 ноя - Р...|
|  5|МИНСК, 17 ноя — Р...|
|  6|МОСКВА, 17 ноя — ...|
|  7|ДУБАЙ, 17 ноя - Р...|
|  8|ДУБАЙ, 17 ноя - Р...|
|  9|ДУБАЙ, 17 ноя - Р...|
| 10|МОСКВА, 17 ноября...|
| 52|МОСКВА, 15 ноя - ...|
+---+--------------------+
only showing top 10 rows



## LDA модель


In [4]:
lda_model = LDANewsModel(news_data)
lda_results, vocab_model = lda_model.fit(num_topics=20, thresholds=(2, 8))

Смотрим, что получилось

In [5]:
topics = lda_results.describeTopics(maxTermsPerTopic=15)
vocabulary = vocab_model.vocabulary
topics.show()
topics_rdd = topics.rdd
topics_words = topics_rdd.map(lambda row: row['termIndices']).map(
    lambda idx_list: [vocabulary[idx] for idx in idx_list]
).collect()

for idx, topic in enumerate(topics_words):
    print(f'Тема {idx + 1}')
    print('----------------')
    for word in topic:
        print(word)
    print('----------------')

+-----+--------------------+--------------------+
|topic|         termIndices|         termWeights|
+-----+--------------------+--------------------+
|    0|[48, 50, 146, 72,...|[0.00718912031797...|
|    1|[1, 22, 33, 65, 7...|[0.02427821462681...|
|    2|[54, 90, 3, 17, 1...|[0.03090192130058...|
|    3|[39, 64, 115, 28,...|[0.02205850371371...|
|    4|[3, 9, 15, 13, 0,...|[0.03404766584271...|
|    5|[37, 20, 150, 33,...|[0.00675269823948...|
|    6|[5, 48, 12, 92, 1...|[0.03038398633806...|
|    7|[159, 88, 80, 133...|[0.00712910799651...|
|    8|[58, 156, 37, 125...|[0.00707809328024...|
|    9|[150, 0, 40, 172,...|[0.00684135297012...|
|   10|[131, 66, 50, 59,...|[0.03547224267652...|
|   11|[16, 122, 130, 14...|[0.00715297859102...|
|   12|[138, 8, 121, 170...|[0.00682087985238...|
|   13|[26, 40, 59, 82, ...|[0.00740468872810...|
|   14|[139, 53, 98, 157...|[0.00694238964215...|
|   15|[21, 68, 163, 171...|[0.05678237605699...|
|   16|[134, 175, 145, 1...|[0.00705744174400...|


## Сохраняем модель

In [6]:
lda_results.load('weights/LDA_model_1')

LocalLDAModel: uid=LDA_248eb2d509dd, k=15, numFeatures=12139

In [9]:
predictions, cv_model = lda_model.predict(news_data)

In [10]:
predictions.show(10)

+---+--------------------+--------------------+--------------------+
| id|       stemmed_words|                  tf|   topicDistribution|
+---+--------------------+--------------------+--------------------+
|  1|[белорусс, рассчи...|(138,[5,8,12,23,3...|[0.00171352511751...|
|  3|[дуба, министр, п...|(138,[1,6,16,22,3...|[0.00234666679561...|
|  4|[минск, белорусс,...|(138,[1,5,6,8,12,...|[9.27253502530128...|
|  5|[минск, напряжен,...|(138,[6,8,12,14,1...|[0.00134943991615...|
|  6|[финансов, рынк, ...|(138,[5,14,24,41,...|[0.00278819298990...|
|  7|[дуба, ежегодн, э...|(138,[5,19,40,51,...|[0.00222900809951...|
|  8|[дуба, саудовск, ...|(138,[2,12,16,20,...|[0.00234666679526...|
|  9|[дуба, объединен,...|(138,[9,22,30,39,...|[0.00262364632520...|
| 10|[французск, издан...|(138,[3,6,14,20,2...|[0.00178221808550...|
| 52|[бизнесм, нешет, ...|(138,[0,1,12,13,1...|[0.00164993096153...|
+---+--------------------+--------------------+--------------------+
only showing top 10 rows

