In [1]:
pip install polars pyarrow numpy tqdm scikit-learn gensim implicit

Note: you may need to restart the kernel to use updated packages.


In [1]:
import os
from typing import List

import polars as pl
import numpy as np

In [2]:
raw_path = r"C:\IMPORTANT\ACTIVITIES\sirius\megamarket\megamarket.parquet"
out_dir  = r"C:\IMPORTANT\ACTIVITIES\sirius\megamarket\emb_out"
read_limit_rows = None
parquet_compression = "zstd"
als_sampling = True
min_user_events = 20
min_item_users  = 20
max_users_cap   = None
max_items_cap   = None 

os.makedirs(out_dir, exist_ok=True)
assert os.path.exists(raw_path), f"Not found: {raw_path}"
print({"raw_path": raw_path, "out_dir": out_dir, "read_limit_rows": read_limit_rows})

{'raw_path': 'C:\\IMPORTANT\\ACTIVITIES\\sirius\\megamarket\\megamarket.parquet', 'out_dir': 'C:\\IMPORTANT\\ACTIVITIES\\sirius\\megamarket\\emb_out', 'read_limit_rows': None}


In [3]:
#n_rows = pl.scan_parquet(raw_path).select(pl.count()).collect(engine="streaming").item()
#print(n_rows)

In [3]:
scan = pl.scan_parquet(raw_path)
print("Columns:", scan.schema)

head_df = scan.limit(5).collect(engine="streaming")
print("Head shape:", head_df.shape)
print(head_df)

cols = scan.columns
has_dt = "datetime" in cols

if has_dt:
    scan_filt = scan.filter(pl.col("datetime").dt.month() != 5)
    ts_df = scan_filt.select([
        pl.col("datetime").min().alias("min_ts"),
        pl.col("datetime").max().alias("max_ts"),
    ]).collect(engine="streaming")
    print("Time range:", ts_df.to_dict(as_series=False))
else:
    scan_filt = scan

need = [c for c in ["user_id", "item_id", "datetime"] if c in cols]
assert ("user_id" in need) and ("item_id" in need), "Dataset must contain user_id and item_id"
base_scan = scan_filt.select(need)
if read_limit_rows is not None:
    base_df = base_scan.limit(read_limit_rows).collect(engine="streaming")
    print("Base slice shape:", base_df.shape)
    print(base_df.head(5))
else:
    base_df = None
    print("Full dataset mode: using lazy pipelines (no base_df materialization).")

  print("Columns:", scan.schema)


Columns: Schema([('user_id', Int32), ('datetime', Datetime(time_unit='ms', time_zone=None)), ('event', Int32), ('item_id', Int32), ('category_id', Int32), ('price', Float32)])
Head shape: (5, 6)
shape: (5, 6)
┌─────────┬─────────────────────────┬───────┬─────────┬─────────────┬───────────┐
│ user_id ┆ datetime                ┆ event ┆ item_id ┆ category_id ┆ price     │
│ ---     ┆ ---                     ┆ ---   ┆ ---     ┆ ---         ┆ ---       │
│ i32     ┆ datetime[ms]            ┆ i32   ┆ i32     ┆ i32         ┆ f32       │
╞═════════╪═════════════════════════╪═══════╪═════════╪═════════════╪═══════════╡
│ 1199174 ┆ 2023-02-13 03:12:19.131 ┆ 2     ┆ 1861088 ┆ 5395        ┆ -0.042062 │
│ 3124963 ┆ 2023-02-13 09:38:55.674 ┆ 2     ┆ 2084441 ┆ 1531        ┆ 0.079817  │
│ 2106055 ┆ 2023-02-13 08:26:05.004 ┆ 2     ┆ 2586689 ┆ 9294        ┆ -0.03899  │
│ 4169844 ┆ 2023-02-13 16:55:08.470 ┆ 2     ┆ 2586689 ┆ 9294        ┆ -0.03899  │
│ 7472260 ┆ 2023-02-13 03:30:13.093 ┆ 2     ┆ 2586689

  cols = scan.columns


Time range: {'min_ts': [datetime.datetime(2023, 1, 15, 0, 0, 0, 708000)], 'max_ts': [datetime.datetime(2023, 4, 30, 23, 59, 59, 470000)]}
Full dataset mode: using lazy pipelines (no base_df materialization).


