In [1]:
import pickle

import jax
import optax
import pandas as pd
import polars as pl
from flax import nnx
from flax_trainer.loss_fn import mean_squared_error
from flax_trainer.trainer import Trainer
from sklearn.model_selection import train_test_split

from flax_recsys.encoder import ColumnEncoder
from flax_recsys.evaluator import GeneralRatingEvaluator
from flax_recsys.loader import GeneralRatingLoader
from flax_recsys.model.DPLRFwFM import DPLRFwFM

### READ

In [2]:
rating_df, movie_df, user_df = (
    pl.from_pandas(
        pd.read_csv(
            "../../dataset/ML1M/ml-1m/ratings.dat",
            delimiter="::",
            engine="python",
            header=None,
            names=["user_id", "item_id", "rating", "timestamp"],
        )
    ),
    pl.from_pandas(
        pd.read_csv(
            "../../dataset/ML1M/ml-1m/movies.dat",
            delimiter="::",
            engine="python",
            header=None,
            names=["item_id", "title", "genres"],
            encoding="ISO-8859-1",
        )
    ).with_columns(pl.col("genres").str.split("|")),
    pl.from_pandas(
        pd.read_csv(
            "../../dataset/ML1M/ml-1m/users.dat",
            delimiter="::",
            engine="python",
            header=None,
            names=["user_id", "gender", "age", "occupation", "zip_code"],
            # encoding="ISO-8859-1",
        )
    ),
)

dataset_df = (
    rating_df.join(movie_df, how="left", on="item_id")
    .join(user_df, how="left", on="user_id")
    .select(pl.all().exclude("rating", "genres"), pl.col("rating"))
    .with_columns(pl.all().exclude("rating").rank("dense") - 1)
)
display(dataset_df)

column_encoder = ColumnEncoder(
    user_id="user_id",
    item_id="item_id",
    timestamp="timestamp",
    one_hot=["title", "gender", "age", "occupation", "zip_code"],
    rating="rating",
)

train_df, valid_df = train_test_split(
    dataset_df, test_size=0.1, random_state=0, shuffle=True
)

user_id,item_id,timestamp,title,gender,age,occupation,zip_code,rating
u32,u32,u32,u32,u32,u32,u32,u32,i64
0,1104,397442,2452,0,0,10,1588,5
0,639,397457,1739,0,0,10,1588,3
0,853,397454,2289,0,0,10,1588,3
0,3177,397440,1054,0,0,10,1588,4
0,2162,400250,557,0,0,10,1588,5
…,…,…,…,…,…,…,…,…
6039,1019,382,3574,1,2,6,466,1
6039,1022,21,814,1,2,6,466,5
6039,548,17,3578,1,2,6,466,5
6039,1024,346,3090,1,2,6,466,4


In [3]:
categorical_X, numerical_X, y = column_encoder.fit_transform(train_df)
loader = GeneralRatingLoader(
    categorical_X=categorical_X,
    numerical_X=numerical_X,
    y=y,
    batch_size=512,
    rngs=nnx.Rngs(0),
)

categorical_X, numerical_X, y = column_encoder.transform(valid_df)
evaluator = GeneralRatingEvaluator(
    categorical_X=categorical_X, numerical_X=numerical_X, y=y
)

### 学習

In [4]:
model = DPLRFwFM(
    categorical_feature_cardinalities=list(column_encoder.cardinality_map.values()),
    numerical_feature_num=column_encoder.numerical_column_num,
    embed_dim=30,
    rho=3,
    rngs=nnx.Rngs(0),
)

In [5]:
trainer = Trainer(
    model=model,
    optimizer=optax.adamw(learning_rate=0.001, weight_decay=0.001),
    train_loader=loader,
    loss_fn=mean_squared_error,
    test_evaluator=evaluator,
    early_stopping_patience=10,
    epoch_num=64,
)
trainer = trainer.fit()

[TEST  000] loss=79.80770111083984


[TRAIN 001]: 100%|██████████| 1759/1759 [00:06<00:00, 273.01it/s, batch_loss=0.9553055] 


[TEST  001] loss=1.0663902759552002


[TRAIN 002]: 100%|██████████| 1759/1759 [00:06<00:00, 280.78it/s, batch_loss=1.0358295] 


[TEST  002] loss=0.8617352843284607


[TRAIN 003]: 100%|██████████| 1759/1759 [00:06<00:00, 283.78it/s, batch_loss=0.7073412] 


[TEST  003] loss=0.8436973690986633


[TRAIN 004]: 100%|██████████| 1759/1759 [00:06<00:00, 283.92it/s, batch_loss=0.6688315] 


[TEST  004] loss=0.8339711427688599


[TRAIN 005]: 100%|██████████| 1759/1759 [00:06<00:00, 286.72it/s, batch_loss=0.7168115] 


[TEST  005] loss=0.8242194056510925


[TRAIN 006]: 100%|██████████| 1759/1759 [00:05<00:00, 296.32it/s, batch_loss=0.56539845]


[TEST  006] loss=0.811716616153717


[TRAIN 007]: 100%|██████████| 1759/1759 [00:06<00:00, 292.17it/s, batch_loss=0.6493333] 


[TEST  007] loss=0.8061832785606384


[TRAIN 008]: 100%|██████████| 1759/1759 [00:06<00:00, 288.36it/s, batch_loss=0.92218494]


[TEST  008] loss=0.8108053803443909


[TRAIN 009]: 100%|██████████| 1759/1759 [00:06<00:00, 272.55it/s, batch_loss=0.654537]  


[TEST  009] loss=0.8127299547195435


[TRAIN 010]: 100%|██████████| 1759/1759 [00:06<00:00, 274.62it/s, batch_loss=0.72396827]


[TEST  010] loss=0.8180336952209473


[TRAIN 011]: 100%|██████████| 1759/1759 [00:06<00:00, 269.97it/s, batch_loss=0.7559954] 


[TEST  011] loss=0.8219563364982605


[TRAIN 012]: 100%|██████████| 1759/1759 [00:06<00:00, 275.49it/s, batch_loss=0.5905463] 


[TEST  012] loss=0.8249061703681946


[TRAIN 013]: 100%|██████████| 1759/1759 [00:06<00:00, 274.06it/s, batch_loss=0.6761055] 


[TEST  013] loss=0.8311976790428162


[TRAIN 014]: 100%|██████████| 1759/1759 [00:06<00:00, 281.03it/s, batch_loss=0.6049273] 


[TEST  014] loss=0.8330389857292175


[TRAIN 015]: 100%|██████████| 1759/1759 [00:06<00:00, 274.09it/s, batch_loss=0.49870542]


[TEST  015] loss=0.8406245112419128


[TRAIN 016]: 100%|██████████| 1759/1759 [00:06<00:00, 269.63it/s, batch_loss=0.60878026]


[TEST  016] loss=0.8426607251167297


[TRAIN 017]: 100%|██████████| 1759/1759 [00:06<00:00, 276.70it/s, batch_loss=0.44367683]


[TEST  017] loss=0.846129298210144


In [6]:
with open("DPLRFwFM_RATING.pickle", "wb") as f:
    pickle.dump(jax.device_get(trainer.best_state_dict), f)