In [None]:
from pyspark.sql.window import Window
from pyspark.sql.functions import lead, col, broadcast, collect_set, array_contains, size, desc, split, to_date

class Transformer():
    def __init__(self):
        pass

    def transform(self, input_df):
        pass


In [None]:

class AirpodsAfterIphonesTransformer(Transformer):
    def transform(self, input_dfs):
        """
        Customers who have bought airpods after buying the iPhone
        """
        transaction_df = input_dfs.get("transaction_df")
        window_spec = Window.partitionBy("customer_id").orderBy("transaction_date")
        
        print("Transactions")
        transaction_df.show(5)

        transformed_df = transaction_df.withColumn(
            "next_product_name", lead("product_name").over(window_spec)
        )

        print("Next Product name bought by customer")
        transformed_df.show(5)

        filtered_df = transformed_df.filter(
           (col("product_name") == "iPhone") & (col("next_product_name") == "AirPods")
        )
        
        print("Transactions AirPods after iPhones")
        filtered_df.show(5)
        
        customer_df = input_dfs["customer_df"]

        print("Customers")
        customer_df.show(5)

        join_df = filtered_df.join(
            broadcast(customer_df),
            on="customer_id", 
            how="inner"
        )

        final_df = join_df.select(
            "customer_id",
            "customer_name", 
            "location"
         )

        print("Customers who have bought AirPods after iPhones")
        final_df.show(5)

        return final_df

In [None]:

class AirpodsAndIphonesTransformer(Transformer):
    def transform(self, input_dfs):
        """
        Customers who bought only airpods and iphones
        """
        transaction_df = input_dfs['transaction_df']

        print("Transactions")
        transaction_df.show(5)

        grouped_df = transaction_df.groupBy("customer_id").agg(collect_set("product_name").alias("products"))
        filtered_df = grouped_df.filter(
            (size(grouped_df.products) == 2) & 
            ((array_contains(grouped_df.products, "iPhone"))) & 
            (array_contains(grouped_df.products, "AirPods"))
        )
        transformed_df = transaction_df.join(filtered_df, on="customer_id", how="inner")
        print("Transactions of customers who bought only AirPods and iPhones")
        transformed_df.show(5)

        customer_df = input_dfs["customer_df"]
        print("Customers")
        customer_df.show(5)

        join_df = transformed_df.join(
            broadcast(customer_df),
            on="customer_id", 
            how="inner"
        )

        final_df = join_df.select(
            "customer_id",
            "customer_name", 
            "location"
         ).distinct()
        
        print("Customers who bought only AirPods and iPhones")
        final_df.show(5)

        return final_df

In [None]:

class ProductsAfterInitialPurchaseTransformer(Transformer):
    def transform(self, input_dfs):
        """
        List of products after initial purchase by each customers
        """
        transaction_df = input_dfs['transaction_df']
        transaction_df.withColumn("transaction_date", to_date(transaction_df["transaction_date"]))

        print("Transactions")
        transaction_df.show(5)

        window_spec = Window.partitionBy("customer_id").orderBy(desc("transaction_date"))
        transformed_df = transaction_df.withColumn(
            "prev_product_name", lead("product_name").over(window_spec)
        )

        print("Prev bought Column added Transactions")
        transformed_df.show(5)

        # after initial purchase
        filtered_df = transformed_df.where(col("prev_product_name") != "Null")

        # column rename for better join
        filtered_df = filtered_df.withColumnRenamed("product_name", "product_name_initial")
        
        print("Just Transactions After Initial Purchase")
        filtered_df.show(5)

        # product df pre-processing
        product_df = input_dfs['product_df']

        print("Products")
        product_df.show(5)
        product_transformed_df = product_df.withColumn(
            "product_name_initial", split(product_df["product_name"], " ").getItem(0)
        )
        
        print("Products after preprocessing")
        product_transformed_df.show(5)
        
        join_df = product_transformed_df.join(
            broadcast(filtered_df), on="product_name_initial", how="inner"
        )

        final_df = join_df.select(
            "product_id",
            "product_name",
            "category",
            "price"
        ).distinct()
        
        print("List of products bought by customers after their initial purchase")
        final_df.show(5)
        return final_df