In [1]:
import time

from pyspark.sql import Column, SparkSession
from pyspark.sql import functions as F

# See lib.rs for details about constants
CARD_TYPES = ("DC", "CC")
TRANSACTION_TYPES = (
    "food-and-household",
    "home",
    "uncategorized",
    "leisure-and-lifestyle",
    "health-and-beauty",
    "shopping-and-services",
    "children",
    "vacation-and-travel",
    "education",
    "insurance",
    "investments-and-savings",
    "expenses-and-other",
    "cars-and-transportation",
)
CHANNELS = ("mobile", "web")


# Required time windows
WINDOWS_IN_DAYS = (
    7,  # week
    14,  # two weeks
    21,  # three weeks
    30,  # month
    90,  # three months
    180,  # half of the year
    360,  # two years
)


def get_all_aggregations(col_prefix: str, cond: Column, cols_list: list[Column]) -> None:
    # Count over group
    cols_list.append(
        F.sum(F.when(cond, F.lit(1)).otherwise(F.lit(0))).alias(f"{col_prefix}_count")
    )
    # Average over group
    cols_list.append(
        F.mean(F.when(cond, F.col("trx_amnt")).otherwise(F.lit(None))).alias(
            f"{col_prefix}_avg"
        )
    )
    # Sum over group
    cols_list.append(
        F.sum(F.when(cond, F.col("trx_amnt")).otherwise(F.lit(0))).alias(
            f"{col_prefix}_sum"
        )
    )
    # Min over group
    cols_list.append(
        F.min(F.when(cond, F.col("trx_amnt")).otherwise(F.lit(None))).alias(
            f"{col_prefix}_min"
        )
    )
    # Max over group
    cols_list.append(
        F.max(F.when(cond, F.col("trx_amnt")).otherwise(F.lit(None))).alias(
            f"{col_prefix}_max"
        )
    )

In [2]:
# Rm an output parquet folder if it exists
!rm -r "../tmp_out"

rm: cannot remove '../tmp_out': No such file or directory


In [3]:
path = "../test_data_small"
start_time = time.time()

spark = (
    SparkSession.builder.master("local[*]")
    .config("spark.sql.shuffle.partitions", 2)
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    .getOrCreate()
)
# root
# |-- customer_id: long (nullable = true)
# |-- card_type: string (nullable = true)
# |-- trx_type: string (nullable = true)
# |-- channel: string (nullable = true)
# |-- trx_amnt: double (nullable = true)
# |-- t_minus: long (nullable = true)
# |-- part_col: string (nullable = true)

data = spark.read.parquet(path)

cols_list = []
for win in WINDOWS_IN_DAYS:
    # Iterate over combination card_type + trx_type
    for card_type in CARD_TYPES:
        for trx_type in TRANSACTION_TYPES:
            cond = F.lit(True)
            cond &= F.col("t_minus") <= F.lit(win)  # Is row in the window?
            cond &= F.col("card_type") == F.lit(
                card_type
            )  # Does row have needed card type?
            cond &= F.col("trx_type") == F.lit(
                trx_type
            )  # Does row have needed trx type?

            # Colname prefix
            col_prefix = f"{card_type}_{trx_type}_{win}d"

            get_all_aggregations(col_prefix, cond, cols_list)

    # Iterate over combination channel + trx_type
    for ch_type in CHANNELS:
        for trx_type in TRANSACTION_TYPES:
            cond = F.lit(True)
            cond &= F.col("t_minus") <= win  # Is row in the window?
            cond &= F.col("channel") == F.lit(
                ch_type
            )  # Does row have needed channel type?
            cond &= F.col("trx_type") == F.lit(
                trx_type
            )  # Does row have needed trx type?

            # Colname prefix
            col_prefix = f"{ch_type}_{trx_type}_{win}d"

            get_all_aggregations(col_prefix, cond, cols_list)

result = data.groupBy("customer_id").agg(*cols_list)

result.write.mode("overwrite").parquet("../tmp_out")

end_time = time.time()

24/05/11 19:12:13 WARN Utils: Your hostname, toolbox resolves to a loopback address: 127.0.0.1; using 192.168.0.29 instead (on interface wlp0s20f3)
24/05/11 19:12:13 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/11 19:12:14 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/05/11 19:12:23 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
24/05/11 19:12:24 WARN DAGScheduler: Broadcasting large task binary with size 1651.5 KiB
[Stage 1:>                                                        (0 + 12) / 15]



24/05/11 19:39:56 WARN DAGScheduler: Broadcasting large task binary with size 3.3 MiB
                                                                                

In [4]:
print(f"Total time: {end_time - start_time} seconds")

Total time: 1692.9119520187378 seconds
