In [1]:
import mlflow
import optax
import pandas as pd
import polars as pl
from flax import nnx
from flax_trainer.loss_fn import mean_squared_error
from flax_trainer.trainer import Trainer
from sklearn.model_selection import train_test_split

from flax_recsys.encoder import ColumnEncoder
from flax_recsys.evaluator import GeneralRatingEvaluator
from flax_recsys.loader import GeneralRatingLoader
from flax_recsys.model.FM import FM

### READ

In [2]:
rating_df, movie_df, user_df = (
    pl.from_pandas(
        pd.read_csv(
            "../../dataset/ML1M/ml-1m/ratings.dat",
            delimiter="::",
            engine="python",
            header=None,
            names=["user_id", "item_id", "rating", "timestamp"],
        )
    ),
    pl.from_pandas(
        pd.read_csv(
            "../../dataset/ML1M/ml-1m/movies.dat",
            delimiter="::",
            engine="python",
            header=None,
            names=["item_id", "title", "genres"],
            encoding="ISO-8859-1",
        )
    ).with_columns(pl.col("genres").str.split("|")),
    pl.from_pandas(
        pd.read_csv(
            "../../dataset/ML1M/ml-1m/users.dat",
            delimiter="::",
            engine="python",
            header=None,
            names=["user_id", "gender", "age", "occupation", "zip_code"],
            # encoding="ISO-8859-1",
        )
    ),
)

dataset_df = (
    rating_df.join(movie_df, how="left", on="item_id")
    .join(user_df, how="left", on="user_id")
    .select(pl.all().exclude("rating", "genres"), pl.col("rating"))
    .with_columns(pl.all().exclude("rating").rank("dense") - 1)
)
display(dataset_df)

column_encoder = ColumnEncoder(
    user_id="user_id",
    item_id="item_id",
    timestamp="timestamp",
    one_hot=["title", "gender", "age", "occupation", "zip_code"],
    rating="rating",
)

train_df, valid_df = train_test_split(
    dataset_df, test_size=0.1, random_state=0, shuffle=True
)

user_id,item_id,timestamp,title,gender,age,occupation,zip_code,rating
u32,u32,u32,u32,u32,u32,u32,u32,i64
0,1104,397442,2452,0,0,10,1588,5
0,639,397457,1739,0,0,10,1588,3
0,853,397454,2289,0,0,10,1588,3
0,3177,397440,1054,0,0,10,1588,4
0,2162,400250,557,0,0,10,1588,5
…,…,…,…,…,…,…,…,…
6039,1019,382,3574,1,2,6,466,1
6039,1022,21,814,1,2,6,466,5
6039,548,17,3578,1,2,6,466,5
6039,1024,346,3090,1,2,6,466,4


In [3]:
categorical_X, numerical_X, y = column_encoder.fit_transform(train_df)
loader = GeneralRatingLoader(
    categorical_X=categorical_X,
    numerical_X=numerical_X,
    y=y,
    batch_size=512,
    rngs=nnx.Rngs(0),
)

categorical_X, numerical_X, y = column_encoder.transform(valid_df)
evaluator = GeneralRatingEvaluator(
    categorical_X=categorical_X, numerical_X=numerical_X, y=y
)

### 学習

In [4]:
model = FM(
    categorical_feature_cardinalities=list(column_encoder.cardinality_map.values()),
    numerical_feature_num=column_encoder.numerical_column_num,
    embed_dim=30,
    rngs=nnx.Rngs(0),
)

In [5]:
mlflow.set_tracking_uri(uri="http://localhost:8080")
mlflow.set_experiment("FM_RATING")

with mlflow.start_run() as run:
    trainer = Trainer(
        model=model,
        optimizer=optax.adamw(learning_rate=0.001, weight_decay=0.001),
        train_loader=loader,
        loss_fn=mean_squared_error,
        valid_evaluator=evaluator,
        early_stopping_patience=10,
        epoch_num=64,
        active_run=run,
    )
    trainer = trainer.fit()

[VALID 000]: loss=29.00882911682129, metrics={'mse': 29.00882911682129}


[TRAIN 001]: 100%|██████████| 1759/1759 [00:05<00:00, 324.49it/s, batch_loss=1.0517151] 


[VALID 001]: loss=1.0401686429977417, metrics={'mse': 1.0401686429977417}


[TRAIN 002]: 100%|██████████| 1759/1759 [00:06<00:00, 255.69it/s, batch_loss=1.025557]  


