In [1]:
from threading import Thread

class StreamingThread(Thread):
    def __init__(self, ssc):
        Thread.__init__(self)
        self.ssc = ssc
    def run(self):
        ssc.start()
        ssc.awaitTermination()
    def stop(self):
        print('----- Stopping... this may take a few seconds -----')
        self.ssc.stop(stopSparkContext=False, stopGraceFully=True)

In [2]:
sc

## Data Loading and exploration

In [3]:
# start with easy implemetation: only consider the content of the 2 fields review_title and review_text
# concantenate them in one new field "review_concat"from pyspark.sql import SQLContext
from pyspark.sql import functions as fn
from pyspark.sql.types import IntegerType
import pandas as pd

filepath = 'data_processed/ExctractedData.json'
# load JSON file
s_df = spark.read.json(filepath)
s_df.count()
s_df = s_df.drop_duplicates(subset=['review_id'])
pd_df = s_df.groupBy('review_id').count().toPandas().set_index("count").sort_index(ascending=False)

In [4]:
# control no duplicate
pd_df.head()

Unnamed: 0_level_0,review_id
count,Unnamed: 1_level_1
1,R15DG6BI3K1I78
1,R1UU50BM0S4LPY
1,R27KEMBTEQ4MHI
1,R1HMP34XP1V9BE
1,R22I2JYOOXA3PP


In [5]:
# concatenate review text and title in one field
s_df = s_df.withColumn('review_concat',fn.concat(fn.col('review_title'),fn.lit(' '), fn.col('review_text')))
# review_score is of type String ==> cast it from String to Integer
s_df = s_df.withColumn("review_score", s_df["review_score"].cast(IntegerType()))
s_df.printSchema()

root
 |-- book_id: string (nullable = true)
 |-- book_title: string (nullable = true)
 |-- review_id: string (nullable = true)
 |-- review_score: integer (nullable = true)
 |-- review_text: string (nullable = true)
 |-- review_title: string (nullable = true)
 |-- review_user: string (nullable = true)
 |-- timestamp: long (nullable = true)
 |-- review_concat: string (nullable = true)



In [6]:
print('Total # of rows: ' + str(s_df.count()))
print('# of rows per class:')
s_df.groupBy("review_score") \
    .count() \
    .orderBy(fn.col("count").desc()) \
    .show()

Total # of rows: 11573
# of rows per class:
+------------+-----+
|review_score|count|
+------------+-----+
|           5| 9383|
|           4| 1529|
|           3|  346|
|           2|  170|
|           1|  145|
+------------+-----+



# Binary Classification
1 to 2 stars = 0 and 3 to 5 stars = 1
## Logistic regression

In [7]:
# add new field bin_score with 0 or 1
from pyspark.sql.functions import udf
def scoreToBin(value):
   if   value < 3: return 0
   else : return 1
udfScoreToBin = udf(scoreToBin, IntegerType())
s_df_bin = s_df.withColumn("bin_score", udfScoreToBin("review_score"))

In [8]:
# control that the function is properly working
s_df_bin.where(fn.col('review_score') == 3).first()

Row(book_id='0062678426', book_title='The Woman in the Window: A Novel', review_id='R1HMP34XP1V9BE', review_score=3, review_text='I wanted this to be better, it started so strong and then lost itself in the last third -to predictability.', review_title='Good, but not as good as the hype', review_user='Amazon Customer', timestamp=1557521653, review_concat='Good, but not as good as the hype I wanted this to be better, it started so strong and then lost itself in the last third -to predictability.', bin_score=1)

In [9]:
# Now make a new stratified split 80-20% with same proportion of bin_score 0 and 1
training_strat_df_bin = s_df_bin.sampleBy("bin_score", fractions={0: 0.8, 1: 0.8}, seed=42)
test_strat_df_bin = s_df_bin.subtract(training_strat_df_bin)

print('# rows training set: ' + str(training_strat_df_bin.count()))
print('# rows per class')
training_strat_df_bin.groupBy("bin_score") \
    .count() \
    .orderBy(fn.col("count").desc()) \
    .show()
# test set
print('# rows test set: ' + str(test_strat_df_bin.count()))
print('# rows per class')
test_strat_df_bin.groupBy("bin_score") \
    .count() \
    .orderBy(fn.col("count").desc()) \
    .show()

