In [236]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.window import Window
from pyspark.ml.regression import GBTRegressor
import numpy as np
import time
import datetime
import os

spark = (
    SparkSession.builder.appName("ADS project 2")
    .config("spark.sql.repl.eagerEval.enabled", True) 
    .config("spark.sql.parquet.cacheMetadata", "true")
    .config("spark.sql.session.timeZone", "Etc/UTC")
    .config("spark.executor.memory", "2g")
    .config("spark.driver.memory", "4g")
    .getOrCreate()
)

In [326]:
sdf = spark.read.parquet('../data/curated/merchant_consumer_abs')
categories = sdf.select(col("merchant_name").alias("merchant_name_copy"), "category", "take_rate").distinct()
sdf = sdf.drop(*['postcode', 'products', 'take_rate', 'category', 'fraud_group',
                 'user_id', 'consumer', 'consumer_address', 'consumer_state', 'consumer_postcode', 'order_day',
                 '__index_level_0__']).where((col("order_datetime") > "2022-02-28")).where((col("order_datetime") < "2022-10-01"))
for i in sdf.columns[:8]:
    if i != "median_age":
        sdf = sdf.drop(i)
# for i in sdf.columns[:4]:
#     sdf = sdf.withColumn(f"log_{i}", when(log(col(i)) > 0, log(col(i))).otherwise(0)).drop(i).withColumnRenamed(f"log_{i}", i)
# sdf = sdf.select(*(sdf.columns[-4:]), *(sdf.columns[:-4]))
order = sdf.columns
sdf = sdf.withColumn("year", col("order_year").cast("int")).drop("order_year").withColumnRenamed("year", "order_year").select(*order)
sdf

median_age,merchant_name,revenue_level,order_datetime,tag,dollar_value,order_year,order_month,consumer_gender
33.00000004463053,Elit Sed Consequa...,a,2022-04-27,artist supply craft,375.16773164703153,2022,4,Female
33.00000004463053,Mollis Integer Co...,b,2022-05-02,digital goods boo...,83.63154755239155,2022,5,Female
33.00000004463053,Hendrerit A Corpo...,a,2022-07-14,watch clock jewel...,100.80643267043833,2022,7,Male
33.00000004463053,Hendrerit A Corpo...,a,2022-05-19,watch clock jewel...,276.12651679041534,2022,5,Female
33.00000004463053,Hendrerit A Corpo...,a,2022-03-29,watch clock jewel...,169.4756116760845,2022,3,Undisclosed
33.00000004463053,Faucibus Leo In C...,a,2022-07-08,bicycle sales ser...,193.0553461719188,2022,7,Male
33.00000004463053,Eros Limited,c,2022-06-26,digital goods boo...,4.123199407915545,2022,6,Undisclosed
33.00000004463053,Eros Limited,c,2022-08-11,digital goods boo...,0.1130668899834874,2022,8,Female
33.00000004463053,Eget Metus In Cor...,a,2022-03-15,tent awning,24.60325749514622,2022,3,Undisclosed
33.00000004463053,Mi Lorem Inc.,b,2022-09-06,watch clock jewel...,69.96821111498734,2022,9,Male


In [327]:
to_do = {'2022': [11, 12],
         '2023': range(1,13)}
merchs = sdf.select("merchant_name", "revenue_level", "tag").distinct()
cols = ["median_age", "order_datetime_str", "dollar_value", "order_year", "order_month", "consumer_gender"]
vals = [(0, f"{year}-{str(month).rjust(2, '0')}-01", 0, int(year), month, "Female") for year in to_do for month in to_do[year]]
months = spark.createDataFrame(vals, cols)
cols.remove("order_datetime_str")
months = months.select(*cols, to_date(col("order_datetime_str"), "yyyy-MM-dd").alias("order_datetime"))
predict = merchs.join(months).select(sdf.columns)
sdf = sdf.union(predict)
sdf

                                                                                

