In [1]:
from pyspark.sql import SparkSession

In [3]:
spark = SparkSession.builder.appName('movie').getOrCreate()

In [8]:
movie = spark.read.options(header=True,inferschema=True).csv('data/movies.csv') 
rating = spark.read.options(header=True,inferschema=True).csv('data/ratings.csv') 


In [9]:
movie.show()

+-------+--------------------+--------------------+
|movieId|               title|              genres|
+-------+--------------------+--------------------+
|      1|    Toy Story (1995)|Adventure|Animati...|
|      2|      Jumanji (1995)|Adventure|Childre...|
|      3|Grumpier Old Men ...|      Comedy|Romance|
|      4|Waiting to Exhale...|Comedy|Drama|Romance|
|      5|Father of the Bri...|              Comedy|
|      6|         Heat (1995)|Action|Crime|Thri...|
|      7|      Sabrina (1995)|      Comedy|Romance|
|      8| Tom and Huck (1995)|  Adventure|Children|
|      9| Sudden Death (1995)|              Action|
|     10|    GoldenEye (1995)|Action|Adventure|...|
|     11|American Presiden...|Comedy|Drama|Romance|
|     12|Dracula: Dead and...|       Comedy|Horror|
|     13|        Balto (1995)|Adventure|Animati...|
|     14|        Nixon (1995)|               Drama|
|     15|Cutthroat Island ...|Action|Adventure|...|
|     16|       Casino (1995)|         Crime|Drama|
|     17|Sen

In [10]:
rating.show()

+------+-------+------+---------+
|userId|movieId|rating|timestamp|
+------+-------+------+---------+
|     1|      1|   4.0|964982703|
|     1|      3|   4.0|964981247|
|     1|      6|   4.0|964982224|
|     1|     47|   5.0|964983815|
|     1|     50|   5.0|964982931|
|     1|     70|   3.0|964982400|
|     1|    101|   5.0|964980868|
|     1|    110|   4.0|964982176|
|     1|    151|   5.0|964984041|
|     1|    157|   5.0|964984100|
|     1|    163|   5.0|964983650|
|     1|    216|   5.0|964981208|
|     1|    223|   3.0|964980985|
|     1|    231|   5.0|964981179|
|     1|    235|   4.0|964980908|
|     1|    260|   5.0|964981680|
|     1|    296|   3.0|964982967|
|     1|    316|   3.0|964982310|
|     1|    333|   5.0|964981179|
|     1|    349|   4.0|964982563|
+------+-------+------+---------+
only showing top 20 rows



In [11]:
movie.printSchema()
rating.printSchema()

root
 |-- movieId: integer (nullable = true)
 |-- title: string (nullable = true)
 |-- genres: string (nullable = true)

root
 |-- userId: integer (nullable = true)
 |-- movieId: integer (nullable = true)
 |-- rating: double (nullable = true)
 |-- timestamp: integer (nullable = true)



In [12]:
display(movie)

DataFrame[movieId: int, title: string, genres: string]

In [16]:
rate = rating.join(movie,'movieId','left')

In [19]:
train,test = rate.randomSplit([0.8,0.2])

In [26]:
from pyspark.ml.recommendation import ALS
als = ALS(userCol='userId',itemCol='movieId',ratingCol='rating',nonnegative=True,coldStartStrategy="drop",implicitPrefs=False)

In [23]:
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.tuning import ParamGridBuilder,CrossValidator

In [27]:
param_grid = ParamGridBuilder() \
    .addGrid(als.rank,[10, 50, 100, 150]) \
        .addGrid(als.regParam,[0.01,0.05,0.1,0.15]) \
            .build()

In [28]:
regressor = RegressionEvaluator(
    metricName='rmse',
    labelCol='rating',
    predictionCol='prediction'
)

In [29]:
cv = CrossValidator(estimator=als,estimatorParamMaps=param_grid,evaluator=regressor,numFolds=5)

In [30]:
model = cv.fit(train)
# best_model = cv.bestModel
# testPrediction = best_model.transform(test)
# Rmse = regressor.evaluate(test)
# print(Rmse)