[VALID 002]: loss=0.9682368636131287, metrics={'mse': 0.9682368636131287}


[TRAIN 003]: 100%|██████████| 1759/1759 [00:07<00:00, 226.17it/s, batch_loss=0.69383425]


[VALID 003]: loss=0.9427712559700012, metrics={'mse': 0.9427712559700012}


[TRAIN 004]: 100%|██████████| 1759/1759 [00:08<00:00, 206.41it/s, batch_loss=0.7182006] 


[VALID 004]: loss=0.9177350997924805, metrics={'mse': 0.9177350997924805}


[TRAIN 005]: 100%|██████████| 1759/1759 [00:09<00:00, 181.62it/s, batch_loss=0.721181]  


[VALID 005]: loss=0.8966554403305054, metrics={'mse': 0.8966554403305054}


[TRAIN 006]: 100%|██████████| 1759/1759 [00:09<00:00, 193.97it/s, batch_loss=0.681755]  


[VALID 006]: loss=0.8747289776802063, metrics={'mse': 0.8747289776802063}


[TRAIN 007]: 100%|██████████| 1759/1759 [00:09<00:00, 192.43it/s, batch_loss=0.6378839] 


[VALID 007]: loss=0.8668085932731628, metrics={'mse': 0.8668085932731628}


[TRAIN 008]: 100%|██████████| 1759/1759 [00:09<00:00, 179.25it/s, batch_loss=0.974689]  


[VALID 008]: loss=0.8593384623527527, metrics={'mse': 0.8593384623527527}


[TRAIN 009]: 100%|██████████| 1759/1759 [00:08<00:00, 197.43it/s, batch_loss=0.660937]  


[VALID 009]: loss=0.8569033741950989, metrics={'mse': 0.8569033741950989}


[TRAIN 010]: 100%|██████████| 1759/1759 [00:08<00:00, 202.61it/s, batch_loss=0.6553361] 


[VALID 010]: loss=0.858298122882843, metrics={'mse': 0.858298122882843}


[TRAIN 011]: 100%|██████████| 1759/1759 [00:08<00:00, 197.65it/s, batch_loss=0.79499114]


[VALID 011]: loss=0.8582233190536499, metrics={'mse': 0.8582233190536499}


[TRAIN 012]: 100%|██████████| 1759/1759 [00:08<00:00, 198.81it/s, batch_loss=0.5339409] 


[VALID 012]: loss=0.8587896823883057, metrics={'mse': 0.8587896823883057}


[TRAIN 013]: 100%|██████████| 1759/1759 [00:08<00:00, 208.45it/s, batch_loss=0.6576497] 


[VALID 013]: loss=0.861323356628418, metrics={'mse': 0.861323356628418}


[TRAIN 014]: 100%|██████████| 1759/1759 [00:09<00:00, 186.75it/s, batch_loss=0.6642739] 


[VALID 014]: loss=0.8611533045768738, metrics={'mse': 0.8611533045768738}


[TRAIN 015]: 100%|██████████| 1759/1759 [00:08<00:00, 196.17it/s, batch_loss=0.48589537]


[VALID 015]: loss=0.8649002909660339, metrics={'mse': 0.8649002909660339}


[TRAIN 016]: 100%|██████████| 1759/1759 [00:09<00:00, 180.40it/s, batch_loss=0.654857]  


[VALID 016]: loss=0.8681081533432007, metrics={'mse': 0.8681081533432007}


[TRAIN 017]: 100%|██████████| 1759/1759 [00:09<00:00, 194.04it/s, batch_loss=0.50039285]


[VALID 017]: loss=0.8705926537513733, metrics={'mse': 0.8705926537513733}


[TRAIN 018]: 100%|██████████| 1759/1759 [00:09<00:00, 186.30it/s, batch_loss=0.6440288] 


[VALID 018]: loss=0.8749492168426514, metrics={'mse': 0.8749492168426514}


[TRAIN 019]: 100%|██████████| 1759/1759 [00:09<00:00, 190.21it/s, batch_loss=0.492442]  


[VALID 019]: loss=0.8777192831039429, metrics={'mse': 0.8777192831039429}
🏃 View run crawling-fawn-96 at: http://localhost:8080/#/experiments/417158149682049859/runs/93d16470b130451ab6fe6621c85ac972
🧪 View experiment at: http://localhost:8080/#/experiments/417158149682049859
