In [1]:
import mlflow
import optax
import pandas as pd
import polars as pl
from flax import nnx
from flax_trainer.loss_fn import mean_squared_error
from flax_trainer.trainer import Trainer
from sklearn.model_selection import train_test_split

from flax_recsys.encoder import ColumnEncoder
from flax_recsys.evaluator import GeneralRatingEvaluator
from flax_recsys.loader import GeneralRatingLoader
from flax_recsys.model.FM import FM

### READ

In [2]:
dataset_df = pl.from_pandas(
    pd.read_csv(
        "../../dataset/ML1M/ml-1m/ratings.dat",
        delimiter="::",
        engine="python",
        header=None,
        names=["user_id", "item_id", "rating", "timestamp"],
    )
)
display(dataset_df)

column_encoder = ColumnEncoder(
    user_id="user_id",
    item_id="item_id",
    rating="rating",
)

train_df, valid_df = train_test_split(
    dataset_df, test_size=0.1, random_state=0, shuffle=True
)

user_id,item_id,rating,timestamp
i64,i64,i64,i64
1,1193,5,978300760
1,661,3,978302109
1,914,3,978301968
1,3408,4,978300275
1,2355,5,978824291
…,…,…,…
6040,1091,1,956716541
6040,1094,5,956704887
6040,562,5,956704746
6040,1096,4,956715648


In [3]:
categorical_X, numerical_X, y = column_encoder.fit_transform(train_df)
loader = GeneralRatingLoader(
    categorical_X=categorical_X,
    numerical_X=numerical_X,
    y=y,
    batch_size=512,
    rngs=nnx.Rngs(0),
)

categorical_X, numerical_X, y = column_encoder.transform(valid_df)
evaluator = GeneralRatingEvaluator(
    categorical_X=categorical_X, numerical_X=numerical_X, y=y
)

### 学習

In [4]:
model = FM(
    categorical_feature_cardinalities=list(column_encoder.cardinality_map.values()),
    numerical_feature_num=column_encoder.numerical_column_num,
    embed_dim=30,
    rngs=nnx.Rngs(0),
)

In [5]:
mlflow.set_tracking_uri(uri="http://localhost:8080")
mlflow.set_experiment("MF_RATING")

with mlflow.start_run() as run:
    trainer = Trainer(
        model=model,
        optimizer=optax.adamw(learning_rate=0.001, weight_decay=0.001),
        train_loader=loader,
        loss_fn=mean_squared_error,
        valid_evaluator=evaluator,
        early_stopping_patience=10,
        epoch_num=64,
        active_run=run,
    )
    trainer = trainer.fit()

[VALID 000]: loss=12.460426330566406, metrics={'mse': 12.460426330566406}


[TRAIN 001]: 100%|██████████| 1759/1759 [00:03<00:00, 519.32it/s, batch_loss=2.9505997]


[VALID 001]: loss=2.994792938232422, metrics={'mse': 2.994792938232422}


[TRAIN 002]: 100%|██████████| 1759/1759 [00:03<00:00, 498.22it/s, batch_loss=1.4715427]


[VALID 002]: loss=1.2534188032150269, metrics={'mse': 1.2534188032150269}


[TRAIN 003]: 100%|██████████| 1759/1759 [00:03<00:00, 458.82it/s, batch_loss=0.7310248] 


[VALID 003]: loss=0.929356038570404, metrics={'mse': 0.929356038570404}


[TRAIN 004]: 100%|██████████| 1759/1759 [00:03<00:00, 496.58it/s, batch_loss=0.6395217] 


[VALID 004]: loss=0.8532370924949646, metrics={'mse': 0.8532370924949646}


[TRAIN 005]: 100%|██████████| 1759/1759 [00:03<00:00, 501.65it/s, batch_loss=0.67285424]


[VALID 005]: loss=0.8243317008018494, metrics={'mse': 0.8243317008018494}


[TRAIN 006]: 100%|██████████| 1759/1759 [00:03<00:00, 500.06it/s, batch_loss=0.52649516]


[VALID 006]: loss=0.8082016706466675, metrics={'mse': 0.8082016706466675}


[TRAIN 007]: 100%|██████████| 1759/1759 [00:03<00:00, 495.48it/s, batch_loss=0.592364]  


[VALID 007]: loss=0.8006784915924072, metrics={'mse': 0.8006784915924072}


[TRAIN 008]: 100%|██████████| 1759/1759 [00:03<00:00, 497.46it/s, batch_loss=0.81747013]


[VALID 008]: loss=0.798324704170227, metrics={'mse': 0.798324704170227}


[TRAIN 009]: 100%|██████████| 1759/1759 [00:03<00:00, 498.15it/s, batch_loss=0.5746051] 


[VALID 009]: loss=0.800514817237854, metrics={'mse': 0.800514817237854}


[TRAIN 010]: 100%|██████████| 1759/1759 [00:03<00:00, 505.58it/s, batch_loss=0.7409642] 


[VALID 010]: loss=0.804857075214386, metrics={'mse': 0.804857075214386}


[TRAIN 011]: 100%|██████████| 1759/1759 [00:03<00:00, 481.62it/s, batch_loss=0.7569204] 


[VALID 011]: loss=0.8092218041419983, metrics={'mse': 0.8092218041419983}


[TRAIN 012]: 100%|██████████| 1759/1759 [00:03<00:00, 482.58it/s, batch_loss=0.45722502]


[VALID 012]: loss=0.8143149614334106, metrics={'mse': 0.8143149614334106}


[TRAIN 013]: 100%|██████████| 1759/1759 [00:03<00:00, 491.13it/s, batch_loss=0.5932075] 


[VALID 013]: loss=0.8201317191123962, metrics={'mse': 0.8201317191123962}


[TRAIN 014]: 100%|██████████| 1759/1759 [00:03<00:00, 480.43it/s, batch_loss=0.64774686]


[VALID 014]: loss=0.8260868787765503, metrics={'mse': 0.8260868787765503}


[TRAIN 015]: 100%|██████████| 1759/1759 [00:03<00:00, 481.48it/s, batch_loss=0.46836382]


[VALID 015]: loss=0.8319261074066162, metrics={'mse': 0.8319261074066162}


[TRAIN 016]: 100%|██████████| 1759/1759 [00:03<00:00, 515.08it/s, batch_loss=0.62902176]


[VALID 016]: loss=0.8379749655723572, metrics={'mse': 0.8379749655723572}


[TRAIN 017]: 100%|██████████| 1759/1759 [00:03<00:00, 484.69it/s, batch_loss=0.50265884]


[VALID 017]: loss=0.8432089686393738, metrics={'mse': 0.8432089686393738}


[TRAIN 018]: 100%|██████████| 1759/1759 [00:03<00:00, 536.72it/s, batch_loss=0.44016466]


[VALID 018]: loss=0.8491826057434082, metrics={'mse': 0.8491826057434082}
🏃 View run bemused-shoat-685 at: http://localhost:8080/#/experiments/538067800959619176/runs/23f06a24b4134db5af487b660b768459
🧪 View experiment at: http://localhost:8080/#/experiments/538067800959619176
