In [7]:
# Following https://spark.apache.org/docs/2.2.0/ml-collaborative-filtering.html
import pyspark
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.recommendation import ALS
from pyspark.sql import Row
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
import time

conf = pyspark.SparkConf().setAll([('spark.master', 'local[2]'), ('spark.app.name', 'Recommender results')])
spark = SparkSession.builder.config(conf=conf).getOrCreate()

train = spark.read.csv("train_baselines.csv", header=True, inferSchema=True)
val = spark.read.csv("val_baselines.csv", header=True, inferSchema=True)

In [8]:
start = time.time()
als = ALS(rank=200, maxIter=10, regParam=0.125, userCol="user_id", itemCol="recipe_id", ratingCol="dual_bayesian_avg_delta",
          coldStartStrategy="drop")
model = als.fit(train)

# Evaluate the model by computing the RMSE on the test data
normalized_predictions = model.transform(val)
predictions = normalized_predictions.withColumn(
    "als_prediction", col("prediction") + col("dual_bayesian_avg")
)
print(time.time() - start)

4214.9910752773285


In [None]:
evaluator = RegressionEvaluator(metricName="rmse", labelCol="rating",
                                predictionCol="als_prediction")

rmse = evaluator.evaluate(predictions)
print(rmse)

0.9029956512423212


In [12]:
pandas_preds = predictions.toPandas()

In [13]:
pandas_preds

Unnamed: 0,_c0,user_id,recipe_id,date,rating,u,i,global_avg,user_avg,user_bayesian_avg,user_bayesian_avg_delta,recipe_avg,recipe_bayesian_avg,dual_bayesian_avg,dual_bayesian_avg_delta,prediction,als_prediction
0,199139,222564,243,2007-03-06,4.0,132,53241,4.573867,4.700667,4.700161,-0.700161,5.000000,4.612606,4.750939,-0.750939,-0.038688,4.712250
1,529052,296050,271,2010-04-27,5.0,6140,103878,4.573867,4.839080,4.821970,0.178030,4.900000,4.682578,4.842231,0.157769,0.038522,4.880754
2,290364,703740,916,2008-01-22,5.0,22810,104686,4.573867,4.700000,4.652700,0.347300,3.933333,4.130421,4.642462,0.357538,-0.036666,4.605796
3,687092,1924722,916,2015-09-13,5.0,15776,104686,4.573867,4.000000,4.344320,0.655680,3.933333,4.130421,4.334082,0.665918,-0.101093,4.232989
4,503521,934536,916,2010-01-16,5.0,6387,104686,4.573867,5.000000,4.744320,0.255680,3.933333,4.130421,4.734082,0.265918,0.081901,4.815983
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
62705,628574,822358,481830,2012-07-01,4.0,15468,161477,4.573867,4.357143,4.422160,-0.422160,5.000000,4.644889,4.456482,-0.456482,-0.097492,4.358990
62706,695192,2123645,502824,2017-08-15,5.0,3644,48148,4.573867,4.488055,4.489777,0.510223,5.000000,4.725075,4.504070,0.495930,0.049679,4.553749
62707,676508,323186,505384,2014-05-26,5.0,172,107667,4.573867,4.947834,4.945860,0.054140,5.000000,4.594159,4.980961,0.019039,0.012248,4.993209
62708,676074,895132,507434,2014-05-14,5.0,262,121026,4.573867,4.881013,4.878697,0.121303,5.000000,4.594159,4.902451,0.097549,0.014039,4.916490


In [14]:
pandas_preds.to_csv('val_als.csv')