In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

In [2]:
spark = SparkSession.builder \
      .appName("denorm") \
      .getOrCreate()

JAVA_HOME is not set


RuntimeError: Java gateway process exited before sending its port number

In [3]:
users = spark.read.option("header",True).csv("test_data/users.csv")
portfolios = spark.read.option("header",True).csv("test_data/portfolios.csv")
portfolio_items = spark.read.option("header",True).csv("test_data/portfolio_items.csv")

                                                                                

In [4]:
users.printSchema()
portfolios.printSchema()
portfolio_items.printSchema()

root
 |-- customer_id: string (nullable = true)
 |-- user_id: string (nullable = true)

root
 |-- portfolio_id: string (nullable = true)
 |-- user_id: string (nullable = true)
 |-- name: string (nullable = true)

root
 |-- portfolio_id: string (nullable = true)
 |-- item_id: string (nullable = true)



In [5]:
denormalised = users \
    .join(portfolios, users.user_id == portfolios.user_id, "left") \
    .join(portfolio_items, portfolios.portfolio_id == portfolio_items.portfolio_id, "left")

print(portfolio_items.count())
print(denormalised.count())
denormalised.show()

20
20
+-----------+-------+------------+-------+------+------------+-------+
|customer_id|user_id|portfolio_id|user_id|  name|portfolio_id|item_id|
+-----------+-------+------------+-------+------+------------+-------+
|         C1|     U1|          P1|     U1|  sdas|          P1|     I5|
|         C1|     U1|          P1|     U1|  sdas|          P1|     I4|
|         C1|     U1|          P1|     U1|  sdas|          P1|     I3|
|         C1|     U1|          P1|     U1|  sdas|          P1|     I2|
|         C1|     U1|          P1|     U1|  sdas|          P1|     I1|
|         C1|     U2|          P2|     U2| efwer|          P2|     I9|
|         C1|     U2|          P2|     U2| efwer|          P2|     I8|
|         C1|     U2|          P2|     U2| efwer|          P2|     I7|
|         C1|     U2|          P2|     U2| efwer|          P2|     I6|
|         C1|     U2|          P2|     U2| efwer|          P2|     I5|
|         C1|     U2|          P2|     U2| efwer|          P2|     I4|


In [6]:
aggregated_by_customer = denormalised.groupby("customer_id") \
    .agg(F.collect_set("item_id").alias("portfolio"))

aggregated_by_customer.show()

+-----------+--------------------+
|customer_id|           portfolio|
+-----------+--------------------+
|         C1|[I4, I10, I2, I3,...|
|         C2|[I10, I13, I12, I...|
+-----------+--------------------+



In [7]:
aggregated_by_item = denormalised.groupby("item_id") \
    .agg(F.collect_set("customer_id").alias("interested_customers"))

aggregated_by_item.show()

+-------+--------------------+
|item_id|interested_customers|
+-------+--------------------+
|     I5|                [C1]|
|    I12|                [C2]|
|    I13|                [C2]|
|     I4|                [C1]|
|     I2|                [C1]|
|     I3|                [C1]|
|     I1|            [C1, C2]|
|    I11|                [C1]|
|     I8|                [C1]|
|    I10|            [C1, C2]|
|     I9|                [C1]|
|     I7|                [C1]|
|     I6|            [C1, C2]|
+-------+--------------------+

