In [0]:
df = spark.read.format("csv") \
    .option("header","true") \
    .option("inferSchema","true") \
    .load("/Volumes/workspace/ecommerce/ecommerce_data/2019-Oct.csv")

df_sample = df.limit(300)

In [0]:
df_sample = df.sample(fraction=0.01, seed=42)

In [0]:
from pyspark.sql import functions as F

interaction_df = df_sample.withColumn(
    "rating",
    F.when(F.col("event_type") == "purchase", 3)
     .when(F.col("event_type") == "cart", 2)
     .otherwise(1)
).select("user_id", "product_id", "rating")

In [0]:
from pyspark.ml.recommendation import ALS

als = ALS(
    userCol="user_id",
    itemCol="product_id",
    ratingCol="rating",
    coldStartStrategy="drop",
    maxIter=5,
    rank=5,
    regParam=0.1,
    seed=42
)

als_model = als.fit(interaction_df)

In [0]:
users = interaction_df.select("user_id").distinct()
items = interaction_df.select("product_id").distinct()

print("Users:", users.count())
print("Items:", items.count())

user_item_pairs = users.crossJoin(items)

Users: 320926
Items: 59280


In [0]:
predictions = als_model.transform(user_item_pairs)

predictions_clean = predictions.filter(
    F.col("prediction").isNotNull()
)

In [0]:
top_items = interaction_df.groupBy("product_id") \
    .count() \
    .orderBy(F.col("count").desc()) \
    .limit(200)

In [0]:
user_item_pairs = users.crossJoin(top_items.select("product_id"))


In [0]:
predictions = als_model.transform(user_item_pairs)

predictions_clean = predictions.filter(
    F.col("prediction").isNotNull()
)

In [0]:
predictions_clean = predictions_clean.repartition("user_id")

In [0]:
from pyspark.sql.window import Window

window = Window.partitionBy("user_id") \
               .orderBy(F.col("prediction").desc())

top_recs = predictions_clean.withColumn(
    "rank",
    F.row_number().over(window)
).filter(F.col("rank") <= 5)

In [0]:
top_recs.show(20, truncate=False)

+---------+----------+----------+----+
|user_id  |product_id|prediction|rank|
+---------+----------+----------+----+
|380905505|5100816   |0.93112576|1   |
|380905505|4804295   |0.91285634|2   |
|380905505|1005102   |0.7812248 |3   |
|380905505|15100367  |0.7603223 |4   |
|380905505|1004833   |0.7163775 |5   |
|383633259|4803977   |0.93878865|1   |
|383633259|1004739   |0.8734791 |2   |
|383633259|1004958   |0.85349154|3   |
|383633259|3601485   |0.83700615|4   |
|383633259|1004886   |0.82647336|5   |
|384989212|4804056   |0.87851334|1   |
|384989212|1004856   |0.8117305 |2   |
|384989212|1005006   |0.80037004|3   |
|384989212|1801860   |0.74684644|4   |
|384989212|1004709   |0.7456994 |5   |
|388330497|12703494  |0.776767  |1   |
|388330497|2601299   |0.76487327|2   |
|388330497|1004785   |0.7469665 |3   |
|388330497|15100370  |0.7090553 |4   |
|388330497|1005239   |0.66718894|5   |
+---------+----------+----------+----+
only showing top 20 rows


In [0]:
top_recs.filter(F.col("user_id") == 512365995).show()

+---------+----------+----------+----+
|  user_id|product_id|prediction|rank|
+---------+----------+----------+----+
|512365995|   1004849| 1.1926501|   1|
|512365995|   1005073| 1.1642694|   2|
|512365995|   1005031| 1.1254295|   3|
|512365995|   1005217| 1.1163355|   4|
|512365995|   1004653| 1.1148835|   5|
+---------+----------+----------+----+

