In [1]:
#LINK TO AMAZON DATASET: https://nijianmo.github.io/amazon/index.html#sample-metadata

In [2]:
import numpy as np
import pandas as pd
from random import randint

In [3]:
#Convert ReviewerID to unique integers
#new_reviewerID = random.sample(range(0, 4607047), 4607047)

#new_reviewerID = spark.createDataFrame(new_reviewerID, IntegerType()).collect()

In [4]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructField, DoubleType, IntegerType
from pyspark.sql import functions as F

from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.recommendation import ALS
from pyspark.ml.tuning import TrainValidationSplit, ParamGridBuilder



spark = SparkSession.builder.getOrCreate()
spark.conf.set('spark.sql.repl.eagerEval.enabled', True)

In [5]:
df = spark.read.csv('/home/luca/Downloads/ratings_Movies_and_TV.csv')

In [6]:
df.rdd.id()

14

In [7]:
df.schema

StructType(List(StructField(_c0,StringType,true),StructField(_c1,StringType,true),StructField(_c2,StringType,true),StructField(_c3,StringType,true)))

In [8]:
df.tail(10)

[Row(_c0='A17W587EH23J0Q', _c1='B00LT1JHLW', _c2='5.0', _c3='1405641600'),
 Row(_c0='A3E4Q2YOYCKXON', _c1='B00LT1JHLW', _c2='5.0', _c3='1405987200'),
 Row(_c0='A1U1UNV1RLCKRL', _c1='B00LT1JHLW', _c2='3.0', _c3='1406073600'),
 Row(_c0='A14THKG1X8861X', _c1='B00LT1JHLW', _c2='5.0', _c3='1405555200'),
 Row(_c0='A3DE438TF1A958', _c1='B00LT1JHLW', _c2='5.0', _c3='1405728000'),
 Row(_c0='AHCV1RTGY3PJ8', _c1='B00LT1JHLW', _c2='5.0', _c3='1405641600'),
 Row(_c0='A2RWCXDMANY0LW', _c1='B00LT1JHLW', _c2='5.0', _c3='1405987200'),
 Row(_c0='A3V9PIFRME2XCW', _c1='B00LT1JHLW', _c2='5.0', _c3='1405900800'),
 Row(_c0='A3ROPC55BE2OM9', _c1='B00LT1JHLW', _c2='5.0', _c3='1405728000'),
 Row(_c0='A2ARBNMH5Q5YM1', _c1='B00LVGP8EA', _c2='5.0', _c3='1405641600')]

In [9]:
#Changing column names - https://stackoverflow.com/questions/34077353/how-to-change-dataframe-column-names-in-pyspark

df = df.selectExpr("_c0 as ReviewerID", "_c1 as ProductID", "_c2 as Rating", "_c3 as unixReviewTime")

In [10]:
#Items of interest are ReviewerID and ProductID

df.show()

+--------------+----------+------+--------------+
|    ReviewerID| ProductID|Rating|unixReviewTime|
+--------------+----------+------+--------------+
|A3R5OBKS7OM2IR|0000143502|   5.0|    1358380800|
|A3R5OBKS7OM2IR|0000143529|   5.0|    1380672000|
| AH3QC2PC1VTGP|0000143561|   2.0|    1216252800|
|A3LKP6WPMP9UKX|0000143588|   5.0|    1236902400|
| AVIY68KEPQ5ZD|0000143588|   5.0|    1232236800|
|A1CV1WROP5KTTW|0000589012|   5.0|    1309651200|
| AP57WZ2X4G0AA|0000589012|   2.0|    1366675200|
|A3NMBJ2LCRCATT|0000589012|   5.0|    1393804800|
| A5Y15SAOMX6XA|0000589012|   2.0|    1307404800|
|A3P671HJ32TCSF|0000589012|   5.0|    1393718400|
|A3VCKTRD24BG7K|0000589012|   5.0|    1378425600|
| ANF0AGIV0JCH2|0000589012|   5.0|    1308182400|
|A3LDEBLV6MVUBE|0000589012|   5.0|    1208995200|
|A1R2XZWQ6NM5M1|0000589012|   5.0|    1224979200|
|A36L1XGA5AQIJY|0000589012|   1.0|    1393113600|
|A2HWI21H23GDS4|0000589012|   4.0|    1338681600|
|A1DNYFL3RSXRMO|0000589012|   5.0|    1208908800|


