In [1]:
from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("Python Spark topic model") \
    .config("spark.some.config.option", "some-value") \
    .getOrCreate()

In [2]:
from pyspark.ml.feature import HashingTF, IDF, Tokenizer, CountVectorizer
from pyspark.sql.types import *
from pyspark.sql.functions import *
from pyspark.ml.linalg import Vectors, SparseVector
from pyspark.ml.clustering import LDA, BisectingKMeans
from pyspark.sql.functions import monotonically_increasing_id
import re

In [4]:
rawdata = spark.read.format('com.databricks.spark.csv').\
                               options(header='true', \
                               inferschema='true').\
                     load("../data/airlines.csv",header=True)

In [5]:
rawdata = rawdata.fillna({'review': ''})                               # Replace nulls with blank string
rawdata = rawdata.withColumn("uid", monotonically_increasing_id())     # Create Unique ID
rawdata = rawdata.withColumn("year_month", rawdata.date.substr(1,7))   # Generate YYYY-MM variable

# Show rawdata (as DataFrame)
rawdata.show(10)

+-----+---------------+---------+--------+------+--------+-----+-----------+--------------------+---+----------+
|   id|        airline|     date|location|rating|   cabin|value|recommended|              review|uid|year_month|
+-----+---------------+---------+--------+------+--------+-----+-----------+--------------------+---+----------+
|10001|Delta Air Lines|21-Jun-14|Thailand|     7| Economy|    4|        YES|Flew Mar 30 NRT t...|  0|   21-Jun-|
|10002|Delta Air Lines|19-Jun-14|     USA|     0| Economy|    2|         NO|Flight 2463 leavi...|  1|   19-Jun-|
|10003|Delta Air Lines|18-Jun-14|     USA|     0| Economy|    1|         NO|Delta Website fro...|  2|   18-Jun-|
|10004|Delta Air Lines|17-Jun-14|     USA|     9|Business|    4|        YES|"I just returned ...|  3|   17-Jun-|
|10005|Delta Air Lines|17-Jun-14| Ecuador|     7| Economy|    3|        YES|"Round-trip fligh...|  4|   17-Jun-|
|10006|Delta Air Lines|17-Jun-14|     USA|     9|Business|    5|        YES|Narita - Bangkok ...

In [6]:
# Print data types
for type in rawdata.dtypes:
    print type

target = rawdata.select(rawdata['rating'].cast(IntegerType()))
target.dtypes


('id', 'int')
('airline', 'string')
('date', 'string')
('location', 'string')
('rating', 'int')
('cabin', 'string')
('value', 'int')
('recommended', 'string')
('review', 'string')
('uid', 'bigint')
('year_month', 'string')


[('rating', 'int')]

In [7]:
def cleanup_text(record):
    text  = record[8]
    uid   = record[9]
    words = text.split()
    
    # Default list of Stopwords
    stopwords_core = ['a', u'about', u'above', u'after', u'again', u'against', u'all', u'am', u'an', u'and', u'any', u'are', u'arent', u'as', u'at', 
    u'be', u'because', u'been', u'before', u'being', u'below', u'between', u'both', u'but', u'by', 
    u'can', 'cant', 'come', u'could', 'couldnt', 
    u'd', u'did', u'didn', u'do', u'does', u'doesnt', u'doing', u'dont', u'down', u'during', 
    u'each', 
    u'few', 'finally', u'for', u'from', u'further', 
    u'had', u'hadnt', u'has', u'hasnt', u'have', u'havent', u'having', u'he', u'her', u'here', u'hers', u'herself', u'him', u'himself', u'his', u'how', 
    u'i', u'if', u'in', u'into', u'is', u'isnt', u'it', u'its', u'itself', 
    u'just', 
    u'll', 
    u'm', u'me', u'might', u'more', u'most', u'must', u'my', u'myself', 
    u'no', u'nor', u'not', u'now', 
    u'o', u'of', u'off', u'on', u'once', u'only', u'or', u'other', u'our', u'ours', u'ourselves', u'out', u'over', u'own', 
    u'r', u're', 
    u's', 'said', u'same', u'she', u'should', u'shouldnt', u'so', u'some', u'such', 
    u't', u'than', u'that', 'thats', u'the', u'their', u'theirs', u'them', u'themselves', u'then', u'there', u'these', u'they', u'this', u'those', u'through', u'to', u'too', 
    u'under', u'until', u'up', 
    u'very', 
    u'was', u'wasnt', u'we', u'were', u'werent', u'what', u'when', u'where', u'which', u'while', u'who', u'whom', u'why', u'will', u'with', u'wont', u'would', 
    u'y', u'you', u'your', u'yours', u'yourself', u'yourselves']
    
    # Custom List of Stopwords - Add your own here
    stopwords_custom = ['']
    stopwords = stopwords_core + stopwords_custom
    stopwords = [word.lower() for word in stopwords]    
    
    text_out = [re.sub('[^a-zA-Z0-9]','',word) for word in words]                                       # Remove special characters
    text_out = [word.lower() for word in text_out if len(word)>2 and word.lower() not in stopwords]     # Remove stopwords and words under X length
    return text_out