In [4]:
if read_limit_rows is not None:
    _df = base_df
    if has_dt and ("datetime" in _df.columns):
        latest_ts = _df.select(pl.col("datetime").max()).item()
        df2 = _df.with_columns([
            pl.col("datetime").alias("ts"),
            pl.col("datetime").dt.date().alias("date"),
        ])
    else:
        latest_ts = None
        df2 = _df.with_columns([
            pl.lit(None).alias("ts"),
            pl.lit(None).alias("date"),
        ])

    user_agg = (
        df2.group_by("user_id", maintain_order=True).agg([
            pl.len().alias("events"),
            pl.n_unique("item_id").alias("n_items"),
            pl.n_unique("date").alias("n_days"),
            pl.col("ts").max().alias("last_ts"),
        ])
    )
    if latest_ts is not None:
        user_agg = user_agg.with_columns(
            (pl.lit(latest_ts) - pl.col("last_ts")).dt.total_days().cast(pl.Int64).alias("recency_days")
        )
    else:
        user_agg = user_agg.with_columns(pl.lit(None).cast(pl.Int64).alias("recency_days"))
else:
    latest_ts = scan_filt.select(pl.col("datetime").max()).collect(engine="streaming").item()
    df2_lazy = base_scan.with_columns([
        pl.col("datetime").alias("ts"),
        pl.col("datetime").dt.date().alias("date"),
    ]) if has_dt and ("datetime" in need) else base_scan.with_columns([
        pl.lit(None).alias("ts"),
        pl.lit(None).alias("date"),
    ])
    user_agg_lazy = (
        df2_lazy.group_by("user_id").agg([
            pl.len().alias("events"),
            pl.n_unique("item_id").alias("n_items"),
            pl.n_unique("date").alias("n_days"),
            pl.col("ts").max().alias("last_ts"),
        ])
    )
    user_out = os.path.join(out_dir, "user_agg.parquet")
    user_agg_lazy.sink_parquet(user_out, compression=parquet_compression)
    user_agg = pl.scan_parquet(user_out)
    if latest_ts is not None:
        user_agg = user_agg.with_columns((pl.lit(latest_ts) - pl.col("last_ts")).dt.total_days().cast(pl.Int64).alias("recency_days"))
    else:
        user_agg = user_agg.with_columns(pl.lit(None).cast(pl.Int64).alias("recency_days"))
    user_agg = user_agg.collect(engine="streaming")
    user_agg.write_parquet(user_out, compression=parquet_compression)
    print("user_agg shape:", user_agg.shape, "->", user_out)
    print(user_agg.head(5))
    if False:
        pass

if read_limit_rows is not None:
    user_out = os.path.join(out_dir, "user_agg.parquet")
    user_agg.write_parquet(user_out, compression=parquet_compression)
    print("user_agg shape:", user_agg.shape, "->", user_out)
    print(user_agg.head(5))

user_agg shape: (2521538, 6) -> C:\IMPORTANT\ACTIVITIES\sirius\megamarket\emb_out\user_agg.parquet
shape: (5, 6)
┌─────────┬────────┬─────────┬────────┬─────────────────────┬──────────────┐
│ user_id ┆ events ┆ n_items ┆ n_days ┆ last_ts             ┆ recency_days │
│ ---     ┆ ---    ┆ ---     ┆ ---    ┆ ---                 ┆ ---          │
│ i32     ┆ u32    ┆ u32     ┆ u32    ┆ datetime[ms]        ┆ i64          │
╞═════════╪════════╪═════════╪════════╪═════════════════════╪══════════════╡
│ 9451783 ┆ 4      ┆ 4       ┆ 2      ┆ 2023-02-28 12:05:51 ┆ 61           │
│ 9090022 ┆ 1      ┆ 1       ┆ 1      ┆ 2023-02-17 17:08:25 ┆ 72           │
│ 4845635 ┆ 2      ┆ 2       ┆ 1      ┆ 2023-02-21 08:07:37 ┆ 68           │
│ 3472058 ┆ 33     ┆ 29      ┆ 1      ┆ 2023-02-19 11:45:19 ┆ 70           │
│ 1817437 ┆ 3      ┆ 3       ┆ 2      ┆ 2023-04-11 16:29:05 ┆ 19           │
└─────────┴────────┴─────────┴────────┴─────────────────────┴──────────────┘


In [5]:
if read_limit_rows is not None:
    _df = base_df
    if has_dt and ("datetime" in _df.columns):
        latest_ts = _df.select(pl.col("datetime").max()).item()
        df2 = _df.with_columns([
            pl.col("datetime").alias("ts"),
        ])
    else:
        latest_ts = None
        df2 = _df.with_columns([pl.lit(None).alias("ts")])

    item_agg = (
        df2.group_by("item_id", maintain_order=True).agg([
            pl.len().alias("events"),
            pl.n_unique("user_id").alias("n_users"),
            (pl.len() / pl.n_unique("user_id")).cast(pl.Float64).alias("avg_events_per_user"),
            pl.col("ts").max().alias("last_ts"),
        ])
    )
    if latest_ts is not None:
        item_agg = item_agg.with_columns(
            (pl.lit(latest_ts) - pl.col("last_ts")).dt.total_days().cast(pl.Int64).alias("recency_days")
        )
    else:
        item_agg = item_agg.with_columns(pl.lit(None).cast(pl.Int64).alias("recency_days"))