# rows training set: 9191
# rows per class
+---------+-----+
|bin_score|count|
+---------+-----+
|        1| 8940|
|        0|  251|
+---------+-----+

# rows test set: 2382
# rows per class
+---------+-----+
|bin_score|count|
+---------+-----+
|        1| 2318|
|        0|   64|
+---------+-----+



In [10]:
# define pre-processing and classification pipeline

from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.feature import IDF, RegexTokenizer, StringIndexer, StopWordsRemover, CountVectorizer, VectorAssembler
from pyspark.ml.classification import LogisticRegression
import nltk
from nltk.corpus import stopwords 

nltk.download('stopwords')
stop_words = list(set(stopwords.words('english')))

book_stringIdx = StringIndexer(inputCol="book_id", outputCol="book_label")

regex_tokenizer = RegexTokenizer()\
    .setGaps(False)\
    .setPattern("\\p{L}+")\
    .setInputCol("review_concat")\
    .setOutputCol("words")

stopword_remover = StopWordsRemover()\
    .setStopWords(stop_words)\
    .setCaseSensitive(False)\
    .setInputCol("words")\
    .setOutputCol("filtered")

count_vectorizer = CountVectorizer(minDF=5)\
    .setInputCol("filtered")\
    .setOutputCol("tf")

idf = IDF()\
    .setInputCol("tf")\
    .setOutputCol("tfidf")

#assembler = VectorAssembler(inputCols=['tfidf','book_label'],outputCol="tfidf_book")

lr = LogisticRegression(featuresCol=idf.getOutputCol(), labelCol="bin_score")

pipeline = Pipeline(stages=[
    book_stringIdx,
    regex_tokenizer,
    stopword_remover,
    count_vectorizer,
    idf,
    #assembler,
    lr])

[nltk_data] Downloading package stopwords to /Users/admin/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [11]:
# utility function to calculate prediction results
def printClassPredictions(predictions):
    predictions.select(fn.expr('float(prediction = bin_score)').alias('correct')).\
        select(fn.avg('correct')).show()
    print('bin_score = 0')
    predictions.filter(predictions['bin_score'] == 0).\
        select(fn.expr('float(prediction = bin_score)').alias('correct')).\
        select(fn.avg('correct')).show()
    print('bin_score = 1')
    predictions.filter(predictions['bin_score'] == 1).\
        select(fn.expr('float(prediction = bin_score)').alias('correct')).\
        select(fn.avg('correct')).show()

In [None]:
# Train the model and save it locally
model = pipeline.fit(training_strat_df_bin)
model.write().overwrite().save("lrm_bin_2.model")

In [66]:
# run the model on the test set to generate predictions
predictionsBin = model.transform(test_strat_df_bin)

In [67]:
predictionsBin.printSchema()

root
 |-- book_id: string (nullable = true)
 |-- book_title: string (nullable = true)
 |-- review_id: string (nullable = true)
 |-- review_score: integer (nullable = true)
 |-- review_text: string (nullable = true)
 |-- review_title: string (nullable = true)
 |-- review_user: string (nullable = true)
 |-- timestamp: long (nullable = true)
 |-- review_concat: string (nullable = true)
 |-- bin_score: integer (nullable = true)
 |-- book_label: double (nullable = false)
 |-- words: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- filtered: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- tf: vector (nullable = true)
 |-- tfidf: vector (nullable = true)
 |-- tfidf_book: vector (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = false)



In [68]:
predictionsBin.filter(predictionsBin['bin_score'] == 1).\
    select("review_score","bin_score","prediction", "probability"). \
    show(30, False)

