Sources : 
[General data science article presenting the method](https://medium.com/@connectwithghosh/topic-modelling-with-latent-dirichlet-allocation-lda-in-pyspark-2cb3ebd5678e)

[Issue of an user to do LDA on SO](https://stackoverflow.com/questions/42051184/latent-dirichlet-allocation-lda-in-spark#)

[Gist of a LDA implementation](https://gist.github.com/Bergvca/a59b127afe46c1c1c479)

[SO question about retrieving topic distribution](https://stackoverflow.com/questions/33072449/extract-document-topic-matrix-from-pyspark-lda-model/41515070)

[Slides from presentation about recommended parameters for LDA](http://www.phusewiki.org/wiki/images/c/c9/Weizhong_Presentation_CDER_Nov_9th.pdf)

In [38]:
import findspark
findspark.init()

import pyspark
from pyspark.sql import *
from pyspark.sql.types import *

spark = SparkSession.builder.getOrCreate()

from nltk.corpus import stopwords, wordnet
from nltk.stem import WordNetLemmatizer
from nltk import pos_tag

#other words that need to be removed in order to avoid pollution on the topic. (experimental findings on news subreddit)
aux_stop_words = ['people', 'would', 'like', 'img', 'jpg', 'imgjpg', 'botcontact', 'picthis']

en_stop = set(stopwords.words('english')+aux_stop_words)
sc = spark.sparkContext
sc.broadcast(en_stop)

import datetime
import re as re

from pyspark.ml.feature import CountVectorizer , IDF
from pyspark.ml.clustering import LDA

In [2]:
data_trump = spark.read.load('../data/donald_comments.parquet')

In [39]:
def text_preprocessing(txt, stop_words, pos_tagging=False, use_lemmatizing=False):
    '''
    Take care of doing all the text preprocessing for LDA
    Only works on english ASCII content. (works on content with accent or such, but filter them out)
    '''
    def get_wordnet_pos(treebank_tag):
        if treebank_tag.startswith('J'):
            return wordnet.ADJ
        elif treebank_tag.startswith('V'):
            return wordnet.VERB
        elif treebank_tag.startswith('N'):
            return wordnet.NOUN
        elif treebank_tag.startswith('R'):
            return wordnet.ADV
        else:
            #this is the default behaviour for lemmatize. 
            return wordnet.NOUN
    
    #in the case of r/hillaryclinton, removes an annoying omnipresent bot
    if 'reddit-bot' in txt:
        return []
    
    def remove_https(s):
        return re.sub(r'https?:\/\/.*[\r\n]*', '', s, flags=re.UNICODE)
    
    #keeping only elements relevant to written speech.
    def keep_only_letters(s):
        return re.sub('[^a-zA-Z \']+', '', s, flags=re.UNICODE)
    
    #remove mark of genitif from speech (i.e. "'s" associated to nouns)
    def remove_genitive(s): 
        return re.sub('(\'s)+', '', s, flags=re.UNICODE)
    
    #to avoid things that should be stop words (the're, I'll, wouldn't) to pollute the topics.
    def remove_quote(s):
        return s.replace('\'', ' ')
    
    def clean_pipeline(s): 
        return remove_quote(remove_genitive(keep_only_letters(remove_https(s))))
    
    if pos_tagging:
        #cannot use pos tagging without lemmatizing
        assert(use_lemmatizing)
    
    #tokenizing the texts (removing line break, space and capitalization)
    token_comm = re.split(" ", clean_pipeline(txt).strip().lower())
    
    #to avoid empty token (caused by multiple spaces in the tokenization)
    token_comm = [t for t in token_comm if len(t) > 0]
    
    if pos_tagging:
        token_comm = pos_tag(token_comm)
    else:
        token_comm = zip(token_comm, [None]*len(token_comm))
        
    #removing all words of three letters or less
    bigger_w = [x for x in token_comm if len(x[0]) > 3]

    #removing stop_words
    wout_sw_w = [x for x in bigger_w if x[0] not in stop_words]
    
    if pos_tagging and use_lemmatizing:
        #get lemma of each word, then return result
        return [WordNetLemmatizer().lemmatize(word, get_wordnet_pos(tag)) for word, tag in wout_sw_w]
    elif use_lemmatizing:
        #get lemma of each word, then return result
        return [WordNetLemmatizer().lemmatize(word) for word, _ in wout_sw_w]
    else:
        return [word for word,_ in wout_sw_w]

def perform_lda(documents, n_topics, n_words, alphas, beta, tokens_col):
    '''
    will perform LDA on a list of documents (== list of token)
    assume that documents is a DataFrame with a column of unique id (uid).
    
    '''
    cv = CountVectorizer(inputCol=tokens_col, outputCol="raw_features")
    cvmodel = cv.fit(documents)
    result_cv = cvmodel.transform(documents)
    
    #we perform an tf-idf (term frequency inverse document frequency), to avoid threads with a lot of words to pollute the topics.
    idf = IDF(inputCol="raw_features", outputCol="features")
    idfModel = idf.fit(result_cv)
    result_tfidf = idfModel.transform(result_cv) 
    
    corpus = result_tfidf.select("uid", "features")
    
    #defining and running the lda. 
    lda = LDA(k=n_topics, docConcentration=alphas, topicConcentration=beta)
    model = lda.fit(corpus)
    
    #retrieving topics, and the vocabulary constructed by the CountVectorizer
    topics = model.describeTopics(maxTermsPerTopic=n_words)
    vocab = cvmodel.vocabulary
    
    #the topics are just numerical indices, we need to convert them to words, and associate them to their weights..
    topics_with_weights = topics.rdd.map(lambda r: (r[0], ([(vocab[t],w) for t,w in zip(r[1], r[2])]), ' '.join([vocab[t] for t in r[1]]))).toDF().selectExpr("_1 as topic_number", "_2 as topic_weight", "_3 as topic")
    
    return topics_with_weights

def display_topics(topics_n_weight):
    tops = topics_n_weight.select('topic').collect()
    for i in range(len(tops)):
        print("T %d: %s"%(i+1, tops[i][0]))
    return tops

def lda_preprocess(df, use_lemmatizing=False, use_pos_tagging=False):
    data_preprocessed = df.select('body', 'created').rdd.filter(lambda r: len(r[0]) > 50).map(lambda r: (text_preprocessing(r[0], en_stop, pos_tagging=use_pos_tagging, use_lemmatizing=use_lemmatizing), r[1])).filter(lambda r: r[0])
    data_uid = data_preprocessed.map(lambda r: r).zipWithUniqueId().map(lambda r: (r[0][0], r[1])).toDF().selectExpr('_1 as text', '_2 as uid')
    return data_uid
    
        
def lda_display_and_time(df, n_topics, n_words, beta=0.01, alpha=0.01):
    print('On %d comments, using %d topics with %d words, beta value : %.3f, alpha value: %.3f'%(df.count(), n_topics, n_words, beta, alpha))
    #print('Starting day : %s Ending day : %s'%(df.agg({"created": "min"}).collect()[0][0], df.agg({"created": "max"}).collect()[0][0]))
    print('Start time : '+ str(datetime.datetime.now()))
    alphas = [alpha]*n_topics
    topic_n_weights = perform_lda(df, n_topics, n_words, alphas, beta, 'text')
    topics_clean = display_topics(topic_n_weights)
    print('End time : \n'+ str(datetime.datetime.now()))
    return spark.createDataFrame(topics_clean, schema = StructType([StructField("Topic", StringType())]))


def grid_search(df, n_topics, n_words, betas, alphas):
    return [(a, b, lda_display_and_time(df, n_topics, n_words, b, a)) for a in alphas for b in betas]


    

## Topics on week before the election

### Filtering out the less upvoted comments

In [4]:
data_trump_score = data_trump.filter(data_trump.score > 10)
data_trump_score.count()

2450251

In [5]:
start_week = datetime.date(year=2016, month=10, day=31)
end_week = datetime.date(year=2016, month=11, day=8)
trump_week_b4_election = data_trump_score.filter(data_trump_score.created > start_week).filter(data_trump_score.created < end_week)
trump_week_b4_election.count()

77512

## The week before the election without any subsampling : 

In [10]:
trump_week_b4_election = lda_preprocess(trump_week_b4_election, use_lemmatizing=True, use_pos_tagging=False)
trump_week_b4_election.cache()
trump_week_b4_election.count()

50489

In [None]:
#grid_search(trump_week_b4_election, n_topics=8, n_words=5, betas=[0.01, 0.025, 0.05, 0.075, 0.1], alphas=[0.01, 0.025, 0.05, 0.075, 0.1])
trump_lda_res = lda_display_and_time(trump_week_b4_election, n_topics=8, n_words=5, beta=0.1, alpha=0.025)
#trump_lda_res.write.mode('overwrite').parquet('../data/trump_lda_result.parquet')
trump_lda_res

### Hillary Comments

In [40]:
hillary_data = spark.read.load('../data/hillary_comments.parquet')
hillary_election_week = hillary_data.filter(hillary_data.score > 10).filter(hillary_data.created > start_week).filter(hillary_data.created < end_week)

In [None]:
hillary_election_prepro = lda_preprocess(hillary_election_week, use_lemmatizing=True, use_pos_tagging=True)
print hillary_election_prepro.count()
hillary_election_prepro.cache()

In [None]:
#grid_search(hillary_election_prepro, 8, 5, betas=[0.01, 0.025, 0.05, 0.075, 0.1], alphas=[0.01, 0.025])
hillary_lda_res = lda_display_and_time(hillary_election_prepro, n_topics=8, n_words=5, beta=0.1, alpha=0.025)
hillary_lda_res

In [35]:
hillary_lda_res.write.mode('overwrite').parquet('../data/hillary_lda_result.parquet')

### News

In [16]:
news_data = spark.read.load('../data/news_week_b4_elec.parquet')
news_election_week = news_data.filter(news_data.created > start_week).filter(news_data.created < end_week)

In [17]:
news_election_prepro = lda_preprocess(news_election_week, use_lemmatizing=True, use_pos_tagging=False)
print news_election_prepro.count()
news_election_prepro.cache()

85750


DataFrame[text: array<string>, uid: bigint]

In [18]:
election_buzz_words = ['trump', 'donald', 'hillary', 'vote', 'voting', 'election', 'campaign', 'clinton']


def election_topic(tokens, buzz_words):
    for t in tokens:
        if t in buzz_words:
            return True
    return False


news_election_focused = news_election_prepro.rdd.filter(lambda r : election_topic(r[0], election_buzz_words)).toDF()
print news_election_focused.count()
news_election_focused.cache()
news_election_focused.show()

3945
+--------------------+-----+
|                text|  uid|
+--------------------+-----+
|[always, found, s...|  490|
|[anyone, hating, ...| 1867|
|[trump, said, ref...| 2596|
|[wait, till, hill...| 5431|
|[downvoted, stati...| 6241|
|[telling, campaig...| 7618|
|[vote, hillary, t...| 7861|
|[clinton, trying,...| 8509|
|[problem, politic...| 8671|
|[mention, thing, ...|10696|
|[think, republica...|14179|
|[public, financin...|17014|
|[argument, democr...|23008|
|[pas, democrat, e...|23251|
|[except, crazy, b...|25438|
|[predicted, trump...|27625|
|[trump, complain,...|31351|
|[thing, shameful,...|31432|
|[current, neutere...|32890|
|[reason, contain,...|35482|
+--------------------+-----+
only showing top 20 rows



In [19]:
news_election_prepro_sample = news_election_prepro.sample(False, 0.1, 4544)
print news_election_prepro_sample.count()
news_election_prepro_sample.cache()

8628


DataFrame[text: array<string>, uid: bigint]

In [24]:
news_lda_res_sample = lda_display_and_time(news_election_prepro_sample, n_topics=8, n_words=5, beta=0.1, alpha=0.025)
news_lda_res_sample.write.mode('overwrite').parquet('../data/news_sample_lda_result.parquet')

On 8628 comments, using 8 topics with 5 words, beta value : 0.100, alpha value: 0.025
Start time : 2018-12-16 15:18:02.044559
T 1: trump page supporter owner shooting
T 2: piece million lawyer drought touch
T 3: bank cost doctor fire paid
T 4: child drug world year thing
T 5: school want need really even
T 6: woman pipeline court case game
T 7: kill muslim shoot dont respect
T 8: government information request public tax
End time : 
2018-12-16 15:18:21.482671


In [25]:
news_focused_lda_res = lda_display_and_time(news_election_focused, n_topics=8, n_words=5, beta=0.1, alpha=0.025)
news_focused_lda_res.write.mode('overwrite').parquet('../data/news_focused_lda_result.parquet')

On 3945 comments, using 8 topics with 5 words, beta value : 0.100, alpha value: 0.025
Start time : 2018-12-16 15:18:21.815268
T 1: party want state muslim jew
T 2: solar power company utility florida
T 3: isi syria saudi syrian dealing
T 4: make vote law trump help
T 5: trump even black church vote
T 6: arizona association sending auto epstein
T 7: clinton email trump time know
T 8: republican vote government right state
End time : 
2018-12-16 15:18:37.138017


In [34]:
news_election_prepro_score = lda_preprocess(news_election_week.filter(news_election_week.score >10), use_lemmatizing=True, use_pos_tagging=False)
print news_election_prepro_score.count()
news_election_prepro_score.cache()

10544


DataFrame[text: array<string>, uid: bigint]

In [37]:
news_score_lda_res = lda_display_and_time(news_election_prepro_score, n_topics=8, n_words=5, beta=0.1, alpha=0.025)
news_score_lda_res.write.mode('overwrite').parquet('../data/news_score_lda_result.parquet')

On 10544 comments, using 8 topics with 5 words, beta value : 0.100, alpha value: 0.025
Start time : 2018-12-16 17:51:40.792959


Py4JJavaError: An error occurred while calling o2548.fit.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 2 in stage 881.0 failed 1 times, most recent failure: Lost task 2.0 in stage 881.0 (TID 30312, localhost, executor driver): TaskResultLost (result lost from block manager)
Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1651)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1639)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1638)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1638)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:831)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:831)
	at scala.Option.foreach(Option.scala:257)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:831)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1872)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1821)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1810)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:642)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2034)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2131)
	at org.apache.spark.rdd.RDD$$anonfun$fold$1.apply(RDD.scala:1098)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
	at org.apache.spark.rdd.RDD.fold(RDD.scala:1092)
	at org.apache.spark.rdd.RDD$$anonfun$treeAggregate$1.apply(RDD.scala:1161)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
	at org.apache.spark.rdd.RDD.treeAggregate(RDD.scala:1137)
	at org.apache.spark.mllib.clustering.OnlineLDAOptimizer.submitMiniBatch(LDAOptimizer.scala:501)
	at org.apache.spark.mllib.clustering.OnlineLDAOptimizer.next(LDAOptimizer.scala:450)
	at org.apache.spark.mllib.clustering.OnlineLDAOptimizer.next(LDAOptimizer.scala:263)
	at org.apache.spark.mllib.clustering.LDA.run(LDA.scala:336)
	at org.apache.spark.ml.clustering.LDA.fit(LDA.scala:912)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:745)