else:
    latest_ts = scan_filt.select(pl.col("datetime").max()).collect(engine="streaming").item()
    df2_lazy = base_scan.with_columns([
        pl.col("datetime").alias("ts")
    ]) if has_dt and ("datetime" in need) else base_scan.with_columns([
        pl.lit(None).alias("ts")
    ])
    item_agg_lazy = (
        df2_lazy.group_by("item_id").agg([
            pl.len().alias("events"),
            pl.n_unique("user_id").alias("n_users"),
            (pl.len() / pl.n_unique("user_id")).cast(pl.Float64).alias("avg_events_per_user"),
            pl.col("ts").max().alias("last_ts"),
        ])
    )
    item_out = os.path.join(out_dir, "item_agg.parquet")
    item_agg_lazy.sink_parquet(item_out, compression=parquet_compression)
    item_agg = pl.scan_parquet(item_out)
    if latest_ts is not None:
        item_agg = item_agg.with_columns((pl.lit(latest_ts) - pl.col("last_ts")).dt.total_days().cast(pl.Int64).alias("recency_days"))
    else:
        item_agg = item_agg.with_columns(pl.lit(None).cast(pl.Int64).alias("recency_days"))
    item_agg = item_agg.collect(engine="streaming")
    item_agg.write_parquet(item_out, compression=parquet_compression)
    print("item_agg shape:", item_agg.shape, "->", item_out)
    print(item_agg.head(5))
    if False:
        pass

if read_limit_rows is not None:
    item_out = os.path.join(out_dir, "item_agg.parquet")
    item_agg.write_parquet(item_out, compression=parquet_compression)
    print("item_agg shape:", item_agg.shape, "->", item_out)
    print(item_agg.head(5))

item_agg shape: (3340545, 6) -> C:\IMPORTANT\ACTIVITIES\sirius\megamarket\emb_out\item_agg.parquet
shape: (5, 6)
┌─────────┬────────┬─────────┬─────────────────────┬─────────────────────────┬──────────────┐
│ item_id ┆ events ┆ n_users ┆ avg_events_per_user ┆ last_ts                 ┆ recency_days │
│ ---     ┆ ---    ┆ ---     ┆ ---                 ┆ ---                     ┆ ---          │
│ i32     ┆ u32    ┆ u32     ┆ f64                 ┆ datetime[ms]            ┆ i64          │
╞═════════╪════════╪═════════╪═════════════════════╪═════════════════════════╪══════════════╡
│ 2575937 ┆ 1      ┆ 1       ┆ 1.0                 ┆ 2023-01-15 08:25:28.120 ┆ 105          │
│ 2161189 ┆ 40     ┆ 12      ┆ 3.333333            ┆ 2023-04-30 15:33:15.491 ┆ 0            │
│ 2012557 ┆ 19     ┆ 7       ┆ 2.714286            ┆ 2023-04-29 17:30:30     ┆ 1            │
│ 3353771 ┆ 3      ┆ 2       ┆ 1.5                 ┆ 2023-03-25 23:11:26.647 ┆ 36           │
│ 3353119 ┆ 16     ┆ 7       ┆ 2.285714  

In [6]:
from scipy.sparse import coo_matrix
import implicit

In [7]:
use_binary_weights = True

if read_limit_rows is not None:
    inter = (
        base_df.select(["user_id", "item_id"]).group_by(["user_id", "item_id"], maintain_order=True)
               .agg(pl.len().alias("cnt"))
    )
