In [1]:
# See lib.rs for details about constants
CARD_TYPES = ("DC", "CC")
TRANSACTION_TYPES = (
    "food-and-household",
    "home",
    "uncategorized",
    "leisure-and-lifestyle",
    "health-and-beauty",
    "shopping-and-services",
    "children",
    "vacation-and-travel",
    "education",
    "insurance",
    "investments-and-savings",
    "expenses-and-other",
    "cars-and-transportation",
)
CHANNELS = ("mobile", "web")


# Required time windows
WINDOWS_IN_DAYS = (
    7,  # week
    14,  # two weeks
    21,  # three weeks
    30,  # month
    90,  # three months
    180,  # half of the year
    360,  # year
    720,  # two years
)

In [2]:
from pyspark.sql import SparkSession, functions as F, DataFrame
from pyspark.sql.column import Column

In [3]:
spark = (
    SparkSession.builder.master("local[*]")
    .config("spark.driver.memory", "8g")
    .config("spark.executor.memory", "8g")
    .config("spark.sql.shuffle.partitions", "12")
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    .config("spark.ui.showConsoleProgress", "false")
    .config("spark.log.level", "ERROR")
    .getOrCreate()
)

24/06/03 21:58:30 WARN Utils: Your hostname, toolbox resolves to a loopback address: 127.0.0.1; using 192.168.0.29 instead (on interface wlp0s20f3)
24/06/03 21:58:30 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/06/03 21:58:30 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [4]:
df = spark.read.parquet("../test_data_tiny")

In [5]:
def transform_col(col: str, all_cols: set[str]) -> Column:
    cols_to_process = [col]
    if "_2w_" in col:
        cols_to_process.append(col.replace("_2w_", "_1w_"))
    elif "_3w_" in col:
        cols_to_process.append(col.replace("_3w_", "_1w_"))
        cols_to_process.append(col.replace("_3w_", "_2w_"))
    elif "_1m_" in col:
        cols_to_process.append(col.replace("_1m_", "_1w_"))
        cols_to_process.append(col.replace("_1m_", "_2w_"))
        cols_to_process.append(col.replace("_1m_", "_3w_"))
    elif "_3m_" in col:
        cols_to_process.append(col.replace("_3m_", "_1w_"))
        cols_to_process.append(col.replace("_3m_", "_2w_"))
        cols_to_process.append(col.replace("_3m_", "_3w_"))
        cols_to_process.append(col.replace("_3m_", "_1m_"))
    elif "_6m_" in col:
        cols_to_process.append(col.replace("_6m_", "_1w_"))
        cols_to_process.append(col.replace("_6m_", "_2w_"))
        cols_to_process.append(col.replace("_6m_", "_3w_"))
        cols_to_process.append(col.replace("_6m_", "_1m_"))
        cols_to_process.append(col.replace("_6m_", "_3m_"))
    elif "_1y_" in col:
        cols_to_process.append(col.replace("_1y_", "_1w_"))
        cols_to_process.append(col.replace("_1y_", "_2w_"))
        cols_to_process.append(col.replace("_1y_", "_3w_"))
        cols_to_process.append(col.replace("_1y_", "_1m_"))
        cols_to_process.append(col.replace("_1y_", "_3m_"))
        cols_to_process.append(col.replace("_1y_", "_6m_"))
    elif "_2y_" in col:
        cols_to_process.append(col.replace("_2y_", "_1w_"))
        cols_to_process.append(col.replace("_2y_", "_2w_"))
        cols_to_process.append(col.replace("_2y_", "_3w_"))
        cols_to_process.append(col.replace("_2y_", "_1m_"))
        cols_to_process.append(col.replace("_2y_", "_3m_"))
        cols_to_process.append(col.replace("_2y_", "_6m_"))
        cols_to_process.append(col.replace("_2y_", "_1y_"))

    cols_to_process = [c for c in cols_to_process if c in all_cols]
    if len(cols_to_process) == 1:
        return F.col(col).alias(col.replace("(trx_amnt)", ""))
    if ("_sum(" in col) or ("_count(" in col):
        return (sum([F.col(x) for x in cols_to_process])).alias(col.replace("(trx_amnt)", ""))
    elif "_max(" in col:
        return F.greatest(*cols_to_process).alias(col.replace("(trx_amnt)", ""))
    elif "_min(" in col:
        return F.least(*cols_to_process).alias(col.replace("(trx_amnt)", ""))
    else:
        return (sum([F.col(x) for x in cols_to_process]) / len(cols_to_process)).alias(col.replace("(trx_amnt)", ""))


