In [1]:
from pyspark.sql import SparkSession

In [2]:
sparkSession = SparkSession.builder.enableHiveSupport().master("local").getOrCreate()

In [3]:
from pyspark.sql import Window
from pyspark.sql.functions import row_number, sum, col, abs, count, desc, asc, when, expr, lit, rank, round as rnd

In [4]:
data = sparkSession.read.parquet("/data/sample264")
meta = sparkSession.read.parquet("/data/meta")

In [None]:
def norm(df, key1, key2, field, n): 
    
    window = Window.partitionBy(key1).orderBy(col(field).desc())
        
    topsDF = df.withColumn("row_number", row_number().over(window)) \
        .filter(col("row_number") <= n) \
        .drop(col("row_number")) 
        
    tmpDF = topsDF.groupBy(col(key1)).agg(col(key1), sum(col(field)).alias("sum_" + field))
   
    normalizedDF = topsDF.join(tmpDF, key1, "inner") \
        .withColumn("norm_" + field, col(field) / col("sum_" + field)) \
        .cache()

    return normalizedDF

In [None]:
user = 776748

alpha = 0.15
beta_user_artist = 0.5
beta_user_track = 0.5
beta_track_track = 1
beta_artist_track = 1

In [None]:
data1 = data.select(
    col('userId').alias('userId'), 
    col('trackId').alias('trackId1'), 
    col('timestamp').alias('timestamp1')
)

data2 = data.select(
    col('userId').alias('userId'), 
    col('trackId').alias('trackId2'), 
    col('timestamp').alias('timestamp2')
)

weights = (data1.join(data2, "userId")
           .filter(col('trackId1') != col('trackId2'))
           .filter(abs(col('timestamp1') - col('timestamp2')) <= 420).cache()
           .groupBy(col('trackId1'), col('trackId2'))
           .count().alias('count')
          ).cache()

norm_weights = (norm(weights, "trackId1", "trackId2", "count", 1000)
                .withColumn("nxt_value", col("norm_count") * beta_track_track)
               )
track_track = norm_weights.select(
    col("trackId1").alias("source"),
    col("trackId2").alias("target"),
    col("nxt_value"),
).cache()


In [None]:
weights = (data
              .groupBy(col("userId"), col("trackId"))
              .count().alias("count")
             ).cache()

norm_weights = (norm(weights, "userId", "trackId", "count", 1000)
                .withColumn("nxt_value", col("norm_count") * beta_user_track)
               )

user_track = norm_weights.select(
    col("userId").alias("source"),
    col("trackId").alias("target"),
    col("nxt_value"),
).cache()

In [None]:
weights = (data
           .groupBy(col("userId"), col("artistId"))
           .count().alias("count")
          ).cache()

norm_weights = (norm(weights, "userId", "artistId", "count", 100)
                .withColumn("nxt_value", col("norm_count") * beta_user_artist)
               )

user_artist = norm_weights.select(
    col("userId").alias("source"),
    col("artistId").alias("target"),
    col("nxt_value"),
).cache()

In [None]:
weights = (data
           .groupBy(col("artistId"), col("trackId"))
           .count().alias("count")
          ).cache()

norm_weights = (norm(weights, "artistId", "trackId", "count", 100)
                .withColumn("nxt_value", col("norm_count") * beta_artist_track)
               )

artist_track = norm_weights.select(
    col("artistId").alias("source"),
    col("trackId").alias("target"),
    col("nxt_value"),
).cache()

In [None]:
edges = (track_track
         .union(user_track)
         .union(user_artist)
         .union(artist_track)
        ).cache()

In [None]:
user_data = data.filter(col("userId") == user)

users = (data
         .select(col("userId").alias("id"))
         .distinct()
         .withColumn("p", when(col("id") == user, 1.0).otherwise(0.0))
        )

tracks = (data
          .select(col("trackId").alias("id"))
          .distinct()
          .join(user_data.select(col("trackId").alias("id"), lit(1).alias("tmp")).distinct(), on="id", how="left")
          .withColumn("p", when(~col("tmp").isNull(), 1.0).otherwise(0.0))
          .select(col("id"), col("p"))
         )

artists = (data
           .select(col("artistId").alias("id"))
           .distinct()
           .join(user_data.select(col("artistId").alias("id"), lit(1).alias("tmp")).distinct(), on="id", how="left")
           .withColumn("p", when(~col("tmp").isNull(), 1.0).otherwise(0.0))
           .select(col("id"), col("p"))
         )

x = (users
     .union(artists)
     .union(tracks)
    ).cache()

In [None]:
u = (x
     .withColumn("u_prob", when(col("id") == user, 1.0).otherwise(0.0))
     .select("id", "u_prob") 
    ).cache()


for _ in range(5):
    sigma = (x
             .join(edges, on=expr("id = source"), how="left")
             .na.fill(0.0, ["nxt_value"])
             .withColumn("acc", col("p") * col("nxt_value"))
             .groupBy("target")
             .agg(sum("acc").alias("sigma"))
            )
    x = (u
         .join(sigma, on=expr("id = target"), how="left")
         .na.fill(0.0, ["sigma"])
         .withColumn("nxt_value", alpha*col("u_prob") + (1-alpha) * col("sigma"))
         .select(col("id"), col("nxt_value").alias("p"))
        ).cache()


In [None]:
results = (x
          .where("id != " + str(user))
          .join(meta, on="id")
          .orderBy(col("p").desc())
          .select(col("Name"), col("Artist"), rnd(col("p"), 5).alias("p"))
         ).cache()

In [None]:
for name, artist, p in results.limit(40).collect():
    print("{} {} {}".format(name, artist, p))