else:
    inter_lazy = base_scan.select(["user_id", "item_id"]).group_by(["user_id", "item_id"]).agg(pl.len().alias("cnt"))
    inter_out = os.path.join(out_dir, "inter.parquet")
    inter_lazy.sink_parquet(inter_out, compression=parquet_compression)

    if als_sampling:
        inter_scan = pl.scan_parquet(inter_out)
        user_cnts = inter_scan.group_by("user_id").agg(pl.len().alias("u_events"))
        item_cnts = inter_scan.group_by("item_id").agg(pl.len().alias("i_users"))

        active_users = user_cnts.filter(pl.col("u_events") >= min_user_events).select("user_id")
        active_items = item_cnts.filter(pl.col("i_users") >= min_item_users).select("item_id")

        inter_core_out = os.path.join(out_dir, "inter_core.parquet")
        (inter_scan
            .join(active_users, on="user_id", how="inner")
            .join(active_items, on="item_id", how="inner")
            .sink_parquet(inter_core_out, compression=parquet_compression)
        )

        inter_core_scan = pl.scan_parquet(inter_core_out)
        if max_users_cap is not None:
            keep_users = (inter_core_scan.select("user_id").unique()
                                         .limit(max_users_cap))
            inter_core_scan = inter_core_scan.join(keep_users, on="user_id", how="inner")
        if max_items_cap is not None:
            keep_items = (inter_core_scan.select("item_id").unique()
                                          .limit(max_items_cap))
            inter_core_scan = inter_core_scan.join(keep_items, on="item_id", how="inner")

        inter = inter_core_scan.collect(engine="streaming")
    else:
        inter = pl.scan_parquet(inter_out).collect(engine="streaming")
if use_binary_weights:
    inter = inter.with_columns(pl.lit(1.0).alias("weight"))
else:
    inter = inter.with_columns(pl.col("cnt").cast(pl.Float32).alias("weight"))

uniq_users = inter.select("user_id").unique().sort("user_id").to_series().to_list()
uniq_items = inter.select("item_id").unique().sort("item_id").to_series().to_list()
user2idx = {u: i for i, u in enumerate(uniq_users)}
item2idx = {it: i for i, it in enumerate(uniq_items)}

user_arr = np.array(inter.select("user_id").to_series().to_list(), dtype=np.int64)
item_arr = np.array(inter.select("item_id").to_series().to_list(), dtype=np.int64)
rows = np.fromiter((user2idx[u] for u in user_arr), dtype=np.uint32, count=len(user_arr))
cols = np.fromiter((item2idx[i] for i in item_arr), dtype=np.uint32, count=len(item_arr))
data = np.array(inter.select("weight").to_series().to_list(), dtype=np.float32)

num_users = len(uniq_users)
num_items = len(uniq_items)
user_item_coo = coo_matrix((data, (rows, cols)), shape=(num_users, num_items), dtype=np.float32)
user_item_csr = user_item_coo.tocsr()
item_user_csr = user_item_csr.T.tocsr()

print({
    "interactions_rows": inter.height,
    "num_users": num_users,
    "num_items": num_items,
})

{'interactions_rows': 35900269, 'num_users': 447207, 'num_items': 337637}


In [8]:
os.environ.setdefault("OPENBLAS_NUM_THREADS", "4")
os.environ.setdefault("OMP_NUM_THREADS", "4")
os.environ.setdefault("MKL_NUM_THREADS", "4")

factors = 48
regularization = 0.01
iterations = 12
als = implicit.als.AlternatingLeastSquares(
    factors=factors,
    regularization=regularization,
    iterations=iterations,
    random_state=42,
)
als.fit(item_user_csr)

  check_blas_config()


  0%|          | 0/12 [00:00<?, ?it/s]

In [9]:
user_factors = als.user_factors
item_factors = als.item_factors

uf_rows, if_rows = user_factors.shape[0], item_factors.shape[0]
nu, ni = len(uniq_users), len(uniq_items)

if uf_rows == nu and if_rows == ni:
    users_mat, items_mat = user_factors, item_factors
elif uf_rows == ni and if_rows == nu:
    users_mat, items_mat = item_factors, user_factors
else:
    raise RuntimeError(f"ALS factor shapes do not match ids: user_factors={user_factors.shape}, item_factors={item_factors.shape}, users={nu}, items={ni}")

user_schema = [f"f{i}" for i in range(users_mat.shape[1])]
item_schema = [f"f{i}" for i in range(items_mat.shape[1])]

users_df = pl.DataFrame(users_mat, schema=user_schema)
users_df = users_df.with_columns(pl.Series("user_id", np.array(uniq_users, dtype=np.int64))) \
                   .select(["user_id", *user_schema])

items_df = pl.DataFrame(items_mat, schema=item_schema)
items_df = items_df.with_columns(pl.Series("item_id", np.array(uniq_items, dtype=np.int64))) \
                   .select(["item_id", *item_schema])

als_user_out = os.path.join(out_dir, "als_user.parquet")
als_item_out = os.path.join(out_dir, "als_item.parquet")
users_df.write_parquet(als_user_out)
items_df.write_parquet(als_item_out)
print("ALS saved:", als_user_out, als_item_out)

ALS saved: C:\IMPORTANT\ACTIVITIES\sirius\megamarket\emb_out\als_user.parquet C:\IMPORTANT\ACTIVITIES\sirius\megamarket\emb_out\als_item.parquet
