In [1]:
import mlflow
import optax
import pandas as pd
import polars as pl
from flax import nnx
from flax_trainer.loss_fn import mean_squared_error
from flax_trainer.trainer import Trainer
from sklearn.model_selection import train_test_split

from flax_recsys.encoder import ColumnEncoder
from flax_recsys.evaluator import GeneralRatingEvaluator
from flax_recsys.loader import GeneralRatingLoader
from flax_recsys.model.DPLRFwFM import DPLRFwFM

### READ

In [2]:
rating_df, movie_df, user_df = (
    pl.from_pandas(
        pd.read_csv(
            "../../dataset/ML1M/ml-1m/ratings.dat",
            delimiter="::",
            engine="python",
            header=None,
            names=["user_id", "item_id", "rating", "timestamp"],
        )
    ),
    pl.from_pandas(
        pd.read_csv(
            "../../dataset/ML1M/ml-1m/movies.dat",
            delimiter="::",
            engine="python",
            header=None,
            names=["item_id", "title", "genres"],
            encoding="ISO-8859-1",
        )
    ).with_columns(pl.col("genres").str.split("|")),
    pl.from_pandas(
        pd.read_csv(
            "../../dataset/ML1M/ml-1m/users.dat",
            delimiter="::",
            engine="python",
            header=None,
            names=["user_id", "gender", "age", "occupation", "zip_code"],
            # encoding="ISO-8859-1",
        )
    ),
)

dataset_df = (
    rating_df.join(movie_df, how="left", on="item_id")
    .join(user_df, how="left", on="user_id")
    .select(pl.all().exclude("rating", "genres"), pl.col("rating"))
    .with_columns(pl.all().exclude("rating").rank("dense") - 1)
)
display(dataset_df)

column_encoder = ColumnEncoder(
    user_id="user_id",
    item_id="item_id",
    timestamp="timestamp",
    one_hot=["title", "gender", "age", "occupation", "zip_code"],
    rating="rating",
)

train_df, valid_df = train_test_split(
    dataset_df, test_size=0.1, random_state=0, shuffle=True
)

user_id,item_id,timestamp,title,gender,age,occupation,zip_code,rating
u32,u32,u32,u32,u32,u32,u32,u32,i64
0,1104,397442,2452,0,0,10,1588,5
0,639,397457,1739,0,0,10,1588,3
0,853,397454,2289,0,0,10,1588,3
0,3177,397440,1054,0,0,10,1588,4
0,2162,400250,557,0,0,10,1588,5
…,…,…,…,…,…,…,…,…
6039,1019,382,3574,1,2,6,466,1
6039,1022,21,814,1,2,6,466,5
6039,548,17,3578,1,2,6,466,5
6039,1024,346,3090,1,2,6,466,4


In [3]:
categorical_X, numerical_X, y = column_encoder.fit_transform(train_df)
loader = GeneralRatingLoader(
    categorical_X=categorical_X,
    numerical_X=numerical_X,
    y=y,
    batch_size=512,
    rngs=nnx.Rngs(0),
)

categorical_X, numerical_X, y = column_encoder.transform(valid_df)
evaluator = GeneralRatingEvaluator(
    categorical_X=categorical_X, numerical_X=numerical_X, y=y
)

### 学習

In [4]:
model = DPLRFwFM(
    categorical_feature_cardinalities=list(column_encoder.cardinality_map.values()),
    numerical_feature_num=column_encoder.numerical_column_num,
    embed_dim=30,
    rho=3,
    rngs=nnx.Rngs(0),
)

In [5]:
mlflow.set_tracking_uri(uri="http://localhost:8080")
mlflow.set_experiment("DPLRFwFM_RATING")

with mlflow.start_run() as run:
    trainer = Trainer(
        model=model,
        optimizer=optax.adamw(learning_rate=0.001, weight_decay=0.001),
        train_loader=loader,
        loss_fn=mean_squared_error,
        valid_evaluator=evaluator,
        early_stopping_patience=10,
        epoch_num=64,
        active_run=run,
    )
    trainer = trainer.fit()

[VALID 000]: loss=79.80770111083984, metrics={'mse': 79.80770111083984}


