In [1]:
from datetime import datetime
from dateutil.relativedelta import relativedelta
import argparse
from typing import Iterator

import pandas as pd
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.types import (
    StructType,
    StructField,
    DateType,
    StringType,
    IntegerType,
    DecimalType,
)
from pyspark.sql.functions import (
    col,
    when,
    first,
    last_day,
    sum,
    count,
    countDistinct,
    from_json,
    lit,
    udf,
)
from pyspark.sql.pandas.functions import pandas_udf
from pyspark.sql.window import Window

In [2]:
DATE_PARAM = "2018-12-31"
DM_CURRENCY_PARAM = "BYN"

filter_date = datetime.strptime(DATE_PARAM, "%Y-%m-%d")
DATE_BEGIN = filter_date + relativedelta(day=1)
DATE_END = filter_date + relativedelta(day=31)

TRANSACTIONS_FILE = "/home/jovyan/work/data/transactions_train_with_currency.csv"
ARTICLES_FILE = "/home/jovyan/work/data/articles.csv"
CUSTOMERS_FILE = "/home/jovyan/work/data/customers.csv"
OUTPUT_FILE = f"/home/jovyan/work/data/dm_transactions/{DATE_END.strftime('%Y-%m-%d')}.csv"

PRICE_PRECISION, PRICE_SCALE = 22, 16
RATE_PRECISION, RATE_SCALE = 6, 2

In [3]:
spark = (
    SparkSession.builder
    .master("spark://spark-master:7077")
    .appName("transactions-etl")
    .config("spark.sql.execution.arrow.pyspark.enabled", "true")
    .getOrCreate()
)

In [4]:
transactions_schema = StructType(
    [       
        StructField('id', IntegerType(), True), 
        StructField('t_dat', DateType(), True), 
        StructField('customer_id', StringType(), True), 
        StructField('article_id', IntegerType(), True), 
        StructField('price', DecimalType(PRICE_PRECISION, PRICE_SCALE), True), 
        StructField('sales_channel_id', IntegerType(), True), 
        StructField('currency', StringType(), True), 
        StructField('current_exchange_rate', StringType(), True)
    ]
)

In [5]:
articles_schema = StructType(
    [
        StructField('article_id', IntegerType(), True), 
        StructField('product_code', IntegerType(), True), 
        StructField('prod_name', StringType(), True), 
        StructField('product_type_no', IntegerType(), True), 
        StructField('product_type_name', StringType(), True), 
        StructField('product_group_name', StringType(), True), 
        StructField('graphical_appearance_no', IntegerType(), True), 
        StructField('graphical_appearance_name', StringType(), True), 
        StructField('colour_group_code', IntegerType(), True), 
        StructField('colour_group_name', StringType(), True), 
        StructField('perceived_colour_value_id', IntegerType(), True), 
        StructField('perceived_colour_value_name', StringType(), True), 
        StructField('perceived_colour_master_id', IntegerType(), True), 
        StructField('perceived_colour_master_name', StringType(), True), 
        StructField('department_no', IntegerType(), True), 
        StructField('department_name', StringType(), True), 
        StructField('index_code', StringType(), True), 
        StructField('index_name', StringType(), True), 
        StructField('index_group_no', IntegerType(), True), 
        StructField('index_group_name', StringType(), True), 
        StructField('section_no', IntegerType(), True), 
        StructField('section_name', StringType(), True), 
        StructField('garment_group_no', IntegerType(), True), 
        StructField('garment_group_name', StringType(), True), 
        StructField('detail_desc', StringType(), True)
    ]
)

In [6]:
customers_schema = StructType(
    [
        StructField('customer_id', StringType(), True), 
        StructField('FN', DecimalType(2, 1), True), 
        StructField('Active', DecimalType(2, 1), True), 
        StructField('club_member_status', StringType(), True), 
        StructField('fashion_news_frequency', StringType(), True), 
        StructField('age', IntegerType(), True), 
        StructField('postal_code', StringType(), True)
    ]
)

In [7]:
transactions_df = (
    spark.read
    .format("csv")
    .schema(transactions_schema)
    .option("header", "true")
    .option("delimiter", ",")
    .load(TRANSACTIONS_FILE)
)

