In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql.functions import lead, col, broadcast, collect_set, size, array_contains, row_number, to_date, min, datediff, avg, expr


class Transformer:
    def __init__(self):
        pass

    def transform(self, inputDFs):
        pass

class AirpodsAfterIphone(Transformer):

    def transform(self, inputDFs):
        """
        Cusotmers who have bought Airpods after buying the iPhone
        """

        transactionInputDF = inputDFs.get("transactionInputDF")

        # Window is created for every row
        windowSpec = Window.partitionBy("customer_id").orderBy("transaction_date")

        # Lead gets the next row in the window sorted by customer_id and transaction_date
        transformedDF = transactionInputDF.withColumn(
            "next_product_name", lead("product_name").over(windowSpec)
        )

        transformedDF.orderBy("customer_id", "transaction_date", "product_name").show()
        
        filteredDF = transformedDF.filter(
            (col("product_name") == "iPhone") & (col("next_product_name") == "AirPods")
        )
        print("AirPods after buying iPhone:")
        filteredDF.orderBy("customer_id", "transaction_date", "product_name").show()

        customerInputDF = inputDFs.get("customerInputDF")

        print("Cusotomer Table:")
        customerInputDF.show()

        # Since filtered table is generally smaller, broadcast can be used on it
        joinDF = customerInputDF.join(
            broadcast(filteredDF),
            "customer_id"
            )

        print("Customers who bought AirPods after iPhone:")
        joinDF.show()

        # Returning only customer details with this transformation logic
        return joinDF.select(
            "customer_id",
            "customer_name",
            "location"
        )



class OnlyAirpodsAfterIphone(Transformer):

    def transform(self, inputDFs):
        """
        Customers who have bought only iPhone and Airpods nothing else
        """

        transactionInputDF = inputDFs.get("transactionInputDF")

        groupedDF = transactionInputDF.groupBy("customer_id").agg(
            collect_set("product_name").alias("products")
        )

        print("Grouped DF")
        groupedDF.show()

        filteredDF = groupedDF.filter(
            (array_contains(col("products") ,"iPhone")) &
            (array_contains(col("products") ,"AirPods")) &
            (size(col("products")) == 2)
        )

        print("Only AirPods and iPhone:")
        filteredDF.show()

        customerInputDF = inputDFs.get("customerInputDF")

        print("Cusotomer Table:")
        customerInputDF.show()

        # Since filtered table is generally smaller, broadcast can be used on it
        joinDF = customerInputDF.join(
            broadcast(filteredDF),
            "customer_id"
            )

        print("Customers who bought only AirPods and iPhone:")
        joinDF.show()

        # Returning only customer details with this transformation logic
        return joinDF.select(
            "customer_id",
            "customer_name",
            "location"
        )



class ProductsAfterInitialPurchase(Transformer):

    def transform(self, inputDFs):
        """
        List all products bought by customers after their initial purchase
        """

        transactionInputDF = inputDFs.get("transactionInputDF")

        # Define the window specification to order transactions by customer and transaction date
        windowSpec = Window.partitionBy("customer_id").orderBy("transaction_date")

        # Add a row number to each transaction for each customer
        rankedDF = transactionInputDF.withColumn("row_num", row_number().over(windowSpec))
        # rankedDF.show()

        # Filter out the first purchase (row_num == 1) for each customer
        filteredDF = rankedDF.filter(col("row_num") > 1)

        # Drop the row_num column as it's no longer needed
        finalDF = filteredDF.drop("row_num")

        # Show the result
        print("Products bought after first purchase:")
        finalDF.show()

        # Returning the DataFrame with customer and their subsequent purchases
        return finalDF.select(
            "customer_id",
            "transaction_date",
            "product_name",
        )



class AirpodsPurchaseDelayTransformer(Transformer):

    def transform(self, inputDFs):
        """
        Calculate the average time delay of buying AirPods after purchasing an iPhone.
        """

        # Get the transaction data
        transactionDF = inputDFs.get("transactionInputDF")

        # Define window spec to partition by customer_id and order by transaction_date
        windowSpec = Window.partitionBy("customer_id").orderBy("transaction_date")

        # Identify the first iPhone purchase date for each customer
        iphonePurchaseDF = transactionDF.filter(col("product_name") == "iPhone").withColumn(
            "iphone_purchase_date", min("transaction_date").over(windowSpec)
        ).select("customer_id", "iphone_purchase_date").distinct()

        # Join this with the original transaction data to find all purchases after the first iPhone purchase
        joinedDF = transactionDF.join(
            iphonePurchaseDF, 
            on="customer_id", 
            how="inner"
        ).filter(col("transaction_date") > col("iphone_purchase_date"))

        # Filter for AirPods purchases after the iPhone
        airpodsDF = joinedDF.filter(col("product_name") == "AirPods")

        # Find the first AirPods purchase date after the iPhone purchase for each customer
        airpodsPurchaseDF = airpodsDF.withColumn(
            "airpods_purchase_date", min("transaction_date").over(windowSpec)
        ).select("customer_id", "iphone_purchase_date", "airpods_purchase_date").distinct()

        print("iPhone and Airpods purchase dates")
        airpodsPurchaseDF.show()

        # Calculate the delay in days between the iPhone and AirPods purchases
        delayDF = airpodsPurchaseDF.withColumn(
            "delay_days", datediff(col("airpods_purchase_date"), col("iphone_purchase_date"))
        )

        # Calculate the average delay
        avgDelayDF = delayDF.agg(avg(col("delay_days")).alias("avg_delay_days"))

        # Show the resulting DataFrame
        avgDelayDF.show()

        return avgDelayDF