In [106]:
import argparse
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pyspark.sql.functions import col, count, collect_set, size, array, flatten, array_distinct, explode, rand
from pyspark.sql.window import Window

In [107]:
spark = SparkSession.builder.getOrCreate()
sc = spark.sparkContext

rdd = [("A", "1", 4), ("C", "1", 4), ("B", "1", 4), ("C", "1", 2), ("D", "1", 4),
        ("B", "2", 4), ("C", "2", 4), ("A", "2", 4), 
        ("A", "3", 4), ("C", "3", 4),
        ("A", "4", 4), ("B", "4", 4), ("C", "4", 4),
        ("A", "5", 4), ("B", "5", 4), ("C", "5", 4),
        ("A", "6", 4), ("B", "6", 4), ("C", "6", 4), ("D", "6", 4), ("E", "6", 4), ("F", "6", 4),
        ("A", "7", 4), ("B", "7", 4), ("C", "7", 4),
        ("A", "8", 4), ("B", "8", 4), ("C", "8", 4),
        ("A", "9", 4), ("B", "9", 4), ("C", "9", 4),
        ("A", "10", 4), ("B", "10", 4), ("C", "10", 4),
        ("A", "11", 4), ("B", "11", 4), ("C", "11", 4), ("D", "11", 4),
        ("A", "12", 4), ("B", "12", 4), ("C", "12", 4), ("D", "12", 4),
        ("A", "13", 4), ("B", "13", 4), ("C", "13", 4), ("D", "13", 4), ("E", "13", 4),
        ("A", "14", 4), ("B", "14", 4), ("C", "14", 4), ("D", "14", 4), ("E", "14", 4),
        ("A", "15", 4), ("B", "15", 4), ("C", "15", 4), ("D", "15", 4), ("E", "15", 4),
        ("A", "16", 4), ("B", "16", 4), ("C", "16", 4), ("D", "16", 4), ("E", "16", 4), ("F", "16", 4)
        ]
schema = StructType([
    StructField("productId", StringType(), True),
    StructField("userId", StringType(), True),
    StructField("score", IntegerType(), True)
])

input_df = spark.createDataFrame(rdd, schema=schema).cache()

In [108]:
input_df.show()

+---------+------+-----+
|productId|userId|score|
+---------+------+-----+
|        A|     1|    4|
|        C|     1|    4|
|        B|     1|    4|
|        C|     1|    2|
|        D|     1|    4|
|        B|     2|    4|
|        C|     2|    4|
|        A|     2|    4|
|        A|     3|    4|
|        C|     3|    4|
|        A|     4|    4|
|        B|     4|    4|
|        C|     4|    4|
|        A|     5|    4|
|        B|     5|    4|
|        C|     5|    4|
|        A|     6|    4|
|        B|     6|    4|
|        C|     6|    4|
|        D|     6|    4|
+---------+------+-----+
only showing top 20 rows



In [109]:
input_df = input_df.filter(input_df["score"] >= 4)
input_df.show()

+---------+------+-----+
|productId|userId|score|
+---------+------+-----+
|        A|     1|    4|
|        C|     1|    4|
|        B|     1|    4|
|        D|     1|    4|
|        B|     2|    4|
|        C|     2|    4|
|        A|     2|    4|
|        A|     3|    4|
|        C|     3|    4|
|        A|     4|    4|
|        B|     4|    4|
|        C|     4|    4|
|        A|     5|    4|
|        B|     5|    4|
|        C|     5|    4|
|        A|     6|    4|
|        B|     6|    4|
|        C|     6|    4|
|        D|     6|    4|
|        E|     6|    4|
+---------+------+-----+
only showing top 20 rows



In [110]:
joined_df = input_df.alias("df1").join(input_df.alias("df2"), ["productId"]) \
    .where(col("df1.userId") < col("df2.userId")) \
    .select(col("df1.userId").alias("userId1"), col("df2.userId").alias("userId2"), col("df1.productId").alias("productId")) \
    .distinct() 

joined_df.show()