In [8]:
articles_df = (
    spark.read
    .format("csv")
    .schema(articles_schema)
    .option("header", "true")
    .option("delimiter", ",")
    .load(ARTICLES_FILE)
)

In [9]:
customers_df = (
    spark.read
    .format("csv")
    .schema(customers_schema)
    .option("header", "true")
    .option("delimiter", ",")
    .load(CUSTOMERS_FILE)
)

In [10]:
filtered_transactions_df = (
    transactions_df
    .where(
        (col("t_dat") >= DATE_BEGIN) & 
        (col("t_dat") <= DATE_END)
    )
)

In [11]:
# @udf(DecimalType(PRICE_PRECISION, PRICE_SCALE))
# def convert_price(rates: dict, key: str, price: float):
#     return price * rates.get(key, 1)

# @udf(StringType())
# def to_upper_case(value: str) -> str:
#     return value.upper()

@pandas_udf(StringType())
def to_upper_case(value: pd.Series) -> pd.Series:
    return value.str.upper()

@pandas_udf(DecimalType(PRICE_PRECISION, PRICE_SCALE))
def convert_price(rates_map: pd.Series, dm_currency: pd.Series, price: pd.Series) -> pd.Series:
    rate = rates_map.apply(lambda row: row.get(dm_currency[0], 1))
    return price * rate

In [12]:
exchange_rate_schema = f"map<string, decimal({RATE_PRECISION}, {RATE_SCALE})>"

converted_price_transactions_df = (
    filtered_transactions_df
    .withColumn("dm_currency", lit(DM_CURRENCY_PARAM))
    .withColumn(
        "current_exchange_rate", 
        to_upper_case(col("current_exchange_rate"))
    )
    .withColumn(
        "current_exchange_rate", 
        from_json(col("current_exchange_rate"), exchange_rate_schema)
    )    
    .withColumn(
        "price", 
        when(
            col("currency") != col("dm_currency"), 
            convert_price(
                col("current_exchange_rate"), 
                col("dm_currency"), 
                col("price")
            )
        )
        .otherwise(col("price"))        
    )
)

In [13]:
enriched_transactions_df = (
    converted_price_transactions_df
    .join(customers_df, "customer_id", "inner")
    .join(articles_df, "article_id", "inner")
    .select(
        "t_dat",
        "customer_id",
        "article_id",
        "price",
        "age",
        "product_group_name",
        "dm_currency"
    )
)

In [14]:
window_spec_most_expensive_article = (
    Window
    .partitionBy("customer_id")
    .orderBy(
        col("price").desc(), 
        col("t_dat")
    )
)

transformed_transactions_df = (
    enriched_transactions_df
    .withColumn(
        "customer_group_by_age",
        when(col("age") < 23, "S")
        .when(col("age") < 60, "A")
        .otherwise("R")
    )
    .withColumn(
        "most_exp_article_id",
        first("article_id")
        .over(window_spec_most_expensive_article)
    )
    .withColumn(
        "part_date",
        last_day(col("t_dat"))
    )
    .select(
        "part_date",
        "customer_id",
        "article_id",
        "price",
        "product_group_name",    
        "customer_group_by_age",
        "most_exp_article_id",
        "dm_currency"
    )
)

In [15]:
aggregated_transactions_df = (
    transformed_transactions_df
    .groupBy(
        "part_date", 
        "customer_id", 
        "customer_group_by_age", 
        "most_exp_article_id",
        "dm_currency"
    )
    .agg(
        sum("price").alias("transaction_amount"),
        count("article_id").alias("number_of_articles"),
        countDistinct("product_group_name").alias("number_of_product_groups")
    )
    .select(
        "part_date", 
        "customer_id", 
        "customer_group_by_age",
        "transaction_amount",
        "most_exp_article_id",
        "number_of_articles",
        "number_of_product_groups",
        "dm_currency"
    )
)

In [16]:
(
    aggregated_transactions_df
    .repartition(1)
    .write
    .mode("overwrite")
    .format("csv")
    .option("header", "true")
    .option("delimiter", ",")
    .save(OUTPUT_FILE)
)

In [17]:
# (
#     aggregated_transactions_df
#     .toPandas()
#     .to_csv(OUTPUT_FILE, sep=",", index=False)
# )

In [18]:
spark.stop()