def generate_pivoted_batch(df: DataFrame, groups: list[str]) -> DataFrame:
    # Partially inpsired by https://stackoverflow.com/a/73850575

    t_groups = groups + ["_win"]
    pivot_col = F.concat_ws("_", *t_groups)
    win_cols = ["1w", "2w", "3w", "1m", "3m", "6m", "1y", "2y"]
    if groups[0] == "card_type":
        pivot_values = [f"{ct}_{tt}_{ww}" for ct in CARD_TYPES for tt in TRANSACTION_TYPES for ww in win_cols]
    else:
        pivot_values = [f"{ch}_{tt}_{ww}" for ch in CHANNELS for tt in TRANSACTION_TYPES for ww in win_cols]

    tdf = (
        df
        .withColumn(
            "_win",
            F.when(F.col("t_minus") <= F.lit(7), F.lit("1w"))
            .when(F.col("t_minus") <= F.lit(14), F.lit("2w"))
            .when(F.col("t_minus") <= F.lit(21), F.lit("3w"))
            .when(F.col("t_minus") <= F.lit(30), F.lit("1m"))
            .when(F.col("t_minus") <= F.lit(90), F.lit("3m"))
            .when(F.col("t_minus") <= F.lit(180), F.lit("6m"))
            .when(F.col("t_minus") <= F.lit(360), F.lit("1y"))
            .when(F.col("t_minus") <= F.lit(720), F.lit("2y"))
        )
        .withColumn("_pivot", pivot_col)
        .groupBy("customer_id")
        .pivot("_pivot", pivot_values)
        .agg(
            F.count("trx_amnt"),
            F.sum("trx_amnt"),
            F.min("trx_amnt"),
            F.max("trx_amnt"),
            F.mean("trx_amnt"),
        )
    )
    columns_to_select = [F.col("customer_id")]
    all_cols = set(tdf.columns)
    for col in tdf.columns:
        if col == "customer_id":
            continue
        else:
            columns_to_select.append(transform_col(col, all_cols))

    return tdf.select(*columns_to_select)

In [6]:
part1 = generate_pivoted_batch(df, ["card_type", "trx_type"])
part2 = generate_pivoted_batch(df, ["channel", "trx_type"])

result = part1.join(part2, on=["customer_id"], how="inner")
pdf_test = result.toPandas()

24/06/03 21:59:03 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
24/06/03 21:59:12 WARN DAGScheduler: Broadcasting large task binary with size 2.5 MiB


In [7]:
pdf_test

Unnamed: 0,customer_id,DC_food-and-household_1w_count,DC_food-and-household_1w_sum,DC_food-and-household_1w_min,DC_food-and-household_1w_max,DC_food-and-household_1w_avg,DC_food-and-household_2w_count,DC_food-and-household_2w_sum,DC_food-and-household_2w_min,DC_food-and-household_2w_max,...,web_cars-and-transportation_1y_count,web_cars-and-transportation_1y_sum,web_cars-and-transportation_1y_min,web_cars-and-transportation_1y_max,web_cars-and-transportation_1y_avg,web_cars-and-transportation_2y_count,web_cars-and-transportation_2y_sum,web_cars-and-transportation_2y_min,web_cars-and-transportation_2y_max,web_cars-and-transportation_2y_avg
0,0,199,1.010461e+06,110.234860,9977.113405,5077.692213,,,110.234860,9977.113405,...,,,144.292529,9979.892357,,,,106.491426,9992.602836,
1,28,144,6.791126e+05,155.356151,9986.920623,4716.059574,,,155.356151,9986.920623,...,,,111.374827,9992.945059,,,,111.374827,9992.945059,
2,30,175,8.164588e+05,114.675123,9979.929500,4665.478934,,,114.675123,9979.929500,...,,,116.680329,9991.337056,,,,111.376207,9991.337056,
3,31,200,1.039263e+06,165.532643,9952.459198,5196.317071,,,165.532643,9952.459198,...,,,129.456721,9999.783566,,,,129.456721,9999.783566,
4,32,136,6.735669e+05,135.336714,9979.919434,4952.698118,,,135.336714,9979.919434,...,,,108.850995,9993.135603,,,,108.850995,9993.135603,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,986,146,7.656383e+05,251.611819,9913.654187,5244.097603,,,251.611819,9913.654187,...,,,121.997269,9996.365634,,,,121.997269,9996.365634,
996,987,113,5.438359e+05,133.957041,9941.101411,4812.706828,,,133.957041,9941.101411,...,,,147.532580,9926.851289,,,,129.442835,9964.097175,
997,989,191,9.681422e+05,158.528208,9960.302763,5068.807270,,,158.528208,9960.302763,...,,,120.329659,9994.509707,,,,120.329659,9997.973922,
998,991,176,8.766102e+05,109.092234,9974.297278,4980.739678,,,109.092234,9974.297278,...,,,101.965210,9981.233625,,,,101.965210,9986.135070,