AttributeError: 'CrossValidator' object has no attribute 'bestModel'

In [34]:
best_model=model.bestModel

In [36]:
testPrediction = best_model.transform(test)
Rmse = regressor.evaluate(testPrediction)
print(Rmse)

0.8651490998420742


In [37]:
recommendation = best_model.recommendForAllUsers(5)

In [38]:
df = recommendation

In [41]:
df.show(
    
)

+------+--------------------+
|userId|     recommendations|
+------+--------------------+
|   471|[[68945, 4.738127...|
|   463|[[170355, 4.77379...|
|   496|[[6818, 4.5608644...|
|   148|[[170355, 4.89657...|
|   540|[[170355, 5.34455...|
|   392|[[68945, 4.673846...|
|   243|[[67618, 5.613619...|
|    31|[[33649, 5.117580...|
|   516|[[4429, 4.7351284...|
|   580|[[170355, 4.72950...|
|   251|[[68945, 5.74595]...|
|   451|[[68945, 5.340492...|
|    85|[[1140, 4.850751]...|
|   137|[[68945, 4.957126...|
|    65|[[68945, 4.836951...|
|   458|[[67618, 5.161386...|
|   481|[[25906, 3.908367...|
|    53|[[68945, 6.789389...|
|   255|[[3525, 3.8541474...|
|   588|[[170355, 4.65545...|
+------+--------------------+
only showing top 20 rows



In [42]:
from pyspark.sql.functions import col,explode
df2 = df.withColumn('movieid_rating',explode('recommendations'))

In [43]:
df2.show()

+------+--------------------+-------------------+
|userId|     recommendations|     movieid_rating|
+------+--------------------+-------------------+
|   471|[[68945, 4.738127...| [68945, 4.7381277]|
|   471|[[68945, 4.738127...|[170355, 4.7381277]|
|   471|[[68945, 4.738127...|  [3379, 4.7381277]|
|   471|[[68945, 4.738127...| [33649, 4.4885697]|
|   471|[[68945, 4.738127...| [171495, 4.488084]|
|   463|[[170355, 4.77379...|[170355, 4.7737927]|
|   463|[[170355, 4.77379...| [68945, 4.7737927]|
|   463|[[170355, 4.77379...|  [3379, 4.7737927]|
|   463|[[170355, 4.77379...|  [33649, 4.572375]|
|   463|[[170355, 4.77379...| [171495, 4.529392]|
|   496|[[6818, 4.5608644...|  [6818, 4.5608644]|
|   496|[[6818, 4.5608644...|[170355, 4.3681946]|
|   496|[[6818, 4.5608644...| [68945, 4.3681946]|
|   496|[[6818, 4.5608644...|  [3379, 4.3681946]|
|   496|[[6818, 4.5608644...| [99764, 4.3654785]|
|   148|[[170355, 4.89657...|[170355, 4.8965735]|
|   148|[[170355, 4.89657...| [68945, 4.8965735]|


In [45]:
df2.select('userId',col('movieid_rating.movieId'),col('movieid_rating.rating')).show()

+------+-------+---------+
|userId|movieId|   rating|
+------+-------+---------+
|   471|  68945|4.7381277|
|   471| 170355|4.7381277|
|   471|   3379|4.7381277|
|   471|  33649|4.4885697|
|   471| 171495| 4.488084|
|   463| 170355|4.7737927|
|   463|  68945|4.7737927|
|   463|   3379|4.7737927|
|   463|  33649| 4.572375|
|   463| 171495| 4.529392|
|   496|   6818|4.5608644|
|   496| 170355|4.3681946|
|   496|  68945|4.3681946|
|   496|   3379|4.3681946|
|   496|  99764|4.3654785|
|   148| 170355|4.8965735|
|   148|  68945|4.8965735|
|   148|   3379|4.8965735|
|   148|  33649|4.7929807|
|   148| 171495|4.5731144|
+------+-------+---------+
only showing top 20 rows



In [50]:
rating.count()

100836

In [51]:
spark.stop()