In [53]:
from pyspark.sql import SparkSession
sparkSession = SparkSession.builder.enableHiveSupport().master("local[2]").getOrCreate()

In [54]:
data = sparkSession.read.parquet("/data/sample264")
meta = sparkSession.read.parquet("/data/meta")

In [55]:
from pyspark.sql import Window
from pyspark.sql.functions import *

def norm(df, key1, field, n): 
    
    window = Window.partitionBy(key1).orderBy(col(field).desc())
        
    topsDF = df.withColumn("row_number", row_number().over(window)) \
        .filter(col("row_number") <= n) \
        .drop(col("row_number")) 
        
    tmpDF = topsDF.groupBy(col(key1)).agg(col(key1), sum(col(field)).alias("sum_" + field))
   
    normalizedDF = topsDF.join(tmpDF, key1, "inner") \
        .withColumn("norm_" + field, col(field) / col("sum_" + field)) \
        .cache()

    return normalizedDF

In [57]:
data.cache()

data_copy = data.withColumnRenamed("timestamp", "timestamp2").withColumnRenamed("trackId", "trackId2")

track_to_track = data \
    .join(data_copy, "userId") \
    .filter(~(data.trackId==data_copy.trackId2) & (data_copy.timestamp2 >= data.timestamp)) \
    .withColumn("time_diff", col("timestamp2")-col("timestamp")) \
    .filter(col("time_diff") <= 7*60) \
    .groupBy(col("trackId"), col("trackId2")).count()

normalized = norm(track_to_track, "trackId", "count", 40) \
    .withColumn("id", column("trackId")) \
    .withColumn("id2", column("trackId2")) \
    .select(col("id"), col("id2"), col("norm_count"))     

window = Window.orderBy(col("norm_count").desc())

top = normalized.withColumn("pos", rank().over(window)) \
    .filter(col("pos") <= 40) \
    .orderBy(col("id").asc(), col("id2").asc()) \
    .select(col("id"), col("id2")) \
    .take(40)

    
for item in top:
    print "%s %s" % item
    
    




801701 920990
808110 894437
809289 847119
814446 870227
819569 800325
827209 942995
828366 830694
829292 871752
830062 849304
831434 856391
832475 925631
832553 836728
836522 907798
840315 878511
841759 898484
844651 897648
844819 834559
847806 949091
852427 825116
856311 875086
857303 943835
875876 916850
878289 814956
879172 898823
879366 814475
882856 841509
889636 799651
890604 904285
890920 838812
895618 944759
903281 810518
904487 810488
907516 845402
923176 831580
926952 818440
932765 860022
933119 883990
935205 829417
940165 823397
941115 949312