median_age,merchant_name,revenue_level,order_datetime,tag,dollar_value,order_year,order_month,consumer_gender
33.00000004463053,Elit Sed Consequa...,a,2022-04-27,artist supply craft,375.16773164703153,2022,4,Female
33.00000004463053,Mollis Integer Co...,b,2022-05-02,digital goods boo...,83.63154755239155,2022,5,Female
33.00000004463053,Hendrerit A Corpo...,a,2022-07-14,watch clock jewel...,100.80643267043833,2022,7,Male
33.00000004463053,Hendrerit A Corpo...,a,2022-05-19,watch clock jewel...,276.12651679041534,2022,5,Female
33.00000004463053,Hendrerit A Corpo...,a,2022-03-29,watch clock jewel...,169.4756116760845,2022,3,Undisclosed
33.00000004463053,Faucibus Leo In C...,a,2022-07-08,bicycle sales ser...,193.0553461719188,2022,7,Male
33.00000004463053,Eros Limited,c,2022-06-26,digital goods boo...,4.123199407915545,2022,6,Undisclosed
33.00000004463053,Eros Limited,c,2022-08-11,digital goods boo...,0.1130668899834874,2022,8,Female
33.00000004463053,Eget Metus In Cor...,a,2022-03-15,tent awning,24.60325749514622,2022,3,Undisclosed
33.00000004463053,Mi Lorem Inc.,b,2022-09-06,watch clock jewel...,69.96821111498734,2022,9,Male


In [328]:
genders = sdf.select("consumer_gender").distinct().rdd.flatMap(lambda x: x).collect()
exprs = [when(col("consumer_gender") == category, 1).otherwise(0).alias(category)
         for category in genders]
sdf = sdf.select(*(sdf.columns), *exprs).drop("consumer_gender")

                                                                                

In [329]:
train = sdf.where((col("order_datetime") < "2022-10-01")).drop("order_datetime")
test = sdf.where((col("order_datetime") > "2022-10-31")).drop("order_datetime")

In [330]:
train_group = train
w = Window.partitionBy('merchant_name')
train_group = train_group.withColumn("avg(median_age)", avg("median_age").over(w)).drop("median_age")

In [331]:
agg_cols = train_group.columns
for i in ["dollar_value", "Female", "Male", "Undisclosed"]:
    agg_cols.remove(i)
train_group = train_group.groupBy(agg_cols).agg(sum("dollar_value"), sum("Female"), sum("Male"), sum("Undisclosed"))

In [332]:
fill = {'2021': range(3,13),
        '2022': range(1,11)}
merchs = train_group.select("merchant_name", "revenue_level", "tag", "avg(median_age)").distinct()
cols = ["order_year", "order_month", "sum(dollar_value)", "sum(Female)", "sum(Male)", "sum(Undisclosed)"]
vals = [(int(year), month, 0, 0, 0, 0) for year in fill for month in fill[year]]
months = spark.createDataFrame(vals, cols)
months = merchs.join(months).select(train_group.columns)
train_agg = train_group.union(months)

In [333]:
group_cols = [i for i in train_agg.columns if "sum" not in i]
agg_cols = [i for i in train_agg.columns if "sum" in i]
train_agg = train_agg.groupBy(group_cols).max()
train_agg = train_agg.drop(*["max(order_month)", "max(avg(median_age))", "max(order_year)"])
for i in agg_cols:
    train_agg = train_agg.withColumnRenamed(f"max({i})", i)
train_agg = train_agg.withColumnRenamed("sum(dollar_value)", "dollar_value")

w = Window.partitionBy('merchant_name')
for i in genders:
    train_agg = train_agg.withColumn(f"avg(sum({i}))", avg(f"sum({i})").over(w)).drop(f"sum({i})")

In [334]:
test_agg = test.select(test.columns[1:7])
merch_agg = train_agg.select(col("merchant_name").alias("merchant_name_copy"), *(train_agg.columns[4:])).drop("order_month", "dollar_value").distinct()
test_agg = test_agg.join(merch_agg, test_agg.merchant_name == merch_agg.merchant_name_copy, "left").drop("merchant_name_copy")
test_agg = test_agg.select(*(train_agg.columns))

