In [1]:
import optax
import pandas as pd
import polars as pl
import jax
import pickle
from flax import nnx
from flax_trainer.evaluator import RegressionEvaluator
from flax_trainer.loader import MiniBatchLoader
from flax_trainer.loss_fn import mean_squared_error
from flax_trainer.trainer import Trainer
from sklearn.model_selection import train_test_split

from flax_recsys.model.FM import FM

### READ

In [2]:
rating_df, movie_df, user_df = (
    pl.from_pandas(
        pd.read_csv(
            "../../dataset/ML1M/ml-1m/ratings.dat",
            delimiter="::",
            engine="python",
            header=None,
            names=["user_id", "item_id", "rating", "timestamp"],
        )
    ),
    pl.from_pandas(
        pd.read_csv(
            "../../dataset/ML1M/ml-1m/movies.dat",
            delimiter="::",
            engine="python",
            header=None,
            names=["item_id", "title", "genres"],
            encoding="ISO-8859-1",
        )
    ).with_columns(pl.col("genres").str.split("|")),
    pl.from_pandas(
        pd.read_csv(
            "../../dataset/ML1M/ml-1m/users.dat",
            delimiter="::",
            engine="python",
            header=None,
            names=["user_id", "gender", "age", "occupation", "zip_code"],
            # encoding="ISO-8859-1",
        )
    ),
)

dataset_df = (
    rating_df.join(movie_df, how="left", on="item_id")
    .join(user_df, how="left", on="user_id")
    .select(pl.all().exclude("rating", "genres"), pl.col("rating"))
    .with_columns(pl.all().exclude("rating").rank("dense") - 1)
)

user_categorical_feature_indices = [0, 4, 5, 6, 7]
user_categorical_feature_cardinalities = [
    dataset_df[:, idx].n_unique() for idx in user_categorical_feature_indices
]
user_numerical_feature_indices = []
item_categorical_feature_indices = [1, 3]
item_categorical_feature_cardinalities = [
    dataset_df[:, idx].n_unique() for idx in item_categorical_feature_indices
]
item_numerical_feature_indices = []

train_df, valid_df = train_test_split(
    dataset_df, test_size=0.1, random_state=0, shuffle=True
)
dataset_df

user_id,item_id,timestamp,title,gender,age,occupation,zip_code,rating
u32,u32,u32,u32,u32,u32,u32,u32,i64
0,1104,397442,2452,0,0,10,1588,5
0,639,397457,1739,0,0,10,1588,3
0,853,397454,2289,0,0,10,1588,3
0,3177,397440,1054,0,0,10,1588,4
0,2162,400250,557,0,0,10,1588,5
…,…,…,…,…,…,…,…,…
6039,1019,382,3574,1,2,6,466,1
6039,1022,21,814,1,2,6,466,5
6039,548,17,3578,1,2,6,466,5
6039,1024,346,3090,1,2,6,466,4


### 学習

In [3]:
model = FM(
    user_categorical_feature_indices=user_categorical_feature_indices,
    user_categorical_feature_cardinalities=user_categorical_feature_cardinalities,
    user_numerical_feature_indices=user_numerical_feature_indices,
    item_categorical_feature_indices=item_categorical_feature_indices,
    item_categorical_feature_cardinalities=item_categorical_feature_cardinalities,
    item_numerical_feature_indices=item_numerical_feature_indices,
    embed_dim=30,
    rngs=nnx.Rngs(0),
)

In [4]:
loader = MiniBatchLoader(dataset_df=train_df, batch_size=512, rngs=nnx.Rngs(0))
evaluator = RegressionEvaluator(dataset_df=valid_df)

trainer = Trainer(
    model=model,
    optimizer=optax.adamw(learning_rate=0.001, weight_decay=0.001),
    train_loader=loader,
    loss_fn=mean_squared_error,
    test_evaluator=evaluator,
    early_stopping_patience=10,
    epoch_num=64,
)
trainer = trainer.fit()

[TEST  000] loss=9.766195297241211


[TRAIN 001]: 100%|██████████| 1759/1759 [00:06<00:00, 293.12it/s, batch_loss=0.99465626]


[TEST  001] loss=0.9966132044792175


[TRAIN 002]: 100%|██████████| 1759/1759 [00:05<00:00, 310.50it/s, batch_loss=0.9755437] 


[TEST  002] loss=0.9362316131591797


[TRAIN 003]: 100%|██████████| 1759/1759 [00:05<00:00, 324.44it/s, batch_loss=0.69498444]


[TEST  003] loss=0.9133365154266357


[TRAIN 004]: 100%|██████████| 1759/1759 [00:05<00:00, 332.51it/s, batch_loss=0.6130253] 


[TEST  004] loss=0.8884686827659607


[TRAIN 005]: 100%|██████████| 1759/1759 [00:05<00:00, 315.73it/s, batch_loss=0.7619328] 


[TEST  005] loss=0.8693138957023621


[TRAIN 006]: 100%|██████████| 1759/1759 [00:05<00:00, 335.37it/s, batch_loss=0.5505578] 


[TEST  006] loss=0.853171169757843


[TRAIN 007]: 100%|██████████| 1759/1759 [00:05<00:00, 319.13it/s, batch_loss=0.66579396]


[TEST  007] loss=0.8479346632957458


[TRAIN 008]: 100%|██████████| 1759/1759 [00:05<00:00, 319.16it/s, batch_loss=0.9705479] 


[TEST  008] loss=0.8462231755256653


[TRAIN 009]: 100%|██████████| 1759/1759 [00:06<00:00, 285.93it/s, batch_loss=0.599929]  


[TEST  009] loss=0.8461006283760071


[TRAIN 010]: 100%|██████████| 1759/1759 [00:05<00:00, 333.09it/s, batch_loss=0.6232314] 


[TEST  010] loss=0.8494939804077148


[TRAIN 011]: 100%|██████████| 1759/1759 [00:05<00:00, 328.36it/s, batch_loss=0.69625044]


[TEST  011] loss=0.8531635999679565


[TRAIN 012]: 100%|██████████| 1759/1759 [00:05<00:00, 322.31it/s, batch_loss=0.5215558] 


[TEST  012] loss=0.8549650311470032


[TRAIN 013]: 100%|██████████| 1759/1759 [00:05<00:00, 336.88it/s, batch_loss=0.67304456]


[TEST  013] loss=0.8598847389221191


[TRAIN 014]: 100%|██████████| 1759/1759 [00:05<00:00, 324.74it/s, batch_loss=0.5559666] 


[TEST  014] loss=0.8595724701881409


[TRAIN 015]: 100%|██████████| 1759/1759 [00:05<00:00, 342.16it/s, batch_loss=0.3696676] 


[TEST  015] loss=0.8666527271270752


[TRAIN 016]: 100%|██████████| 1759/1759 [00:05<00:00, 327.27it/s, batch_loss=0.7682887] 


[TEST  016] loss=0.871368408203125


[TRAIN 017]: 100%|██████████| 1759/1759 [00:05<00:00, 310.20it/s, batch_loss=0.48974738]


[TEST  017] loss=0.8742348551750183


[TRAIN 018]: 100%|██████████| 1759/1759 [00:05<00:00, 314.25it/s, batch_loss=0.65188706]


[TEST  018] loss=0.880210816860199


[TRAIN 019]: 100%|██████████| 1759/1759 [00:05<00:00, 322.08it/s, batch_loss=0.48766622]


[TEST  019] loss=0.8831638693809509


In [5]:
with open('FM_RATING.pickle', 'wb') as f:
    pickle.dump(jax.device_get(trainer.best_state_dict), f)