In [1]:
import math
import sys
from typing import Self

import jax
import jax.numpy as jnp
import optax
import pandas as pd
import polars as pl
from flax import nnx
from flax_trainer.loss_fn import mean_squared_error
from flax_trainer.trainer import Trainer
from sklearn.model_selection import train_test_split

from flax_recsys.model.MF import MF

sys.path.append("/workspace")


### READ

In [2]:
rating_df = pl.from_pandas(
    pd.read_csv(
        "../../dataset/ML1M/ml-1m/ratings.dat",
        delimiter="::",
        engine="python",
        header=None,
        names=["user_id", "item_id", "rating", "timestamp"],
    )
)

dataset_df = rating_df.select(
    pl.col("user_id") - pl.col("user_id").min(),
    pl.col("item_id") - pl.col("item_id").min(),
    "rating",
)
user_num, item_num = (
    dataset_df.get_column("user_id").max() + 1,
    dataset_df.get_column("item_id").max() + 1,
)
train_df, valid_df = train_test_split(
    dataset_df, test_size=0.1, random_state=0, shuffle=True
)
dataset_df

user_id,item_id,rating
i64,i64,i64
0,1192,5
0,660,3
0,913,3
0,3407,4
0,2354,5
…,…,…
6039,1090,1
6039,1093,5
6039,561,5
6039,1095,4


### ローダー

In [3]:
class MiniBatchLoader:
    def __init__(self, df_DATA: pl.DataFrame, batch_size: int, seed: int):
        self.df_DATA = df_DATA
        self.batch_size = batch_size
        self.rngs = nnx.Rngs(0)

    def __iter__(self) -> Self:
        """Prepares for batch iteration"""

        # Num. of data
        self.data_size = self.df_DATA.height

        # Num. of batch
        self.batch_num = math.ceil(self.data_size / self.batch_size)

        # Shuffle rows of data
        self.shuffled_indices = jax.random.permutation(self.rngs(), self.data_size)
        self.X_df, self.y_df = (
            self.df_DATA[self.shuffled_indices.tolist(), :].select(
                "user_id", "item_id"
            ),
            self.df_DATA[self.shuffled_indices.tolist(), :].select("rating"),
        )

        # Initialize batch index
        self.batch_index = 0

        return self

    def __len__(self) -> int:
        """Returns the number of batches

        Returns:
            int: The number of batches
        """

        return self.batch_num

    def __next__(self) -> tuple[jax.Array, jax.Array]:
        """Returns data from the current batch

        Returns:
            jax.Array: The input data.
            jax.Array: The target data.
        """

        if self.batch_index >= self.batch_num:
            raise StopIteration()

        else:
            # Extract the {batch_index}-th mini-batch
            start_index = self.batch_size * self.batch_index
            slice_size = min(self.batch_size, (self.data_size - start_index))
            X, y = (
                jax.device_put(
                    self.X_df[start_index : (start_index + slice_size)].to_numpy()
                ),
                jax.device_put(
                    self.y_df[start_index : (start_index + slice_size)].to_numpy()
                ),
            )

            # Update batch index
            self.batch_index += 1

            return X, y

### 評価器

In [4]:
from flax_trainer.evaluator import BaseEvaluator


class Evaluator(BaseEvaluator):
    def __init__(self, df_DATA: pl.DataFrame, batch_size: int):
        self.df_DATA = df_DATA
        self.batch_size = batch_size

    def evaluate(self, model: nnx.Module) -> tuple[float, dict[str, float]]:
        pred_y, true_y = [], []

        X, true_y = (
            jax.device_put(self.df_DATA.select("user_id", "item_id").to_numpy()),
            jax.device_put(self.df_DATA.select("rating").to_numpy()),
        )
        pred_y = nnx.jit(model)(X)

        mse = jnp.mean((pred_y - true_y) ** 2)
        rmse = float(jnp.sqrt(mse))

        return rmse, {"rmse": rmse}

