In [0]:
from pyspark.sql.window import Window
from pyspark.sql.functions import lead, broadcast, collect_set, col, size, array_contains

class Transformer:
    def __init__(self):
        pass

    def transform(self):
        pass

class AirpodsAfterIphoneTransformer(Transformer):
    def transform(self, inputDfs):
        """
        Identify customers who have bought AirPods after buying iPhone.
        """
        # Get the input DataFrame from the dictionary
        inputDF = inputDfs.get("inputDF")
        
        # Define the window specification
        windowSpec = Window.partitionBy("customer_id").orderBy("transaction_date")

        # Add a column for the next product purchased by each customer
        transformedDF = inputDF.withColumn(
            "next_product_name", lead("product_name").over(windowSpec)
        )

        # Show the transformed DataFrame (to inspect)
        print("Airpods after buying iPhone:")

        # Filter for customers who bought iPhone and then AirPods
        filterDf = transformedDF.filter(
            (col("product_name") == "iPhone") & (col("next_product_name") == "AirPods")
        )

        # Show the filtered DataFrame (customers who bought AirPods after iPhone)
        filterDf.orderBy("customer_id", "transaction_date").show()

        customerInputDF = inputDfs.get("customerDF")
        joinDF = customerInputDF.join(
            broadcast(filterDf), "customer_id"
        )
        print("JOINED DF")
        joinDF.select("customer_id","customer_name","location").show()
        return joinDF
    

class OnlyAirpodsAndIphoneTransformer(Transformer):

    def transform(self, inputDfs):
        """
        Identify customers who have bought only AirPods and iPhone.
        """
        inputDF = inputDfs.get("inputDF")
        customerInputDF = inputDfs.get("customerDF")

        grouped_DF  = inputDF.groupBy("customer_id").agg(
            collect_set("product_name").alias("products")
        )
        filterDf = grouped_DF.filter(
            (array_contains(col("products"), "iPhone")) & 
            (array_contains(col("products"), "AirPods")) &
            (size(col("products")) == 2)
        )
        joinDF = customerInputDF.join(
            broadcast(filterDf), "customer_id"
        )
        print("JOINED DF")
        joinDF.select("customer_id","customer_name","products").show()
        return joinDF
