In [76]:
from pyspark.sql import SparkSession
sparkSession = SparkSession.builder.enableHiveSupport().master("local").getOrCreate()

In [77]:
data = sparkSession.read.parquet("/data/sample264")
meta = sparkSession.read.parquet("/data/meta")

In [78]:
from pyspark.sql import Window
from pyspark.sql.functions import row_number, col, rank, when, sum, abs
import datetime
track_pairs = data.alias("df1").join(data.alias("df2"), "userId", "inner")\
                               .filter("df1.trackId <> df2.trackId")\
                               .select(col("df1.trackId").alias("id1"), col("df2.trackId").alias("id2"),
                                       ((col("df2.timestamp") - col("df1.timestamp"))/60).alias("timeDif"))\

In [79]:
track_pairs_weights = track_pairs.select("id1", "id2",
                                         (when(abs(col("timeDif")) < 7, 1).otherwise(0)).alias("weight"))\
                                 .groupBy("id1", "id2")\
                                 .agg(sum("weight").alias("weights"))\
                                 .filter(col("weights")>0)

In [80]:
def norm(df, key1, n): 
    
    window = Window.partitionBy("id1").orderBy(col("id1"),col("weights").desc()) 
    
    topsDf = df.select("*", row_number().over(window).alias("row_number"))\
                                    .filter(col("row_number") < n)\
                                    .drop(col("row_number"))
            
    sumOfTrackWeights = topsDf.groupBy(col("id1"))\
                              .agg(sum("weights").alias("total_weights"))
        
    normalizedDF = topsDf.join(sumOfTrackWeights, key1, "inner")\
                         .withColumn("norm_weights", col("weights")/col("total_weights"))
                

    return normalizedDF

In [81]:
normilized_weights = norm(track_pairs_weights, 'id1', 51)

In [82]:
result = normilized_weights.orderBy(col("norm_weights").desc(), col("id1"), col("id2"))\
                            .select("id1", "id2")\
                            .take(40)

In [83]:
for val in result:
    print("%s %s" % val)

943060 906644
916106 917005
966277 848609
906644 943060
957991 930935
924333 871513
833441 946677
826881 848807
924085 907245
912219 934393
820030 917240
832462 887023
913080 949890
802110 810825
875978 896805
900521 904430
910631 941522
914119 875262
936520 846587
799433 936403
822869 807426
878315 881915
882619 831384
894756 936409
930544 804125
901112 874270
810938 938608
836133 834366
860070 834845
864928 858352
906140 837660
911867 867312
917702 902978
947006 953666
962689 946994
963056 854480
866447 932601
846523 798768
800639 936092
803925 877664