In [11]:
#MUST FIT INTEGER RANGE: -2147483648 to 2147483647 - https://spark.apache.org/docs/latest/sql-ref-datatypes.html

df.count()

4607047

In [12]:
new_df = df.take(100000)

In [13]:
new_df = spark.createDataFrame(new_df)

In [14]:
df.dtypes

[('ReviewerID', 'string'),
 ('ProductID', 'string'),
 ('Rating', 'string'),
 ('unixReviewTime', 'string')]

In [15]:
df = df.withColumn("ReviewerID", df["ReviewerID"].cast(IntegerType()))
df = df.withColumn("ProductID", df["ProductID"].cast(IntegerType()))
df = df.withColumn("Rating", df["Rating"].cast(IntegerType()))
df = df.withColumn("unixReviewTime", df["unixReviewTime"].cast(IntegerType()))

In [16]:
#Items of interest are ReviewerID and ProductID

df.show()

+----------+---------+------+--------------+
|ReviewerID|ProductID|Rating|unixReviewTime|
+----------+---------+------+--------------+
|      null|   143502|     5|    1358380800|
|      null|   143529|     5|    1380672000|
|      null|   143561|     2|    1216252800|
|      null|   143588|     5|    1236902400|
|      null|   143588|     5|    1232236800|
|      null|   589012|     5|    1309651200|
|      null|   589012|     2|    1366675200|
|      null|   589012|     5|    1393804800|
|      null|   589012|     2|    1307404800|
|      null|   589012|     5|    1393718400|
|      null|   589012|     5|    1378425600|
|      null|   589012|     5|    1308182400|
|      null|   589012|     5|    1208995200|
|      null|   589012|     5|    1224979200|
|      null|   589012|     1|    1393113600|
|      null|   589012|     4|    1338681600|
|      null|   589012|     5|    1208908800|
|      null|   589012|     1|    1218412800|
|      null|   589012|     5|    1322956800|
|      nul

In [17]:
# Taken from: https://stackoverflow.com/questions/44153575/fill-na-with-random-numbers-in-pyspark
new_df = df.withColumn('ReviewerID', F.coalesce(F.col('ReviewerID'), (F.round(F.rand()*100000)))).collect()
new_df = spark.createDataFrame(new_df)

In [18]:
new_df.show()

+----------+---------+------+--------------+
|ReviewerID|ProductID|Rating|unixReviewTime|
+----------+---------+------+--------------+
|   20897.0|   143502|     5|    1358380800|
|    7518.0|   143529|     5|    1380672000|
|   53908.0|   143561|     2|    1216252800|
|   81400.0|   143588|     5|    1236902400|
|   11893.0|   143588|     5|    1232236800|
|   25607.0|   589012|     5|    1309651200|
|   79693.0|   589012|     2|    1366675200|
|    6297.0|   589012|     5|    1393804800|
|   79307.0|   589012|     2|    1307404800|
|    3650.0|   589012|     5|    1393718400|
|   72012.0|   589012|     5|    1378425600|
|   85204.0|   589012|     5|    1308182400|
|   60006.0|   589012|     5|    1208995200|
|   19893.0|   589012|     5|    1224979200|
|   21909.0|   589012|     1|    1393113600|
|   13983.0|   589012|     4|    1338681600|
|   97826.0|   589012|     5|    1208908800|
|   48180.0|   589012|     1|    1218412800|
|    3207.0|   589012|     5|    1322956800|
|   14854.