udf_cleantext = udf(cleanup_text , ArrayType(StringType()))
clean_text = rawdata.withColumn("words", udf_cleantext(struct([rawdata[x] for x in rawdata.columns])))


In [9]:
# Term Frequency Vectorization  - Option 2 (CountVectorizer)    : 
cv = CountVectorizer(inputCol="words", outputCol="rawFeatures", vocabSize = 1000)
cvmodel = cv.fit(clean_text)
featurizedData = cvmodel.transform(clean_text)

vocab = cvmodel.vocabulary
vocab_broadcast = spark.sparkContext.broadcast(vocab)

idf = IDF(inputCol="rawFeatures", outputCol="features")
idfModel = idf.fit(featurizedData)
rescaledData = idfModel.transform(featurizedData)

In [10]:
lda = LDA(k=25, seed=123, optimizer="em", featuresCol="features")

ldamodel = lda.fit(rescaledData)

#model.isDistributed()
#model.vocabSize()

ldatopics = ldamodel.describeTopics()
ldatopics.show(25)

def map_termID_to_Word(termIndices):
    words = []
    for termID in termIndices:
        words.append(vocab_broadcast.value[termID])
    
    return words

udf_map_termID_to_Word = udf(map_termID_to_Word , ArrayType(StringType()))
ldatopics_mapped = ldatopics.withColumn("topic_desc", udf_map_termID_to_Word(ldatopics.termIndices))
ldatopics_mapped.select(ldatopics_mapped.topic, ldatopics_mapped.topic_desc).show(25,False)

ldaResults = ldamodel.transform(rescaledData)

ldaResults.show()

+-----+--------------------+--------------------+
|topic|         termIndices|         termWeights|
+-----+--------------------+--------------------+
|    0|[106, 301, 432, 7...|[0.02642483914460...|
|    1|[218, 257, 312, 4...|[0.02860477465189...|
|    2|[869, 639, 155, 2...|[0.01437555629232...|
|    3|[139, 155, 50, 12...|[0.02946689912509...|
|    4|[582, 640, 33, 16...|[0.01636880051698...|
|    5|[498, 251, 48, 26...|[0.01998824201487...|
|    6|[197, 791, 88, 39...|[0.03639006364794...|
|    7|[459, 248, 386, 1...|[0.01892117257323...|
|    8|[237, 761, 78, 31...|[0.01975997181346...|
|    9|[411, 8, 629, 47,...|[0.01284467745242...|
|   10|[573, 796, 723, 1...|[0.01320751612923...|
|   11|[500, 363, 392, 2...|[0.02222368774827...|
|   12|[136, 182, 5, 22,...|[0.01628267590187...|
|   13|[780, 327, 368, 7...|[0.01180130842975...|
|   14|[19, 71, 29, 4, 3...|[0.01849898845552...|
|   15|[54, 8, 26, 389, ...|[0.02528842185304...|
|   16|[601, 435, 100, 6...|[0.01486335804649...|


In [11]:
def breakout_array(index_number, record):
    vectorlist = record.tolist()
    return vectorlist[index_number]

udf_breakout_array = udf(breakout_array, FloatType())
enrichedData = ldaResults                                                                   \
        .withColumn("Topic_12", udf_breakout_array(lit(12), ldaResults.topicDistribution))  \
        .withColumn("topic_20", udf_breakout_array(lit(20), ldaResults.topicDistribution))            

enrichedData.show()

+-----+---------------+---------+---------+------+--------+-----+-----------+--------------------+---+----------+--------------------+--------------------+--------------------+--------------------+-----------+-----------+
|   id|        airline|     date| location|rating|   cabin|value|recommended|              review|uid|year_month|               words|         rawFeatures|            features|   topicDistribution|   Topic_12|   topic_20|
+-----+---------------+---------+---------+------+--------+-----+-----------+--------------------+---+----------+--------------------+--------------------+--------------------+--------------------+-----------+-----------+
|10001|Delta Air Lines|21-Jun-14| Thailand|     7| Economy|    4|        YES|Flew Mar 30 NRT t...|  0|   21-Jun-|[flew, mar, nrt, ...|(1000,[0,3,11,25,...|(1000,[0,3,11,25,...|[0.03326415802857...|0.025253229| 0.06195073|
|10002|Delta Air Lines|19-Jun-14|      USA|     0| Economy|    2|         NO|Flight 2463 leavi...|  1|   19-Jun-

In [12]:
enrichedData.createOrReplaceTempView("enrichedData")

spark.sql("SELECT id, airline, date, rating, topic_12 FROM enrichedData")

spark.sql("SELECT id, airline, date, rating, topic_20 FROM enrichedData")

DataFrame[id: int, airline: string, date: string, rating: int, topic_20: float]