In [15]:
from pyspark.sql import SparkSession
sparkSession = SparkSession.builder.enableHiveSupport().master("local").getOrCreate()

In [16]:
data = sparkSession.read.parquet("/data/sample264")
meta = sparkSession.read.parquet("/data/meta")

from pyspark.sql import Window
from pyspark.sql.functions import row_number, col, rank, when, sum, abs, count

In [17]:

import datetime
track_pairs = data.alias("df1").join(data.alias("df2"), "userId", "inner")\
                               .filter("df1.trackId <> df2.trackId")\
                               .select(col("df1.trackId").alias("id1"), col("df2.trackId").alias("id2"),
                                       ((col("df2.timestamp") - col("df1.timestamp"))/60).alias("timeDif"))\

In [18]:
track_pairs_weights = track_pairs.select("id1", "id2",
                                         (when(abs(col("timeDif")) < 7, 1).otherwise(0)).alias("weight"))\
                                 .groupBy("id1", "id2")\
                                 .agg(sum("weight").alias("weights"))\
                                 .filter(col("weights")>0)

In [19]:
def norm(df, key1, field, n): 
    
    window = Window.partitionBy(key1).orderBy(col(key1), col(field).desc()) 
    
    topsDf = df.select("*", row_number().over(window).alias("row_number"))\
                                    .filter(col("row_number") < n)\
                                    .drop(col("row_number"))
            
    sumOfTrackWeights = topsDf.groupBy(col(key1))\
                              .agg(sum(field).alias("total_" + field))
        
    normalizedDF = topsDf.join(sumOfTrackWeights, key1, "inner")\
                         .withColumn("norm_" + field, col(field)/col("total_" + field))
                

    return normalizedDF

In [20]:
normilized_weights = norm(track_pairs_weights, 'id1', "weights", 51)

In [None]:
result = normilized_weights.orderBy(col("norm_weights").desc(), col("id1"), col("id2"))\
                            .select("id1", "id2")\
                            .take(40)

In [None]:
#for val in result:
#    print("%s %s" % val)

798256 923706
798319 837992
798322 876562
798331 827364
798335 840741
798374 816874
798375 810685
798379 812055
798380 840113
798396 817687
798398 926302
798405 867217
798443 905923
798457 918918
798460 891840
798461 940379
798470 840814
798474 963162
798477 883244
798485 955521
798505 905671
798550 936295
798626 845438
798691 818279
798692 898823
798702 811440
798704 937570
798725 933147
798738 894170
798745 799665
798782 956938
798801 950802
798820 890393
798833 916319
798865 962662
798931 893574
798946 946408
799012 809997
799024 935246
799047 905199


In [89]:
userTrack = data.groupBy("userId", "trackId")\
                .agg(count('*').alias("weight"))

In [98]:
result2 = norm(userTrack, "userId", "weight", 1001).orderBy(col("norm_weight").desc(), 
                                                            col("userId"), 
                                                            col("trackId"))\
                                                   .select(col("userId"), col("trackId"))\
                                                   .take(40)

In [99]:
for val in result2:
    print("%s %s" % val)

66 965774
116 867268
128 852564
131 880170
195 946408
215 860111
235 897176
300 857973
321 915545
328 943482
333 818202
346 864911
356 961308
428 943572
431 902497
445 831381
488 841340
542 815388
617 946395
649 901672
658 937522
662 881433
698 935934
708 952432
746 879259
747 879259
776 946408
784 806468
806 866581
811 948017
837 799685
901 871513
923 879322
934 940714
957 945183
989 878364
999 967768
1006 962774
1049 849484
1057 920458