In [19]:
del(df)

In [20]:
new_df.dtypes

[('ReviewerID', 'double'),
 ('ProductID', 'bigint'),
 ('Rating', 'bigint'),
 ('unixReviewTime', 'bigint')]

In [21]:
new_df = new_df.withColumn("ReviewerID", new_df["ReviewerID"].cast(IntegerType()))
new_df = new_df.withColumn("ProductID", new_df["ProductID"].cast(IntegerType()))
new_df = new_df.withColumn("Rating", new_df["Rating"].cast(DoubleType()))
new_df = new_df.withColumn("unixReviewTime", new_df["unixReviewTime"].cast(IntegerType()))

In [22]:
new_df.dtypes

[('ReviewerID', 'int'),
 ('ProductID', 'int'),
 ('Rating', 'double'),
 ('unixReviewTime', 'int')]

In [23]:
new_df.show()

+----------+---------+------+--------------+
|ReviewerID|ProductID|Rating|unixReviewTime|
+----------+---------+------+--------------+
|     20897|   143502|   5.0|    1358380800|
|      7518|   143529|   5.0|    1380672000|
|     53908|   143561|   2.0|    1216252800|
|     81400|   143588|   5.0|    1236902400|
|     11893|   143588|   5.0|    1232236800|
|     25607|   589012|   5.0|    1309651200|
|     79693|   589012|   2.0|    1366675200|
|      6297|   589012|   5.0|    1393804800|
|     79307|   589012|   2.0|    1307404800|
|      3650|   589012|   5.0|    1393718400|
|     72012|   589012|   5.0|    1378425600|
|     85204|   589012|   5.0|    1308182400|
|     60006|   589012|   5.0|    1208995200|
|     19893|   589012|   5.0|    1224979200|
|     21909|   589012|   1.0|    1393113600|
|     13983|   589012|   4.0|    1338681600|
|     97826|   589012|   5.0|    1208908800|
|     48180|   589012|   1.0|    1218412800|
|      3207|   589012|   5.0|    1322956800|
|     1485

In [24]:
(training, test) = new_df.randomSplit([0.8, 0.2])

In [25]:
als = ALS(maxIter=5, regParam=0.01, userCol="ReviewerID", itemCol="ProductID", ratingCol="Rating",
          coldStartStrategy="nan")
model = als.fit(training)

Py4JJavaError: An error occurred while calling o112.fit.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 10.0 failed 1 times, most recent failure: Lost task 0.0 in stage 10.0 (TID 24) (luca executor driver): org.apache.spark.SparkException: Failed to execute user defined function(ALSModelParams$$Lambda$3108/0x000000084125b040: (int) => int)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:755)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
	at scala.collection.Iterator$SliceIterator.hasNext(Iterator.scala:266)
	at scala.collection.Iterator.foreach(Iterator.scala:941)
	at scala.collection.Iterator.foreach$(Iterator.scala:941)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1429)
	at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
	at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
	at scala.collection.TraversableOnce.to(TraversableOnce.scala:315)
	at scala.collection.TraversableOnce.to$(TraversableOnce.scala:313)
	at scala.collection.AbstractIterator.to(Iterator.scala:1429)
	at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:307)
	at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:307)
	at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1429)
	at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:294)
	at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:288)
	at scala.collection.AbstractIterator.toArray(Iterator.scala:1429)
	at org.apache.spark.rdd.RDD.$anonfun$take$2(RDD.scala:1449)
	at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2236)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:131)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:497)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1439)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:500)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	at java.base/java.lang.Thread.run(Thread.java:829)
