In [1]:
from datetime import datetime
from dateutil.relativedelta import relativedelta
import argparse

from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.types import (
    StructType,
    StructField,
    DateType,
    StringType,
    IntegerType,
    DecimalType,
)
from pyspark.sql.functions import (
    col,
    when,
    first,
    last_day,
    sum,
    count,
    countDistinct,
)
from pyspark.sql.window import Window

In [2]:
DATE_PARAM = "2018-12-31"
TRANSACTIONS_FILE = "/home/jovyan/work/data/transactions_train.csv"
ARTICLES_FILE = "/home/jovyan/work/data/articles.csv"
CUSTOMERS_FILE = "/home/jovyan/work/data/customers.csv"
OUTPUT_FILE = "/home/jovyan/work/data/data_mart.csv"

filter_date = datetime.strptime(DATE_PARAM, "%Y-%m-%d")
DATE_BEGIN = filter_date + relativedelta(day=1)
DATE_END = filter_date + relativedelta(day=31)

In [3]:
spark = (
    SparkSession.builder
    .master("spark://spark-master:7077")
    .appName("transactions-etl")
    .getOrCreate()
)

In [4]:
transactions_schema = StructType(
    [
        StructField("t_dat", DateType(), True),
        StructField("customer_id", StringType(), True),
        StructField("article_id", IntegerType(), True),
        StructField("price", DecimalType(22, 20), True),
        StructField("sales_channel_id", IntegerType(), True),
    ]
)

In [5]:
articles_schema = StructType(
    [
        StructField('article_id', IntegerType(), True), 
        StructField('product_code', IntegerType(), True), 
        StructField('prod_name', StringType(), True), 
        StructField('product_type_no', IntegerType(), True), 
        StructField('product_type_name', StringType(), True), 
        StructField('product_group_name', StringType(), True), 
        StructField('graphical_appearance_no', IntegerType(), True), 
        StructField('graphical_appearance_name', StringType(), True), 
        StructField('colour_group_code', IntegerType(), True), 
        StructField('colour_group_name', StringType(), True), 
        StructField('perceived_colour_value_id', IntegerType(), True), 
        StructField('perceived_colour_value_name', StringType(), True), 
        StructField('perceived_colour_master_id', IntegerType(), True), 
        StructField('perceived_colour_master_name', StringType(), True), 
        StructField('department_no', IntegerType(), True), 
        StructField('department_name', StringType(), True), 
        StructField('index_code', StringType(), True), 
        StructField('index_name', StringType(), True), 
        StructField('index_group_no', IntegerType(), True), 
        StructField('index_group_name', StringType(), True), 
        StructField('section_no', IntegerType(), True), 
        StructField('section_name', StringType(), True), 
        StructField('garment_group_no', IntegerType(), True), 
        StructField('garment_group_name', StringType(), True), 
        StructField('detail_desc', StringType(), True)
    ]
)

In [6]:
customers_schema = StructType(
    [
        StructField('customer_id', StringType(), True), 
        StructField('FN', DecimalType(2, 1), True), 
        StructField('Active', DecimalType(2, 1), True), 
        StructField('club_member_status', StringType(), True), 
        StructField('fashion_news_frequency', StringType(), True), 
        StructField('age', IntegerType(), True), 
        StructField('postal_code', StringType(), True)
    ]
)

In [7]:
transactions_df = (
    spark.read
    .format("csv")
    .schema(transactions_schema)
    .option("header", "true")
    .option("delimiter", ",")
    .load(TRANSACTIONS_FILE)
)

In [8]:
articles_df = (
    spark.read
    .format("csv")
    .schema(articles_schema)
    .option("header", "true")
    .option("delimiter", ",")
    .load(ARTICLES_FILE)
)

In [9]:
customers_df = (
    spark.read
    .format("csv")
    .schema(customers_schema)
    .option("header", "true")
    .option("delimiter", ",")
    .load(CUSTOMERS_FILE)
)

In [10]:
filtered_transactions_df = (
    transactions_df
    .where(
        (col("t_dat") >= DATE_BEGIN) & 
        (col("t_dat") <= DATE_END)
    )
)

In [11]:
enriched_transactions_df = (
    filtered_transactions_df
    .join(customers_df, "customer_id", "inner")
    .join(articles_df, "article_id", "inner")
    .select(
        "t_dat",
        "customer_id",
        "article_id",
        "price",
        "age",
        "product_group_name"        
    )
)

In [12]:
window_spec_most_expensive_article = (
    Window
    .partitionBy("customer_id")
    .orderBy(
        col("price").desc(), 
        col("t_dat")
    )
)

transformed_transactions_df = (
    enriched_transactions_df
    .withColumn(
        "customer_group_by_age",
        when(col("age") < 23, "S")
        .when(col("age") < 60, "A")
        .otherwise("R")
    )
    .withColumn(
        "most_exp_article_id",
        first("article_id")
        .over(window_spec_most_expensive_article)
    )
    .withColumn(
        "part_date",
        last_day(col("t_dat"))
    )
    .select(
        "part_date",
        "customer_id",
        "article_id",
        "price",
        "product_group_name",    
        "customer_group_by_age",
        "most_exp_article_id"
    )
)

In [13]:
aggregated_transactions_df = (
    transformed_transactions_df
    .groupBy(
        "part_date", 
        "customer_id", 
        "customer_group_by_age", 
        "most_exp_article_id"
    )
    .agg(
        sum("price").alias("transaction_amount"),
        count("article_id").alias("number_of_articles"),
        countDistinct("product_group_name").alias("number_of_product_groups")
    )
    .select(
        "part_date", 
        "customer_id", 
        "customer_group_by_age",
        "transaction_amount",
        "most_exp_article_id",
        "number_of_articles",
        "number_of_product_groups"        
    )
)

In [14]:
# (
#     aggregated_transactions
#     .coalesce(1)
#     .write
#     .mode("overwrite")
#     .format("csv")
#     .option("header", "true")
#     .option("delimiter", ",")
#     .save(OUTPUT_FILE)
# )

In [15]:
(
    aggregated_transactions_df
    .toPandas()
    .to_csv(OUTPUT_FILE, sep=",", index=False)
)

In [16]:
spark.stop()