In [27]:
import polars as pl
import data_actions.data_processing.utils as utils

In [28]:
data = utils.load_data("data/original_data")
utils.calculate_days_since_start(data)

In [29]:
event_counts = dict()
joined_table_weeks = dict()
datasets_per_day = dict()

In [30]:
days_since_start = pl.concat([data["events"][df].select("days_since_start") for df in data["events"]])
max_day = days_since_start.max().collect().item()
TARGET_START_DAY = max_day - 27
VAL_START_DAY = max_day - 13

In [31]:
product_info = data["products"]["product_properties"]
for table in ("add_to_cart", "remove_from_cart", "product_buy"):
    df = data["events"][table]
    df = df.join(product_info.select(pl.exclude("name")), on="sku", how="left")
    data["events"][table] = df

In [32]:
for day in range(28, TARGET_START_DAY):
    for start in range(day - 28, day, 7):
        week = start // 7
        for table in ("add_to_cart", "remove_from_cart", "product_buy"):
            df = data["events"][table]
            event_counts[table] = df.filter((pl.col("days_since_start") >= start) & (pl.col("days_since_start") < start + 7))\
                                  .group_by(["client_id"]).agg(
                [
                    pl.col("client_id").count().alias(f"week{week}_count_{table}"),
                    pl.col("sku").value_counts(sort=True).struct.field("sku").first().alias(f"week{week}_most_common_item_{table}"),
                    pl.col("category").value_counts(sort=True).struct.field("category").first().alias(f"week{week}_most_common_cat_{table}"),
                    pl.col("price").mean().alias(f"week{week}_avg_price_{table}"),
                    pl.col("price").value_counts(sort=True).struct.field("price").first().alias(f"week{week}_most_common_price_{table}"),
                ]
            )

        page_visits = data["events"]["page_visit"].select("client_id", "days_since_start", "url")\
                                                  .filter((pl.col("days_since_start") >= start) & (pl.col("days_since_start") < start + 7))\
                                                  .group_by("client_id").agg(
        [
            pl.col("client_id").count().alias(f"week{week}_page_visits_count"),
        ])
        event_counts["page_visit"] = page_visits

        query_counts = data["events"]["search_query"].select("client_id", "days_since_start")\
                                                     .filter((pl.col("days_since_start") >= start) & (pl.col("days_since_start") < start + 7))\
                                                     .group_by("client_id").count()
        query_counts = query_counts.rename({"count": f"week{week}_search_query_count"})
        event_counts["search_query"] = query_counts

        joined_table = list(event_counts.items())[0][1]
        for _, table in list(event_counts.items())[1:]:
            joined_table = joined_table.join(table, on=["client_id"], how="full", coalesce=True)
        joined_table_weeks[week] = joined_table

    full_joined_table = list(joined_table_weeks.values())[0]
    for _, table in list(joined_table_weeks.items())[1:]:
        full_joined_table = full_joined_table.join(table, on=["client_id"], how="full", coalesce=True)
    
    datasets_per_day[day] = full_joined_table

  .group_by("client_id").count()


In [None]:
dataset = list(datasets_per_day.values())[0].collect()
for i, table in enumerate(list(datasets_per_day.values())[1:]):
    dataset = dataset.vstack(table.collect())
    dataset.write_parquet(f"weekly_dataset_checkpoint{i}.parquet", compression="gzip")

: 