In [2]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as f


spark = SparkSession.builder.appName("Chapter4-3").getOrCreate()

In [15]:
ratings = (
    spark.read.csv(
        path="ratings.csv",
        sep=",",
        header=True,
        quote='"',
        schema="userId INT, movieId INT, rating DOUBLE, timestamp INT",
    )
    # .withColumn("timestamp", f.to_timestamp(f.from_unixtime("timestamp")))
    .drop("timestamp")
    .cache()
)

The ALS class has this signature:

```python
class pyspark.ml.recommendation.ALS(
    rank=10,
    maxIter=10,
    regParam=0.1,
    numUserBlocks=10,
    numItemBlocks=10,
    implicitPrefs=False,
    alpha=1.0,
    userCol="user",
    itemCol="item",
    seed=None,
    ratingCol="rating",
    nonnegative=False,
    checkpointInterval=10,
    intermediateStorageLevel="MEMORY_AND_DISK",
    finalStorageLevel="MEMORY_AND_DISK",
    coldStartStrategy="nan",
)
```

In [7]:
from pyspark.ml.recommendation import ALS

In [17]:
model = (
    ALS(
        userCol="userId",
        itemCol="movieId",
        ratingCol="rating",
    ).fit(ratings)
)

In [18]:
predictions = model.transform(ratings)
predictions.show(10, False)

+------+-------+------+----------+
|userId|movieId|rating|prediction|
+------+-------+------+----------+
|191   |148    |5.0   |4.9315734 |
|133   |471    |4.0   |3.3586264 |
|597   |471    |2.0   |3.8164203 |
|385   |471    |4.0   |3.3726332 |
|436   |471    |3.0   |3.655964  |
|602   |471    |4.0   |3.5622292 |
|91    |471    |1.0   |2.1123495 |
|409   |471    |3.0   |3.8733768 |
|372   |471    |3.0   |3.239497  |
|599   |471    |2.5   |2.5592706 |
+------+-------+------+----------+
only showing top 10 rows



In [45]:
model.userFactors.show(5)

+---+--------------------+
| id|            features|
+---+--------------------+
| 10|[-0.011906751, 1....|
| 20|[-0.59221053, 0.6...|
| 30|[-0.63435304, 0.6...|
| 40|[-0.686187, 0.679...|
| 50|[-0.57855225, 0.3...|
+---+--------------------+
only showing top 5 rows



In [41]:
model.itemFactors.show(5)

+---+--------------------+
| id|            features|
+---+--------------------+
| 10|[-0.6802544, 0.43...|
| 20|[-1.0189537, 0.31...|
| 30|[-0.90857005, 0.4...|
| 40|[-0.9740444, 0.50...|
| 50|[-0.89618796, 0.8...|
+---+--------------------+
only showing top 5 rows