In [335]:
merch_meta = train_agg.select('merchant_name', 'revenue_level', 'avg(median_age)',
                              'avg(sum(Undisclosed))', 'avg(sum(Female))', 'avg(sum(Male))').distinct()
merch_meta = merch_meta.join(categories, merch_meta.merchant_name == categories.merchant_name_copy).drop("merchant_name_copy")
merch_meta.write.parquet('../data/meta/merchant_metadata.parquet')

                                                                                

In [336]:
from pyspark.ml.feature import OneHotEncoder, StringIndexer, VectorAssembler, Interaction

interact = ["tag", "revenue_level", "order_month"]
for i in interact:
    stringIndexer = StringIndexer(inputCol=i, outputCol=f"{i}_num")
    train_agg = stringIndexer.fit(train_agg).transform(train_agg).drop(i)
    test_agg = stringIndexer.fit(test_agg).transform(test_agg)
    encoder = OneHotEncoder(inputCol=f"{i}_num", outputCol=i+"_vec")
    train_agg = encoder.fit(train_agg).transform(train_agg).drop(f"{i}_num")
    test_agg = encoder.fit(test_agg).transform(test_agg).drop(f"{i}_num")
interaction = Interaction(inputCols=[f"{i}_vec" for i in interact], outputCol="interact")
train_agg = interaction.transform(train_agg)
test_agg = interaction.transform(test_agg)

                                                                                

In [337]:
from pyspark.sql import Row
from pyspark.ml.linalg import Vectors

feats = train_agg.columns
feats.remove('merchant_name')
feats.remove('dollar_value')
assembler = VectorAssembler(inputCols=feats, outputCol="features")
train_vec = assembler.transform(train_agg)
test_vec = assembler.transform(test_agg)

In [338]:
train_gbt = train_vec.select("features", col("dollar_value").alias("label"))
test_gbt = test_vec.select('merchant_name', 'order_year', 'order_month', "features")

In [339]:
from pyspark.ml.regression import GBTRegressor

GBT = GBTRegressor(featuresCol="features")
gbt = GBT.fit(train_gbt)
gbt_predict = gbt.transform(test_gbt)
# print("Root Mean Squared Error (RMSE) = %g" % rmse_evaluator.evaluate(gbt_predict))
# print("R Squared (R2) = %g" % r2_evaluator.evaluate(gbt_predict))

                                                                                

In [340]:
gbt_predict

                                                                                

merchant_name,order_year,order_month,features,prediction
Varius Orci Insti...,2022,11,"(1100,[0,1,2,3,4,...",4473.089115791361
Semper Incorporated,2022,11,"(1100,[0,1,2,3,4,...",-4162.136124455534
Hendrerit Consect...,2022,11,"(1100,[0,1,2,3,4,...",-11699.946724869978
Erat Semper Ltd,2022,11,"(1100,[0,1,2,3,4,...",-4854.805910209342
Vel Turpis Company,2022,11,"(1100,[0,1,2,3,4,...",4087.8404028149953
Curabitur Vel LLC,2022,11,"(1100,[0,1,2,3,4,...",736.6430777362441
Vulputate Velit E...,2022,11,"(1100,[0,1,2,3,4,...",681.8223443507804
Porttitor Eros In...,2022,11,"(1100,[0,1,2,3,4,...",-5715.928514035324
Dictum Mi Incorpo...,2022,11,"(1100,[0,1,2,3,4,...",2588.902135624668
Pede Praesent Ltd,2022,11,"(1100,[0,1,2,3,4,...",-637.7269496538365


In [342]:
os.makedirs("../data/results/", exist_ok=True)
gbt_predict.select("merchant_name", "order_year", "order_month", "prediction").write.parquet("../data/results/predictions.parquet")

                                                                                