In [0]:
from pyspark.sql.window import Window
from pyspark.sql.functions import lead, row_number, col, broadcast, collect_set, size, array_contains, trim, countDistinct


class Transformer:
    def __init__(self):
        pass

    def transform(self, inputDFs):
        pass

class AirpodsAfterIphoneTransformer(Transformer):

    def transform(self, inputDFs):
        """
        Customers who have bought Airpods after buying the iPhone
        """

        transcatioInputDF = inputDFs.get("transcatioInputDF")

        windowSpec = Window.partitionBy("customer_id").orderBy("transaction_date")

        transformedDF = transcatioInputDF.withColumn(
            "next_product_name", lead("product_name").over(windowSpec)
        )

        
        transformedDF.orderBy("customer_id", "transaction_date", "product_name")

        filteredDF = transformedDF.filter(
            (col("product_name") == "iPhone") & (col("next_product_name") == "AirPods")
        )

        filteredDF.orderBy("customer_id", "transaction_date", "product_name")

        customerInputDF = inputDFs.get("customerInputDF")

        
        joinDF =  customerInputDF.join(
           broadcast(filteredDF),
            "customer_id"
        )

        joinDF.show()

        return joinDF.select(
            "customer_id",
            "customer_name",
            "location"
        )


class OnlyAirpodsAndIphone(Transformer):

    def transform(self, inputDFs):
        """
        Customer who have bought only iPhone and Airpods nothing else
        """

        transcatioInputDF = inputDFs.get("transcatioInputDF")


        groupedDF = transcatioInputDF.groupBy("customer_id").agg(
            collect_set("product_name").alias("products")
        )


        filteredDF = groupedDF.filter(
            (array_contains(col("products"), "iPhone")) &
            (array_contains(col("products"), "AirPods")) & 
            (size(col("products")) == 2)
        )
        

        customerInputDF = inputDFs.get("customerInputDF")


        joinDF =  customerInputDF.join(
           broadcast(filteredDF),
            "customer_id"
        )

        joinDF.show()

        return joinDF.select(
            "customer_id",
            "customer_name",
            "location"
        )

class boughtAfterInitialPurchase(Transformer):

    def transform(self, inputDFs):
        """
        Customer who have bought only iPhone and Airpods nothing else
        """

        transcatioInputDF = inputDFs.get("transcatioInputDF")

        windowSpec = Window.partitionBy("customer_id").orderBy("transaction_date")

        transformedDF = transcatioInputDF.withColumn(
            "row_num", row_number().over(windowSpec)
        )

        filteredDF = transformedDF.filter(
            (col("row_num")>1)
        )

        filteredDF = filteredDF.orderBy("customer_id", "transaction_date", "product_name")

        customerInputDF = inputDFs.get("customerInputDF")

        
        joinDF =  customerInputDF.join(
           broadcast(filteredDF),
            "customer_id"
        )

        joinDF.show()

        return joinDF.select(
            "customer_id",
            "customer_name",
            "location"
        )

class customersFromOhioAndNevada(Transformer):

    def transform(self, inputDFs):

        transactionDF = inputDFs.get("transcatioInputDF")
        productDF = inputDFs.get("productInputDF")
        customerDF = inputDFs.get("customerInputDF")

        transactionDF = transactionDF.withColumn("customer_id",col("customer_id").cast("int"))

        joinedDF = transactionDF.join(customerDF, transactionDF.customer_id == customerDF.customer_id).join(productDF, transactionDF.product_name == productDF.product_name).filter(customerDF.location.isin("Nevada","Ohio")).select(transactionDF.product_name).distinct()

        joinedDF.show()

        return joinedDF
    

class moreThanOneCategory(Transformer):

    def transform(self, inputDFs):

        transactionDF = inputDFs.get("transcatioInputDF")
        productDF = inputDFs.get("productInputDF")

        joinedDF = transactionDF.join(productDF, transactionDF.product_name == productDF.product_name).select("transaction_id", "customer_id", transactionDF.product_name, "category")

        categoryDF = joinedDF.groupBy("customer_id").agg(countDistinct("category").alias("distinct_categories"))

        finalDF = categoryDF.filter(col("distinct_categories") > 1)

        finalDF = finalDF.drop("distinct_categories")
                                                         
        finalDF.show()

        return joinedDF



        