### 学習

In [5]:
model = MF(user_num=user_num, item_num=item_num, embed_dim=50, rngs=nnx.Rngs(0))
loader = MiniBatchLoader(df_DATA=train_df, batch_size=512, seed=0)
evaluator = Evaluator(valid_df, batch_size=512)

trainer = Trainer(
    model=model,
    optimizer=optax.adamw(learning_rate=0.001, weight_decay=0.001),
    train_loader=loader,
    loss_fn=mean_squared_error,
    test_evaluator=evaluator,
    early_stopping_patience=10,
    epoch_num=64,
)
trainer = trainer.fit()

[TEST  000] loss=3.7516701221466064


[TRAIN 001]: 100%|██████████| 1759/1759 [00:05<00:00, 296.14it/s, batch_loss=3.4941676] 


[TEST  001] loss=1.843126654624939


[TRAIN 002]: 100%|██████████| 1759/1759 [00:05<00:00, 325.82it/s, batch_loss=1.0742438] 


[TEST  002] loss=0.9685046076774597


[TRAIN 003]: 100%|██████████| 1759/1759 [00:05<00:00, 324.97it/s, batch_loss=0.66448736]


[TEST  003] loss=0.9224202036857605


[TRAIN 004]: 100%|██████████| 1759/1759 [00:05<00:00, 319.29it/s, batch_loss=0.6056271] 


[TEST  004] loss=0.9040776491165161


[TRAIN 005]: 100%|██████████| 1759/1759 [00:05<00:00, 312.86it/s, batch_loss=0.6473345] 


[TEST  005] loss=0.8928287625312805


[TRAIN 006]: 100%|██████████| 1759/1759 [00:05<00:00, 316.75it/s, batch_loss=0.5314449] 


[TEST  006] loss=0.8860964179039001


[TRAIN 007]: 100%|██████████| 1759/1759 [00:05<00:00, 318.95it/s, batch_loss=0.5306775] 


[TEST  007] loss=0.8851304650306702


[TRAIN 008]: 100%|██████████| 1759/1759 [00:06<00:00, 281.81it/s, batch_loss=0.8056063] 


[TEST  008] loss=0.8868416547775269


[TRAIN 009]: 100%|██████████| 1759/1759 [00:06<00:00, 279.02it/s, batch_loss=0.46339297]


[TEST  009] loss=0.8906249403953552


[TRAIN 010]: 100%|██████████| 1759/1759 [00:05<00:00, 311.85it/s, batch_loss=0.5903769] 


[TEST  010] loss=0.8968890309333801


[TRAIN 011]: 100%|██████████| 1759/1759 [00:05<00:00, 314.84it/s, batch_loss=0.6538729] 


[TEST  011] loss=0.9015406370162964


[TRAIN 012]: 100%|██████████| 1759/1759 [00:05<00:00, 315.65it/s, batch_loss=0.47154295]


[TEST  012] loss=0.9077015519142151


[TRAIN 013]: 100%|██████████| 1759/1759 [00:04<00:00, 356.87it/s, batch_loss=0.53116924]


[TEST  013] loss=0.9140112996101379


[TRAIN 014]: 100%|██████████| 1759/1759 [00:04<00:00, 380.02it/s, batch_loss=0.48564133]


[TEST  014] loss=0.9191444516181946


[TRAIN 015]: 100%|██████████| 1759/1759 [00:04<00:00, 376.56it/s, batch_loss=0.33390576]


[TEST  015] loss=0.9264206886291504


[TRAIN 016]: 100%|██████████| 1759/1759 [00:04<00:00, 393.13it/s, batch_loss=0.55270416]


[TEST  016] loss=0.9326691627502441


[TRAIN 017]: 100%|██████████| 1759/1759 [00:05<00:00, 318.77it/s, batch_loss=0.38127318]


[TEST  017] loss=0.9374197125434875