Caused by: java.lang.IllegalArgumentException: ALS only supports values in Integer range for columns ReviewerID and ProductID. Value null was not numeric.
	at org.apache.spark.ml.recommendation.ALSModelParams.$anonfun$checkedCast$1(ALS.scala:104)
	at org.apache.spark.ml.recommendation.ALSModelParams.$anonfun$checkedCast$1$adapted(ALS.scala:89)
	... 33 more

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2258)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2207)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2206)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2206)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1079)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1079)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1079)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2445)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2387)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2376)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:868)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2196)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2217)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2236)
	at org.apache.spark.rdd.RDD.$anonfun$take$1(RDD.scala:1449)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:414)
	at org.apache.spark.rdd.RDD.take(RDD.scala:1422)
	at org.apache.spark.rdd.RDD.$anonfun$isEmpty$1(RDD.scala:1557)
	at scala.runtime.java8.JFunction0$mcZ$sp.apply(JFunction0$mcZ$sp.java:23)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:414)
	at org.apache.spark.rdd.RDD.isEmpty(RDD.scala:1557)
	at org.apache.spark.ml.recommendation.ALS$.train(ALS.scala:947)
	at org.apache.spark.ml.recommendation.ALS.$anonfun$fit$1(ALS.scala:709)
	at org.apache.spark.ml.util.Instrumentation$.$anonfun$instrumented$1(Instrumentation.scala:191)
	at scala.util.Try$.apply(Try.scala:213)
	at org.apache.spark.ml.util.Instrumentation$.instrumented(Instrumentation.scala:191)
	at org.apache.spark.ml.recommendation.ALS.fit(ALS.scala:691)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.base/java.lang.Thread.run(Thread.java:829)
Caused by: org.apache.spark.SparkException: Failed to execute user defined function(ALSModelParams$$Lambda$3108/0x000000084125b040: (int) => int)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:755)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
	at scala.collection.Iterator$SliceIterator.hasNext(Iterator.scala:266)
	at scala.collection.Iterator.foreach(Iterator.scala:941)
	at scala.collection.Iterator.foreach$(Iterator.scala:941)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1429)
	at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
	at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
	at scala.collection.TraversableOnce.to(TraversableOnce.scala:315)
	at scala.collection.TraversableOnce.to$(TraversableOnce.scala:313)
	at scala.collection.AbstractIterator.to(Iterator.scala:1429)
	at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:307)
	at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:307)
	at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1429)
	at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:294)
	at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:288)
	at scala.collection.AbstractIterator.toArray(Iterator.scala:1429)
	at org.apache.spark.rdd.RDD.$anonfun$take$2(RDD.scala:1449)
	at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2236)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:131)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:497)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1439)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:500)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	... 1 more
Caused by: java.lang.IllegalArgumentException: ALS only supports values in Integer range for columns ReviewerID and ProductID. Value null was not numeric.
	at org.apache.spark.ml.recommendation.ALSModelParams.$anonfun$checkedCast$1(ALS.scala:104)
	at org.apache.spark.ml.recommendation.ALSModelParams.$anonfun$checkedCast$1$adapted(ALS.scala:89)
	... 33 more


In [26]:
"""ERROR: java.lang.IllegalArgumentException: 
ALS only supports values in Integer range for columns ReviewerID and ProductID.
Value null was not numeric."""

training.filter("ReviewerID is NULL").show()
training.filter("ProductID is NULL").show()

+----------+---------+------+--------------+
|ReviewerID|ProductID|Rating|unixReviewTime|
+----------+---------+------+--------------+
+----------+---------+------+--------------+

+----------+---------+------+--------------+
|ReviewerID|ProductID|Rating|unixReviewTime|
+----------+---------+------+--------------+
|         0|     null|   1.0|     953424000|
|         1|     null|   2.0|    1345507200|
|         1|     null|   5.0|    1265760000|
|         1|     null|   5.0|    1279670400|
|         2|     null|   5.0|     962236800|
|         3|     null|   4.0|    1380240000|
|         3|     null|   5.0|    1389744000|
|         4|     null|   5.0|    1036540800|
|         4|     null|   5.0|    1384300800|
|         6|     null|   3.0|    1293667200|
|         6|     null|   5.0|     971049600|
|         6|     null|   5.0|    1363046400|
|         7|     null|   1.0|    1385942400|
|         7|     null|   3.0|    1316217600|
|        10|     null|   3.0|    1108512000|
|        