[TRAIN 001]: 100%|██████████| 1759/1759 [00:10<00:00, 169.36it/s, batch_loss=0.9553055] 


[VALID 001]: loss=1.0663902759552002, metrics={'mse': 1.0663902759552002}


[TRAIN 002]: 100%|██████████| 1759/1759 [00:10<00:00, 165.73it/s, batch_loss=1.0358295] 


[VALID 002]: loss=0.8617352843284607, metrics={'mse': 0.8617352843284607}


[TRAIN 003]: 100%|██████████| 1759/1759 [00:10<00:00, 160.11it/s, batch_loss=0.7073412] 


[VALID 003]: loss=0.8436973690986633, metrics={'mse': 0.8436973690986633}


[TRAIN 004]: 100%|██████████| 1759/1759 [00:10<00:00, 170.88it/s, batch_loss=0.6688315] 


[VALID 004]: loss=0.8339711427688599, metrics={'mse': 0.8339711427688599}


[TRAIN 005]: 100%|██████████| 1759/1759 [00:11<00:00, 157.92it/s, batch_loss=0.7168115] 


[VALID 005]: loss=0.8242194056510925, metrics={'mse': 0.8242194056510925}


[TRAIN 006]: 100%|██████████| 1759/1759 [00:10<00:00, 162.50it/s, batch_loss=0.56539845]


[VALID 006]: loss=0.811716616153717, metrics={'mse': 0.811716616153717}


[TRAIN 007]: 100%|██████████| 1759/1759 [00:10<00:00, 160.88it/s, batch_loss=0.6493333] 


[VALID 007]: loss=0.8061832785606384, metrics={'mse': 0.8061832785606384}


[TRAIN 008]: 100%|██████████| 1759/1759 [00:11<00:00, 154.11it/s, batch_loss=0.92218494]


[VALID 008]: loss=0.8108053803443909, metrics={'mse': 0.8108053803443909}


[TRAIN 009]: 100%|██████████| 1759/1759 [00:11<00:00, 154.95it/s, batch_loss=0.654537]  


[VALID 009]: loss=0.8127299547195435, metrics={'mse': 0.8127299547195435}


[TRAIN 010]: 100%|██████████| 1759/1759 [00:11<00:00, 157.76it/s, batch_loss=0.72396827]


[VALID 010]: loss=0.8180336952209473, metrics={'mse': 0.8180336952209473}


[TRAIN 011]: 100%|██████████| 1759/1759 [00:11<00:00, 153.68it/s, batch_loss=0.7559954] 


[VALID 011]: loss=0.8219563364982605, metrics={'mse': 0.8219563364982605}


[TRAIN 012]: 100%|██████████| 1759/1759 [00:11<00:00, 153.70it/s, batch_loss=0.5905463] 


[VALID 012]: loss=0.8249061703681946, metrics={'mse': 0.8249061703681946}


[TRAIN 013]: 100%|██████████| 1759/1759 [00:11<00:00, 158.21it/s, batch_loss=0.6761055] 


[VALID 013]: loss=0.8311976790428162, metrics={'mse': 0.8311976790428162}


[TRAIN 014]: 100%|██████████| 1759/1759 [00:11<00:00, 157.29it/s, batch_loss=0.6049273] 


[VALID 014]: loss=0.8330389857292175, metrics={'mse': 0.8330389857292175}


[TRAIN 015]: 100%|██████████| 1759/1759 [00:12<00:00, 141.58it/s, batch_loss=0.49870542]


[VALID 015]: loss=0.8406245112419128, metrics={'mse': 0.8406245112419128}


[TRAIN 016]: 100%|██████████| 1759/1759 [00:11<00:00, 158.64it/s, batch_loss=0.60878026]


[VALID 016]: loss=0.8426607251167297, metrics={'mse': 0.8426607251167297}


[TRAIN 017]: 100%|██████████| 1759/1759 [00:11<00:00, 154.82it/s, batch_loss=0.44367683]


[VALID 017]: loss=0.846129298210144, metrics={'mse': 0.846129298210144}
🏃 View run bouncy-gnu-90 at: http://localhost:8080/#/experiments/307500353090437060/runs/daacf95739e64bd38ce7a7687208e30b
🧪 View experiment at: http://localhost:8080/#/experiments/307500353090437060
