In [0]:
%run "./reader_factory"

In [0]:
from pyspark.sql.window import Window
from pyspark.sql.functions import lead, col, broadcast, collect_set, size, array_contains

In [0]:
# Abstract class of transformer
class AbstractTransformer():
    def __init__(self):
        pass

    def transform(self):
        pass

# To find the customers who bought airpods after buying an iphone
class AirpodsAfteriPhoneTransformer(AbstractTransformer):
    def transform(self, inputDFs):
        print("Transformer for customers who bought airpods right after buying an iphone")
        
        transactionInputDF = inputDFs.get("transactionInputDF")         # To get the transactions dataframe stored in inputDFs dictionary
        print("Transactions dataframe before the tranform: -")
        transactionInputDF.show()

        # Defining the window to find out the next product bought after buying the current product
        # Here the window is a table with all the rows of one 
        # customer that is ordered by trasanction date 
        windowSpec = Window.partitionBy("customer_id").orderBy("transaction_date")

        # Getting a new dataframe with an additional column for the next product
        transformedDF = transactionInputDF.withColumn("next_product_name", lead("product_name").over(windowSpec))
        # withColumn(colmn_name, function to find value to store in column)
        # lead -> fucntion to get the value in the next row of a column

        print("After finding the next product for each customer:-")
        transformedDF.orderBy("customer_id", "transaction_date").show()     # transformedDF ordered by customer_id and then transaction_date

        print("Showing only the transactions of customers who bought \'AirPods\' after \'iPhone\': -")
        filteredDF = transformedDF.filter(
            (col("product_name") == "iPhone") & (col("next_product_name") == "AirPods")
        )
        filteredDF.show()

        customerInputDF = inputDFs.get("customerInputDF")         # To get the customers dataframe stored in inputDFs dictionary

        # Joiing to find the required customers' details
        joinedDF = customerInputDF.join(broadcast(filteredDF), "customer_id", "inner")         
        # inner join on "customer_id"
        # broadcast is used to broadcast customerInputDF (Smaller DF) to the partitions of filteredDF (Bigger DF) to improve performance

        resultDF = joinedDF.select("customer_id", "customer_name", "location")  # .show()      # .select to get specific columns 
        # print(f"Joined DF to show required customers\' details:\n{resultDF.show()}")
       
        return resultDF
    
class OnlyAirpodsAndiPhoneTransformer(AbstractTransformer):

    def transform(self, inputDFs):
        """
        Transform to find the customers who only bought an iPhone and Airpods, nothing else.
        """
        
        transactionInputDF = inputDFs.get("transactionInputDF")         # To get the transactions dataframe stored in inputDFs dictionary
        print("Transactions dataframe before the tranform: -")
        transactionInputDF.show()

        # collect_set() used in aggregate to get all unique values of a product for each customer
        # .alias is used to give a name to the grouped values column  
        groupedDF = transactionInputDF.groupBy("customer_id").agg(
            collect_set("product_name").alias("products")
        )

        print("Products bought by each customer:")
        groupedDF.show()

        print("Showing the customers who bought \'AirPods\' and \'iPhone\' only: -")
        filteredDF = groupedDF.filter(
            (array_contains(col("products"), "iPhone")) & 
            (array_contains(col("products"), "AirPods")) & 
            (size(col("products")) == 2)            
            # To ensure that only iphone and airpods are bought, size is used to find number of elements in the set for each customer
        )
        filteredDF.show()

        customerInputDF = inputDFs.get("customerInputDF")         # To get the customers dataframe stored in inputDFs dictionary

        # Joiing to find the required customers' details
        joinedDF = customerInputDF.join(broadcast(filteredDF), "customer_id", "inner")         
        # inner join on "customer_id"
        # broadcast is used to broadcast customerInputDF (Smaller DF) to the partitions of filteredDF (Bigger DF) to improve performance

        resultDF = joinedDF.select("customer_id", "customer_name", "location")  # .show()      # .select to get specific columns 
        # print(f"Joined DF to show required customers\' details:\n{resultDF.show()}")
       
        return resultDF