<a href="https://colab.research.google.com/github/RdZilla/product_category_pairs/blob/main/product_category_pairs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install pyspark==3.0.1 py4j==0.10.9

Collecting pyspark==3.0.1
  Downloading pyspark-3.0.1.tar.gz (204.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m204.2/204.2 MB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting py4j==0.10.9
  Downloading py4j-0.10.9-py2.py3-none-any.whl (198 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m198.6/198.6 kB[0m [31m17.9 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.0.1-py2.py3-none-any.whl size=204612224 sha256=47b44004e631305117426388731b3c56ea04d158588633995174f1b84b36bbdc
  Stored in directory: /root/.cache/pip/wheels/19/b0/c8/6cb894117070e130fc44352c2a13f15b6c27e440d04a84fb48
Successfully built pyspark
Installing collected packages: py4j, pyspark
  Attempting uninstall: py4j
    Found existing installation: py4j 0.10.9.7
    Uninstalli

In [2]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, collect_list, expr, lit

In [3]:
def get_product_category_pairs(df):
    product_categories = df.groupBy("product_name").agg(collect_list("category_name").alias("categories"))

    product_category_pairs = product_categories.select(col("product_name"), expr("explode_outer(categories)").alias("category_name"))
    products_without_categories = df.filter(col("category_name").isNull()).select("product_name").distinct().withColumn("category_name", lit(None))

    result_df = product_category_pairs.union(products_without_categories)
    result_df = result_df.orderBy("category_name")
    return result_df

In [13]:
spark = SparkSession.builder \
    .appName("product_category_pairs") \
    .getOrCreate()

In [14]:
spark

In [15]:
data = [("Yogurt", "Milk", "optional information", "extra_info"),
        ("Buckwheat", "Сereals", "optional information", "extra_info"),
        ("Rice", "Milk", "optional information", "extra_info"),
        ("Chicken", "Meat", "optional information", "extra_info"),
        ("Pork", "Meat", "optional information", "extra_info"),
        ("Beef", "Meat", "optional information", "extra_info"),
        ("Bear", "Drinks", "optional information", "extra_info"),
        ("Bear", "Alcohol", "optional information", "extra_info"),
        ("Juice", "Drinks", "optional information", "extra_info"),
        ("Croissants", "", "optional information", "extra_info"),
        ("Tea", "", "optional_information", "extra_info")]

columns = ["product_name", "category_name", "other_information", "extra_info"]
df = spark.createDataFrame(data, columns)
df.show()

+------------+-------------+--------------------+----------+
|product_name|category_name|   other_information|extra_info|
+------------+-------------+--------------------+----------+
|      Yogurt|         Milk|optional information|extra_info|
|   Buckwheat|      Сereals|optional information|extra_info|
|        Rice|         Milk|optional information|extra_info|
|     Chicken|         Meat|optional information|extra_info|
|        Pork|         Meat|optional information|extra_info|
|        Beef|         Meat|optional information|extra_info|
|        Bear|       Drinks|optional information|extra_info|
|        Bear|      Alcohol|optional information|extra_info|
|       Juice|       Drinks|optional information|extra_info|
|  Croissants|             |optional information|extra_info|
|         Tea|             |optional_information|extra_info|
+------------+-------------+--------------------+----------+



In [16]:
df.printSchema()

root
 |-- product_name: string (nullable = true)
 |-- category_name: string (nullable = true)
 |-- other_information: string (nullable = true)
 |-- extra_info: string (nullable = true)



In [17]:
result_df = get_product_category_pairs(df)
result_df.show()

+------------+-------------+
|product_name|category_name|
+------------+-------------+
|  Croissants|             |
|         Tea|             |
|        Bear|      Alcohol|
|        Bear|       Drinks|
|       Juice|       Drinks|
|        Beef|         Meat|
|     Chicken|         Meat|
|        Pork|         Meat|
|      Yogurt|         Milk|
|        Rice|         Milk|
|   Buckwheat|      Сereals|
+------------+-------------+



In [18]:
result_df.printSchema()

root
 |-- product_name: string (nullable = true)
 |-- category_name: string (nullable = true)



In [19]:
spark.stop()