In [1]:
# import libs
from pyspark import SparkContext
from pyspark.sql import SparkSession
from pyspark.sql.types import DecimalType
from pyspark.sql.window import Window
from pyspark.sql.functions import lit, when, col, count
import pyspark.sql.functions as func
from datetime import datetime

In [6]:
# paths and part_date
articles_path = '../data/articles.csv'
customers_path = '../data/customers.csv'
transactions_train_path = '../data/transactions_train.csv'
output_path = 'output_csv'
part_date = datetime.strptime(input('Enter date like "2019-12-31": '), '%Y-%m-%d')

Enter date like "2019-12-31":  2019-12-31


In [3]:
# create spark session
spark = SparkSession(SparkContext())

In [8]:
# read csv files
articles_raw = spark.read.csv(articles_path, header=True)
customers_raw = spark.read.csv(customers_path, header=True)
transactions_train_raw = spark.read.csv(transactions_train_path, header=True)

In [10]:
# current month transactions
transactions_train_df = transactions_train_raw.filter((func.year(transactions_train_raw['t_dat']) == part_date.year) &
                                                      (func.month(transactions_train_raw['t_dat']) == part_date.month)) \
                                              .join(articles_raw, 'article_id', 'left') \
                                              .withColumn('row_num', func.monotonically_increasing_id()) \
                                              .withColumn('price', col('price').cast(DecimalType(38, 20))) \
                                              .select('t_dat', 'customer_id', 'article_id', 'price', 'product_group_name', 'row_num')

In [11]:
window_most_exp_art = Window.partitionBy('customer_id').orderBy(func.desc('price'), 'row_num')
most_exp_articles_df = transactions_train_df.withColumn('row_number', func.row_number().over(window_most_exp_art)) \
                                            .filter(col('row_number') == 1) \
                                            .select('customer_id', 'article_id') \
                                            .withColumnRenamed('article_id', 'most_exp_article_id')
# most_exp_articles_df.show()

In [12]:
number_of_product_groups_df = transactions_train_df.select('customer_id', 'product_group_name') \
                                                   .groupBy('customer_id', 'product_group_name') \
                                                   .count() \
                                                   .groupBy('customer_id') \
                                                   .agg(count('product_group_name').alias('number_of_product_groups'))
# number_of_product_groups_df.show()

In [13]:
grouped_transactions_df = transactions_train_df.select('customer_id', 'price') \
                                               .groupBy('customer_id') \
                                               .agg(func.sum('price').alias('transaction_amount'), func.count('price').alias('number_of_articles'))
# grouped_transactions_df.show()

In [14]:
result_df = grouped_transactions_df.join(number_of_product_groups_df, 'customer_id', 'left') \
                                   .join(most_exp_articles_df, 'customer_id', 'left') \
                                   .join(customers_raw, 'customer_id', 'left') \
                                   .withColumn('customer_group_by_age', when(col('age') < 23, 'S') \
                                                                       .when(col('age') > 59, 'R') \
                                                                       .otherwise('A')) \
                                   .withColumn('part_date', func.lit(datetime.strftime(part_date, '%Y-%m'))) \
                                   .select('part_date', 'customer_id', 'customer_group_by_age', 'transaction_amount', 'most_exp_article_id', 'number_of_articles', 'number_of_product_groups')
result_df.write \
         .mode('overwrite') \
         .option('header', 'True') \
         .option('delimiter', ',') \
         .csv(output_path)