+------------+---------+----------+------------------------------------------+
|review_score|bin_score|prediction|probability                               |
+------------+---------+----------+------------------------------------------+
|5           |1        |1.0       |[6.5480971361332815E-37,1.0]              |
|4           |1        |1.0       |[4.695704333450706E-30,1.0]               |
|5           |1        |1.0       |[4.999160310346102E-34,1.0]               |
|4           |1        |1.0       |[4.836040797894084E-221,1.0]              |
|4           |1        |1.0       |[2.517938810152065E-33,1.0]               |
|5           |1        |1.0       |[1.6618670213179988E-70,1.0]              |
|5           |1        |1.0       |[3.1176918326909973E-121,1.0]             |
|5           |1        |1.0       |[1.596307693894593E-54,1.0]               |
|5           |1        |1.0       |[9.110905948359848E-75,1.0]               |
|5           |1        |1.0       |[2.23862551866414

In [28]:
predictionsBin.select(fn.expr('float(prediction == bin_score)').alias('correct')).\
        select(fn.avg('correct')).show()

Py4JJavaError: An error occurred while calling o1439.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 11 in stage 217.0 failed 1 times, most recent failure: Lost task 11.0 in stage 217.0 (TID 15237, localhost, executor driver): org.apache.spark.SparkException: Failed to execute user defined function($anonfun$9: (string) => double)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage8.agg_doAggregateWithKeysOutput_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage8.agg_doAggregateWithoutKey_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage8.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$11$$anon$1.hasNext(WholeStageCodegenExec.scala:619)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:125)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:55)
	at org.apache.spark.scheduler.Task.run(Task.scala:121)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:402)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:408)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.SparkException: Unseen label: 1328589358.  To handle unseen labels, set Param handleInvalid to keep.
	at org.apache.spark.ml.feature.StringIndexerModel$$anonfun$9.apply(StringIndexer.scala:260)
	at org.apache.spark.ml.feature.StringIndexerModel$$anonfun$9.apply(StringIndexer.scala:246)
	... 16 more

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1887)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1875)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1874)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1874)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
	at scala.Option.foreach(Option.scala:257)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:926)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2108)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2057)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2046)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:365)
	at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
	at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3384)
	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2545)
	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2545)
	at org.apache.spark.sql.Dataset$$anonfun$53.apply(Dataset.scala:3365)
	at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3364)
	at org.apache.spark.sql.Dataset.head(Dataset.scala:2545)
	at org.apache.spark.sql.Dataset.take(Dataset.scala:2759)
	at org.apache.spark.sql.Dataset.getRows(Dataset.scala:255)
	at org.apache.spark.sql.Dataset.showString(Dataset.scala:292)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.SparkException: Failed to execute user defined function($anonfun$9: (string) => double)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage8.agg_doAggregateWithKeysOutput_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage8.agg_doAggregateWithoutKey_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage8.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$11$$anon$1.hasNext(WholeStageCodegenExec.scala:619)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:125)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:55)
	at org.apache.spark.scheduler.Task.run(Task.scala:121)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:402)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:408)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	... 1 more
Caused by: org.apache.spark.SparkException: Unseen label: 1328589358.  To handle unseen labels, set Param handleInvalid to keep.
	at org.apache.spark.ml.feature.StringIndexerModel$$anonfun$9.apply(StringIndexer.scala:260)
	at org.apache.spark.ml.feature.StringIndexerModel$$anonfun$9.apply(StringIndexer.scala:246)
	... 16 more


In [60]:
new_df = predictionsBin.select(fn.expr('float(prediction == bin_score)').alias('correct'))
new_df.show()

+-------+
|correct|
+-------+
|    1.0|
|    1.0|
|    1.0|
|    1.0|
|    1.0|
|    1.0|
|    0.0|
|    1.0|
|    1.0|
|    1.0|
|    1.0|
|    1.0|
|    1.0|
|    1.0|
|    1.0|
|    1.0|
|    1.0|
|    1.0|
|    0.0|
|    1.0|
+-------+
only showing top 20 rows



In [65]:
agg_new_df = new_df.agg(fn.avg("correct"))
agg_new_df.printSchema()
#agg_new_df.show()

root
 |-- avg(correct): double (nullable = true)



In [23]:
printClassPredictions(predictionsBin)

