In [1]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('meal').master('local').getOrCreate()

In [2]:
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.recommendation import ALS

In [3]:
data = spark.read.csv('Meal_Info.csv', inferSchema=True, header=True)

In [4]:
data.printSchema()

root
 |-- movieId: integer (nullable = true)
 |-- rating: double (nullable = true)
 |-- userId: integer (nullable = true)
 |-- mealskew: double (nullable = true)
 |-- meal_name: string (nullable = true)



In [5]:
data.show()

+-------+------+------+--------+--------------------+
|movieId|rating|userId|mealskew|           meal_name|
+-------+------+------+--------+--------------------+
|      2|   3.0|     0|     2.0|       Chicken Curry|
|      3|   1.0|     0|     3.0|Spicy Chicken Nug...|
|      5|   2.0|     0|     5.0|           Hamburger|
|      9|   4.0|     0|     9.0|       Taco Surprise|
|     11|   1.0|     0|    11.0|            Meatloaf|
|     12|   2.0|     0|    12.0|        Ceaser Salad|
|     15|   1.0|     0|    15.0|            BBQ Ribs|
|     17|   1.0|     0|    17.0|         Sushi Plate|
|     19|   1.0|     0|    19.0|Cheesesteak Sandw...|
|     21|   1.0|     0|    21.0|             Lasagna|
|     23|   1.0|     0|    23.0|      Orange Chicken|
|     26|   3.0|     0|    26.0|    Spicy Beef Plate|
|     27|   1.0|     0|    27.0|Salmon with Mashe...|
|     28|   1.0|     0|    28.0| Penne Tomatoe Pasta|
|     29|   1.0|     0|    29.0|        Pork Sliders|
|     30|   1.0|     0|    3

In [11]:
data = data.dropna()

In [12]:
data.show()

+-------+------+------+--------+--------------------+
|movieId|rating|userId|mealskew|           meal_name|
+-------+------+------+--------+--------------------+
|      2|   3.0|     0|     2.0|       Chicken Curry|
|      3|   1.0|     0|     3.0|Spicy Chicken Nug...|
|      5|   2.0|     0|     5.0|           Hamburger|
|      9|   4.0|     0|     9.0|       Taco Surprise|
|     11|   1.0|     0|    11.0|            Meatloaf|
|     12|   2.0|     0|    12.0|        Ceaser Salad|
|     15|   1.0|     0|    15.0|            BBQ Ribs|
|     17|   1.0|     0|    17.0|         Sushi Plate|
|     19|   1.0|     0|    19.0|Cheesesteak Sandw...|
|     21|   1.0|     0|    21.0|             Lasagna|
|     23|   1.0|     0|    23.0|      Orange Chicken|
|     26|   3.0|     0|    26.0|    Spicy Beef Plate|
|     27|   1.0|     0|    27.0|Salmon with Mashe...|
|     28|   1.0|     0|    28.0| Penne Tomatoe Pasta|
|     29|   1.0|     0|    29.0|        Pork Sliders|
|     30|   1.0|     0|    3

In [13]:
(training, test) = data.randomSplit([0.8, 0.2])

In [14]:
als = ALS(maxIter=5, regParam=0.01, userCol='userId', itemCol='mealskew', ratingCol='rating')
model = als.fit(training)

In [16]:
predictions = model.transform(test)
predictions.orderBy('prediction', ascending=False).show()

+-------+------+------+--------+--------------------+----------+
|movieId|rating|userId|mealskew|           meal_name|prediction|
+-------+------+------+--------+--------------------+----------+
|     22|   3.0|    29|    22.0|   Pulled Pork Plate|  5.037647|
|      4|   1.0|     7|     4.0|Pretzels and Chee...| 4.6381054|
|      2|   1.0|    25|     2.0|       Chicken Curry| 4.4864473|
|     28|   1.0|     0|    28.0| Penne Tomatoe Pasta|  4.381492|
|     27|   3.0|    24|    27.0|Salmon with Mashe...|  4.210078|
|     19|   1.0|    22|    19.0|Cheesesteak Sandw...|   4.15603|
|      4|   1.0|     5|     4.0|Pretzels and Chee...| 3.9602718|
|     30|   1.0|     4|    30.0| Vietnamese Sandwich| 3.9115555|
|     19|   1.0|     8|    19.0|Cheesesteak Sandw...|  3.711784|
|     23|   3.0|    28|    23.0|      Orange Chicken| 3.6741145|
|     25|   1.0|    15|    25.0| Roast Beef Sandwich| 3.5628712|
|     29|   3.0|     3|    29.0|        Pork Sliders| 3.4298608|
|      6|   1.0|    12|  

In [17]:
evaluator = RegressionEvaluator(metricName='rmse', 
                labelCol='rating', predictionCol='prediction')
rmse = evaluator.evaluate(predictions)
print('Root mean square error = ' + str(rmse))

Root mean square error = 2.2722298656622915
