In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, countDistinct

# Initialize Spark session
spark = SparkSession.builder.appName("CustomersPurchasedAllProducts").getOrCreate()

In [2]:
# Sample data for Products Table
products_data = [
    (5,),
    (6,)
]

# Sample data for Customers Table
customers_data = [
    (1, 5),
    (2, 6),
    (3, 5),
    (3, 6),
    (1, 6)
]


In [3]:
# Create DataFrames
products_df = spark.createDataFrame(products_data, ["product_key"])
customers_df = spark.createDataFrame(customers_data, ["customer_id", "product_key"])

In [4]:
products_df.show()

+-----------+
|product_key|
+-----------+
|          5|
|          6|
+-----------+



In [5]:
customers_df.show()

+-----------+-----------+
|customer_id|product_key|
+-----------+-----------+
|          1|          5|
|          2|          6|
|          3|          5|
|          3|          6|
|          1|          6|
+-----------+-----------+



In [6]:
# Count the total number of distinct products
total_products_count = products_df.count()

# Count the number of distinct products each customer has purchased
customer_product_count_df = (customers_df
                              .groupBy("customer_id")
                              .agg(countDistinct("product_key").alias("product_count")))

# Find customers who bought all products
result_df = customer_product_count_df.filter(col("product_count") == total_products_count)

# Show the result
result_df.select("customer_id").show()

+-----------+
|customer_id|
+-----------+
|          1|
|          3|
+-----------+



SPARK SQL METHOD

In [7]:
# Register DataFrames as SQL temporary views
products_df.createOrReplaceTempView("Products")
customers_df.createOrReplaceTempView("Customers")

In [8]:
# SQL query to find customers who bought all products
query = """
SELECT customer_id
FROM Customers
GROUP BY customer_id
HAVING COUNT(DISTINCT product_key) = (SELECT COUNT(*) FROM Products)
"""

In [9]:
# Execute the query and show the result
result_df = spark.sql(query)
result_df.show()

+-----------+
|customer_id|
+-----------+
|          1|
|          3|
+-----------+