In [27]:
training = training.dropna()
training.filter("ReviewerID is NULL").show()
training.filter("ProductID is NULL").show()

+----------+---------+------+--------------+
|ReviewerID|ProductID|Rating|unixReviewTime|
+----------+---------+------+--------------+
+----------+---------+------+--------------+

+----------+---------+------+--------------+
|ReviewerID|ProductID|Rating|unixReviewTime|
+----------+---------+------+--------------+
+----------+---------+------+--------------+



In [28]:
test.filter("ReviewerID is NULL").show()
test.filter("ProductID is NULL").show()
test = test.dropna()

+----------+---------+------+--------------+
|ReviewerID|ProductID|Rating|unixReviewTime|
+----------+---------+------+--------------+
+----------+---------+------+--------------+

+----------+---------+------+--------------+
|ReviewerID|ProductID|Rating|unixReviewTime|
+----------+---------+------+--------------+
|         1|     null|   5.0|    1179619200|
|         2|     null|   5.0|     955670400|
|         6|     null|   5.0|    1373673600|
|        10|     null|   1.0|    1193097600|
|        10|     null|   3.0|    1013472000|
|        13|     null|   5.0|    1059177600|
|        17|     null|   4.0|    1096156800|
|        22|     null|   5.0|    1372982400|
|        24|     null|   4.0|    1046822400|
|        33|     null|   5.0|    1350345600|
|        34|     null|   5.0|    1357689600|
|        35|     null|   5.0|    1216944000|
|        39|     null|   5.0|    1336521600|
|        45|     null|   5.0|    1366761600|
|        48|     null|   5.0|    1397520000|
|        

In [29]:
als = ALS(userCol="ReviewerID", itemCol="ProductID", ratingCol="Rating",
          coldStartStrategy="drop")

In [30]:
param_grid = ParamGridBuilder()\
.addGrid(als.rank, [12, 13, 14])\
.addGrid(als.maxIter, [18, 19, 20])\
.addGrid(als.regParam, [.17, .18, .19])\
.build()

In [31]:
evaluator = RegressionEvaluator(metricName="rmse", labelCol="Rating", 
                                predictionCol="prediction")

In [32]:
tvs = TrainValidationSplit(estimator=als, estimatorParamMaps=param_grid,
                          evaluator=evaluator)

In [33]:
model = tvs.fit(training)

In [34]:
best_model = model.bestModel

In [35]:
predictions = best_model.transform(test)
rmse = evaluator.evaluate(predictions)

In [36]:
print("RMSE = " + str(rmse))
print("**BEST MODEL**")
print("Rank: ", best_model.rank)
print("MaxIter: ", best_model._java_obj.parent().getMaxIter())
print("RegParam: ", best_model._java_obj.parent().getRegParam())

RMSE = 1.4868219383808512
**BEST MODEL**
Rank:  14
MaxIter:  20
RegParam:  0.19


In [37]:
display(predictions.sort("ReviewerID", "Rating"))

ReviewerID,ProductID,Rating,unixReviewTime,prediction
2,792838742,3.0,970444800,1.4603155
2,1608838137,5.0,1378425600,3.4005356
4,784010218,5.0,1402790400,4.750594
6,790734680,5.0,1395705600,4.199687
7,792837746,4.0,1360713600,3.7072403
7,767839145,5.0,1369526400,4.0521116
7,1573303593,5.0,1385337600,3.705644
9,1573622990,4.0,1370563200,3.0953841
11,790732203,5.0,1391558400,4.1523786
14,790730987,5.0,1376870400,4.7076497