Py4JJavaError: An error occurred while calling o1309.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 11 in stage 189.0 failed 1 times, most recent failure: Lost task 11.0 in stage 189.0 (TID 13579, localhost, executor driver): org.apache.spark.SparkException: Failed to execute user defined function($anonfun$9: (string) => double)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage8.agg_doAggregateWithKeysOutput_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage8.agg_doAggregateWithoutKey_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage8.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$11$$anon$1.hasNext(WholeStageCodegenExec.scala:619)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:125)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:55)
	at org.apache.spark.scheduler.Task.run(Task.scala:121)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:402)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:408)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.SparkException: Unseen label: 1328589358.  To handle unseen labels, set Param handleInvalid to keep.
	at org.apache.spark.ml.feature.StringIndexerModel$$anonfun$9.apply(StringIndexer.scala:260)
	at org.apache.spark.ml.feature.StringIndexerModel$$anonfun$9.apply(StringIndexer.scala:246)
	... 16 more

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1887)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1875)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1874)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1874)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
	at scala.Option.foreach(Option.scala:257)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:926)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2108)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2057)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2046)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:365)
	at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
	at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3384)
	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2545)
	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2545)
	at org.apache.spark.sql.Dataset$$anonfun$53.apply(Dataset.scala:3365)
	at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3364)
	at org.apache.spark.sql.Dataset.head(Dataset.scala:2545)
	at org.apache.spark.sql.Dataset.take(Dataset.scala:2759)
	at org.apache.spark.sql.Dataset.getRows(Dataset.scala:255)
	at org.apache.spark.sql.Dataset.showString(Dataset.scala:292)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.SparkException: Failed to execute user defined function($anonfun$9: (string) => double)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage8.agg_doAggregateWithKeysOutput_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage8.agg_doAggregateWithoutKey_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage8.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$11$$anon$1.hasNext(WholeStageCodegenExec.scala:619)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:125)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:55)
	at org.apache.spark.scheduler.Task.run(Task.scala:121)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:402)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:408)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	... 1 more
Caused by: org.apache.spark.SparkException: Unseen label: 1328589358.  To handle unseen labels, set Param handleInvalid to keep.
	at org.apache.spark.ml.feature.StringIndexerModel$$anonfun$9.apply(StringIndexer.scala:260)
	at org.apache.spark.ml.feature.StringIndexerModel$$anonfun$9.apply(StringIndexer.scala:246)
	... 16 more


## BELOW - SOME additional TESTS

In [168]:
from pyspark.sql import functions as fn

#printClassPredictions(predictionsBin)
predictionsBin.select('bin_score', 'prediction', fn.expr('float(prediction = bin_score)').alias('correct')).show(n = 30, truncate = 70)
#.\
#a.select(fn.avg('correct')).show()

+---------+----------+-------+
|bin_score|prediction|correct|
+---------+----------+-------+
|        1|       1.0|    1.0|
|        1|       1.0|    1.0|
|        1|       1.0|    1.0|
|        0|       0.0|    1.0|
|        1|       1.0|    1.0|
|        1|       1.0|    1.0|
|        0|       1.0|    0.0|
|        1|       1.0|    1.0|
|        1|       1.0|    1.0|
|        1|       1.0|    1.0|
|        1|       1.0|    1.0|
|        1|       1.0|    1.0|
|        1|       1.0|    1.0|
|        1|       1.0|    1.0|
|        1|       1.0|    1.0|
|        1|       1.0|    1.0|
|        1|       1.0|    1.0|
|        1|       1.0|    1.0|
|        0|       1.0|    0.0|
|        1|       1.0|    1.0|
|        0|       1.0|    0.0|
|        1|       1.0|    1.0|
|        1|       1.0|    1.0|
|        1|       1.0|    1.0|
|        1|       1.0|    1.0|
|        1|       1.0|    1.0|
|        1|       1.0|    1.0|
|        1|       1.0|    1.0|
|        1|       1.0|    1.0|
|       

In [159]:
# show some predictions for which the ground truth was score = 1
predictions_lr_bin.filter(predictions_lr_bin['bin_score'] == 1).\
    select("review_concat","review_score","bin_score","prediction", "probability"). \
    show(n = 30, truncate = 70)

+----------------------------------------------------------------------+------------+---------+----------+----------------------------------------+
|                                                         review_concat|review_score|bin_score|prediction|                             probability|
+----------------------------------------------------------------------+------------+---------+----------+----------------------------------------+
|Mostly pleased. I love the book, but it had a worn dirt mark on the...|           5|        1|       1.0|[0.05597257676497326,0.9440274232350269]|
|💖💖💖 You will not stop thinking about this book until you have re...|           4|        1|       1.0|[0.05597257676497326,0.9440274232350269]|
|Great book I really really am enjoying this read. A feel good book ...|           5|        1|       1.0|[0.05597257676497326,0.9440274232350269]|
|Falling Hard (The Blackhawk Boys, Book 4) by Lexi Ryan Lexi Ryan is...|           4|        1|       1.0|[0.055972