In [1]:
import pickle

import jax
import optax
import pandas as pd
import polars as pl
from flax import nnx
from flax_trainer.loss_fn import mean_squared_error
from flax_trainer.trainer import Trainer
from sklearn.model_selection import train_test_split

from flax_recsys.encoder import ColumnEncoder
from flax_recsys.evaluator import GeneralRatingEvaluator
from flax_recsys.loader import GeneralRatingLoader
from flax_recsys.model.FM import FM

### READ

In [2]:
rating_df, movie_df, user_df = (
    pl.from_pandas(
        pd.read_csv(
            "../../dataset/ML1M/ml-1m/ratings.dat",
            delimiter="::",
            engine="python",
            header=None,
            names=["user_id", "item_id", "rating", "timestamp"],
        )
    ),
    pl.from_pandas(
        pd.read_csv(
            "../../dataset/ML1M/ml-1m/movies.dat",
            delimiter="::",
            engine="python",
            header=None,
            names=["item_id", "title", "genres"],
            encoding="ISO-8859-1",
        )
    ).with_columns(pl.col("genres").str.split("|")),
    pl.from_pandas(
        pd.read_csv(
            "../../dataset/ML1M/ml-1m/users.dat",
            delimiter="::",
            engine="python",
            header=None,
            names=["user_id", "gender", "age", "occupation", "zip_code"],
            # encoding="ISO-8859-1",
        )
    ),
)

dataset_df = (
    rating_df.join(movie_df, how="left", on="item_id")
    .join(user_df, how="left", on="user_id")
    .select(pl.all().exclude("rating", "genres"), pl.col("rating"))
    .with_columns(pl.all().exclude("rating").rank("dense") - 1)
)
display(dataset_df)

column_encoder = ColumnEncoder(
    user_id="user_id",
    item_id="item_id",
    timestamp="timestamp",
    one_hot=["title", "gender", "age", "occupation", "zip_code"],
    rating="rating",
)

train_df, valid_df = train_test_split(
    dataset_df, test_size=0.1, random_state=0, shuffle=True
)

user_id,item_id,timestamp,title,gender,age,occupation,zip_code,rating
u32,u32,u32,u32,u32,f64,u32,u32,i64
0,1104,397442,2452,0,0.01,10,1588,5
0,639,397457,1739,0,0.01,10,1588,3
0,853,397454,2289,0,0.01,10,1588,3
0,3177,397440,1054,0,0.01,10,1588,4
0,2162,400250,557,0,0.01,10,1588,5
…,…,…,…,…,…,…,…,…
6039,1019,382,3574,1,2.01,6,466,1
6039,1022,21,814,1,2.01,6,466,5
6039,548,17,3578,1,2.01,6,466,5
6039,1024,346,3090,1,2.01,6,466,4


In [3]:
categorical_X, numerical_X, y = column_encoder.fit_transform(train_df)
loader = GeneralRatingLoader(
    categorical_X=categorical_X,
    numerical_X=numerical_X,
    y=y,
    batch_size=512,
    rngs=nnx.Rngs(0),
)

categorical_X, numerical_X, y = column_encoder.transform(valid_df)
evaluator = GeneralRatingEvaluator(
    categorical_X=categorical_X, numerical_X=numerical_X, y=y
)

### 学習

In [4]:
model = FM(
    categorical_feature_cardinalities=list(column_encoder.cardinality_map.values()),
    numerical_feature_num=column_encoder.numerical_column_num,
    embed_dim=30,
    rngs=nnx.Rngs(0),
)

In [5]:
trainer = Trainer(
    model=model,
    optimizer=optax.adamw(learning_rate=0.001, weight_decay=0.001),
    train_loader=loader,
    loss_fn=mean_squared_error,
    test_evaluator=evaluator,
    early_stopping_patience=10,
    epoch_num=64,
)
trainer = trainer.fit()

[TEST  000] loss=13.318453788757324


[TRAIN 001]: 100%|██████████| 1759/1759 [00:05<00:00, 323.14it/s, batch_loss=1.0366287] 


[TEST  001] loss=1.0412428379058838


[TRAIN 002]: 100%|██████████| 1759/1759 [00:06<00:00, 290.43it/s, batch_loss=1.0487803] 


[TEST  002] loss=0.9677133560180664


[TRAIN 003]: 100%|██████████| 1759/1759 [00:05<00:00, 302.91it/s, batch_loss=0.6905799] 


[TEST  003] loss=0.9403350353240967


[TRAIN 004]: 100%|██████████| 1759/1759 [00:05<00:00, 336.53it/s, batch_loss=0.741958]  


[TEST  004] loss=0.9134047627449036


[TRAIN 005]: 100%|██████████| 1759/1759 [00:05<00:00, 316.78it/s, batch_loss=0.83516103]


[TEST  005] loss=0.8906211853027344


[TRAIN 006]: 100%|██████████| 1759/1759 [00:05<00:00, 334.90it/s, batch_loss=0.6897791] 


[TEST  006] loss=0.86864173412323


[TRAIN 007]: 100%|██████████| 1759/1759 [00:05<00:00, 348.90it/s, batch_loss=0.6102255] 


[TEST  007] loss=0.8572801351547241


[TRAIN 008]: 100%|██████████| 1759/1759 [00:05<00:00, 350.75it/s, batch_loss=0.9271427] 


[TEST  008] loss=0.8529515266418457


[TRAIN 009]: 100%|██████████| 1759/1759 [00:04<00:00, 355.70it/s, batch_loss=0.69459283]


[TEST  009] loss=0.85091233253479


[TRAIN 010]: 100%|██████████| 1759/1759 [00:04<00:00, 353.42it/s, batch_loss=0.67583334]


[TEST  010] loss=0.8535036444664001


[TRAIN 011]: 100%|██████████| 1759/1759 [00:05<00:00, 335.55it/s, batch_loss=0.69625974]


[TEST  011] loss=0.8537264466285706


[TRAIN 012]: 100%|██████████| 1759/1759 [00:05<00:00, 344.16it/s, batch_loss=0.4998692] 


[TEST  012] loss=0.8541572690010071


[TRAIN 013]: 100%|██████████| 1759/1759 [00:05<00:00, 328.54it/s, batch_loss=0.71430016]


[TEST  013] loss=0.8562111258506775


[TRAIN 014]: 100%|██████████| 1759/1759 [00:05<00:00, 301.59it/s, batch_loss=0.6343296] 


[TEST  014] loss=0.8596003651618958


[TRAIN 015]: 100%|██████████| 1759/1759 [00:05<00:00, 342.07it/s, batch_loss=0.4463875] 


[TEST  015] loss=0.8659738898277283


[TRAIN 016]: 100%|██████████| 1759/1759 [00:05<00:00, 332.46it/s, batch_loss=0.71196103]


[TEST  016] loss=0.8689772486686707


[TRAIN 017]: 100%|██████████| 1759/1759 [00:04<00:00, 365.30it/s, batch_loss=0.53169096]


[TEST  017] loss=0.8693222999572754


[TRAIN 018]: 100%|██████████| 1759/1759 [00:05<00:00, 300.01it/s, batch_loss=0.5354732] 


[TEST  018] loss=0.8748645782470703


[TRAIN 019]: 100%|██████████| 1759/1759 [00:05<00:00, 331.36it/s, batch_loss=0.50476164]


[TEST  019] loss=0.8793476819992065


In [6]:
with open("FM_RATING.pickle", "wb") as f:
    pickle.dump(jax.device_get(trainer.best_state_dict), f)