In [None]:
from pyspark.sql.window import Window
from pyspark.sql.functions import lead, col, broadcast, collect_set,array_contains,sum, split
from pyspark.sql.types import IntegerType

class Transformer:

    def __init__(self,InputDfs):
        self.InputDfs = InputDfs

    def transform(self):
        pass

class AirpodAfterIphoneTransform(Transformer):

    def transform(self):

        TransInputDf = self.InputDfs.get('TransInputDf')
        CustomerInputDf = self.InputDfs.get('CustomerInputDf')

        WindowSpec = Window.partitionBy("customer_id").orderBy("transaction_date")

        TransformedDf = TransInputDf.withColumn(
            "next_product_name",
            lead("product_name").over(WindowSpec))

        FilteredDf = TransformedDf.filter(
            (col("product_name") == 'iPhone') &
            (col("next_product_name") == 'AirPods'))
        
        JoinDf = CustomerInputDf.join(broadcast(FilteredDf), 
            "customer_id", "inner")
        

        JoinDf.orderBy("product_name").show()

        return JoinDf
    
class AirpodAndIphoneTransform(Transformer):

    def transform(self):

        TransInputDf = self.InputDfs.get('TransInputDf')
        CustomerInputDf = self.InputDfs.get('CustomerInputDf')

        groupedDf = TransInputDf.groupBy("customer_id").agg(collect_set("product_name").alias("products"))

        FilteredDf = groupedDf.filter(
            (array_contains(col("products"),"iPhone")) &
            (array_contains(col("products"),"AirPods"))
            ) 

        JoinDf = CustomerInputDf.join(broadcast(FilteredDf), 
            "customer_id", "inner")

        JoinDf.orderBy("customer_id").show()

        return JoinDf
    

class Top3ProductsByRevenueTransform(Transformer):

    def transform(self):

        TransInputDf = self.InputDfs.get('TransInputDf')
        ProductsInputDf = self.InputDfs.get('ProductsInputDf')

        ProductsInputDf = ProductsInputDf.withColumn("price",ProductsInputDf.price.cast(IntegerType()))
        ProductsInputDf = ProductsInputDf.withColumn("product_name",split(ProductsInputDf['product_name'], ' ').getItem(0))


        JoinDf = TransInputDf.join(broadcast(ProductsInputDf), 
            "product_name", "inner")
        # JoinDf.show()
        
        groupedDf = JoinDf.groupBy("category").agg(sum("price").alias("total_price"))
        
        FilteredDf = groupedDf.orderBy("total_price").limit(3)

        FilteredDf.show()

        return FilteredDf