+-------+-------+---------+
|userId1|userId2|productId|
+-------+-------+---------+
|      1|      2|        A|
|      1|      3|        A|
|      1|      4|        A|
|      1|      5|        A|
|      1|      6|        A|
|      1|      7|        A|
|      1|      8|        A|
|      1|      9|        A|
|      1|     10|        A|
|      1|     11|        A|
|      1|     12|        A|
|      1|     13|        A|
|      1|     14|        A|
|      1|     15|        A|
|      1|     16|        A|
|      2|      3|        A|
|      2|      4|        A|
|      2|      5|        A|
|      2|      6|        A|
|      2|      7|        A|
+-------+-------+---------+
only showing top 20 rows



In [111]:
couple_df = joined_df.groupBy("userId1", "userId2").agg(collect_set("productId").alias("productsAffinity")) \
    .where(size(col("productsAffinity")) >= 3) \
    .select("userId1", "userId2", "productsAffinity") \
    .cache()

couple_df.show()

+-------+-------+----------------+
|userId1|userId2|productsAffinity|
+-------+-------+----------------+
|     10|      5|       [C, B, A]|
|     11|     14|    [C, B, A, D]|
|     15|      8|       [C, B, A]|
|     10|     16|       [C, B, A]|
|     12|      7|       [C, B, A]|
|     16|      2|       [C, B, A]|
|     15|      2|       [C, B, A]|
|      1|      4|       [C, B, A]|
|      1|     11|    [C, B, A, D]|
|      4|      6|       [C, B, A]|
|     12|     16|    [C, B, A, D]|
|      2|      8|       [C, B, A]|
|     11|      7|       [C, B, A]|
|     13|     14| [C, E, B, A, D]|
|     13|      7|       [C, B, A]|
|     11|     16|    [C, B, A, D]|
|     14|      5|       [C, B, A]|
|     10|     13|       [C, B, A]|
|      4|      8|       [C, B, A]|
|      6|      8|       [C, B, A]|
+-------+-------+----------------+
only showing top 20 rows



In [112]:
group_df = couple_df.withColumn(
    "userId",
    explode(array("userId1", "userId2"))
).groupBy("productsAffinity").agg(collect_set("userId").alias("groupUsers"))

group_df.show()

+------------------+--------------------+
|  productsAffinity|          groupUsers|
+------------------+--------------------+
|         [C, B, A]|[1, 15, 2, 5, 8, ...|
|      [C, B, A, D]|[13, 12, 1, 15, 1...|
|   [C, E, B, A, D]| [13, 15, 16, 14, 6]|
|[F, C, E, B, A, D]|             [16, 6]|
+------------------+--------------------+



In [116]:
group_df.orderBy(col("groupUsers")[0]).show()

shuffled_df = group_df.orderBy(rand(1))
shuffled_df.show()

shuffled_df.orderBy(col("groupUsers")[0]).show()

+------------------+--------------------+
|  productsAffinity|          groupUsers|
+------------------+--------------------+
|         [C, B, A]|[1, 15, 2, 5, 8, ...|
|      [C, B, A, D]|[13, 12, 1, 15, 1...|
|   [C, E, B, A, D]| [13, 15, 16, 14, 6]|
|[F, C, E, B, A, D]|             [16, 6]|
+------------------+--------------------+

+------------------+--------------------+
|  productsAffinity|          groupUsers|
+------------------+--------------------+
|      [C, B, A, D]|[13, 12, 1, 15, 1...|
|[F, C, E, B, A, D]|             [16, 6]|
|   [C, E, B, A, D]| [13, 15, 16, 14, 6]|
|         [C, B, A]|[1, 15, 2, 5, 8, ...|
+------------------+--------------------+

+------------------+--------------------+
|  productsAffinity|          groupUsers|
+------------------+--------------------+
|         [C, B, A]|[1, 15, 2, 5, 8, ...|
|      [C, B, A, D]|[13, 12, 1, 15, 1...|
|   [C, E, B, A, D]| [13, 15, 16, 14, 6]|
|[F, C, E, B, A, D]|             [16, 6]|
+------------------+------------