In [1]:
import mlflow
import optax
import pandas as pd
import polars as pl
from flax import nnx
from flax_trainer.loss_fn import mean_squared_error
from flax_trainer.trainer import Trainer
from sklearn.model_selection import train_test_split

from flax_recsys.encoder import ColumnEncoder
from flax_recsys.evaluator import GeneralRatingEvaluator
from flax_recsys.loader import GeneralRatingLoader
from flax_recsys.model.FwFM import FwFM

### READ

In [2]:
rating_df, movie_df, user_df = (
    pl.from_pandas(
        pd.read_csv(
            "../../dataset/ML1M/ml-1m/ratings.dat",
            delimiter="::",
            engine="python",
            header=None,
            names=["user_id", "item_id", "rating", "timestamp"],
        )
    ),
    pl.from_pandas(
        pd.read_csv(
            "../../dataset/ML1M/ml-1m/movies.dat",
            delimiter="::",
            engine="python",
            header=None,
            names=["item_id", "title", "genres"],
            encoding="ISO-8859-1",
        )
    ).with_columns(pl.col("genres").str.split("|")),
    pl.from_pandas(
        pd.read_csv(
            "../../dataset/ML1M/ml-1m/users.dat",
            delimiter="::",
            engine="python",
            header=None,
            names=["user_id", "gender", "age", "occupation", "zip_code"],
            # encoding="ISO-8859-1",
        )
    ),
)

dataset_df = (
    rating_df.join(movie_df, how="left", on="item_id")
    .join(user_df, how="left", on="user_id")
    .select(pl.all().exclude("rating", "genres"), pl.col("rating"))
    .with_columns(pl.all().exclude("rating").rank("dense") - 1)
)
display(dataset_df)

column_encoder = ColumnEncoder(
    user_id="user_id",
    item_id="item_id",
    timestamp="timestamp",
    one_hot=["title", "gender", "age", "occupation", "zip_code"],
    rating="rating",
)

train_df, valid_df = train_test_split(
    dataset_df, test_size=0.1, random_state=0, shuffle=True
)

user_id,item_id,timestamp,title,gender,age,occupation,zip_code,rating
u32,u32,u32,u32,u32,u32,u32,u32,i64
0,1104,397442,2452,0,0,10,1588,5
0,639,397457,1739,0,0,10,1588,3
0,853,397454,2289,0,0,10,1588,3
0,3177,397440,1054,0,0,10,1588,4
0,2162,400250,557,0,0,10,1588,5
…,…,…,…,…,…,…,…,…
6039,1019,382,3574,1,2,6,466,1
6039,1022,21,814,1,2,6,466,5
6039,548,17,3578,1,2,6,466,5
6039,1024,346,3090,1,2,6,466,4


In [3]:
categorical_X, numerical_X, y = column_encoder.fit_transform(train_df)
loader = GeneralRatingLoader(
    categorical_X=categorical_X,
    numerical_X=numerical_X,
    y=y,
    batch_size=512,
    rngs=nnx.Rngs(0),
)

categorical_X, numerical_X, y = column_encoder.transform(valid_df)
evaluator = GeneralRatingEvaluator(
    categorical_X=categorical_X, numerical_X=numerical_X, y=y
)

### 学習

In [4]:
model = FwFM(
    categorical_feature_cardinalities=list(column_encoder.cardinality_map.values()),
    numerical_feature_num=column_encoder.numerical_column_num,
    embed_dim=30,
    rngs=nnx.Rngs(0),
)

In [5]:
mlflow.set_tracking_uri(uri="http://localhost:8080")
mlflow.set_experiment("FwFM_RATING")

with mlflow.start_run() as run:
    trainer = Trainer(
        model=model,
        optimizer=optax.adamw(learning_rate=0.001, weight_decay=0.001),
        train_loader=loader,
        loss_fn=mean_squared_error,
        valid_evaluator=evaluator,
        early_stopping_patience=10,
        epoch_num=64,
        active_run=run,
    )
    trainer = trainer.fit()

[VALID 000]: loss=30.17424964904785, metrics={'mse': 30.17424964904785}


[TRAIN 001]: 100%|██████████| 1759/1759 [00:09<00:00, 182.99it/s, batch_loss=0.87051815]


[VALID 001]: loss=0.9123357534408569, metrics={'mse': 0.9123357534408569}


[TRAIN 002]: 100%|██████████| 1759/1759 [00:09<00:00, 175.98it/s, batch_loss=0.9984347] 


[VALID 002]: loss=0.8667045831680298, metrics={'mse': 0.8667045831680298}


[TRAIN 003]: 100%|██████████| 1759/1759 [00:10<00:00, 160.86it/s, batch_loss=0.7116962] 


[VALID 003]: loss=0.8626285791397095, metrics={'mse': 0.8626285791397095}


[TRAIN 004]: 100%|██████████| 1759/1759 [00:09<00:00, 177.92it/s, batch_loss=0.71589833]


[VALID 004]: loss=0.8584514856338501, metrics={'mse': 0.8584514856338501}


[TRAIN 005]: 100%|██████████| 1759/1759 [00:10<00:00, 165.26it/s, batch_loss=0.77232915]


[VALID 005]: loss=0.8590362668037415, metrics={'mse': 0.8590362668037415}


[TRAIN 006]: 100%|██████████| 1759/1759 [00:10<00:00, 166.75it/s, batch_loss=0.59617615]


[VALID 006]: loss=0.8544833660125732, metrics={'mse': 0.8544833660125732}


[TRAIN 007]: 100%|██████████| 1759/1759 [00:10<00:00, 164.47it/s, batch_loss=0.7209101] 


[VALID 007]: loss=0.8528735637664795, metrics={'mse': 0.8528735637664795}


[TRAIN 008]: 100%|██████████| 1759/1759 [00:10<00:00, 160.62it/s, batch_loss=1.0805565] 


[VALID 008]: loss=0.8486641645431519, metrics={'mse': 0.8486641645431519}


[TRAIN 009]: 100%|██████████| 1759/1759 [00:11<00:00, 159.29it/s, batch_loss=0.6669211] 


[VALID 009]: loss=0.8476921319961548, metrics={'mse': 0.8476921319961548}


[TRAIN 010]: 100%|██████████| 1759/1759 [00:10<00:00, 170.73it/s, batch_loss=0.62571836]


[VALID 010]: loss=0.8533954620361328, metrics={'mse': 0.8533954620361328}


[TRAIN 011]: 100%|██████████| 1759/1759 [00:10<00:00, 168.82it/s, batch_loss=0.76019]   


[VALID 011]: loss=0.8536626100540161, metrics={'mse': 0.8536626100540161}


[TRAIN 012]: 100%|██████████| 1759/1759 [00:10<00:00, 168.05it/s, batch_loss=0.5222816] 


[VALID 012]: loss=0.8560672998428345, metrics={'mse': 0.8560672998428345}


[TRAIN 013]: 100%|██████████| 1759/1759 [00:10<00:00, 167.33it/s, batch_loss=0.6611078] 


[VALID 013]: loss=0.8589900135993958, metrics={'mse': 0.8589900135993958}


[TRAIN 014]: 100%|██████████| 1759/1759 [00:10<00:00, 167.62it/s, batch_loss=0.7329082] 


[VALID 014]: loss=0.8635580539703369, metrics={'mse': 0.8635580539703369}


[TRAIN 015]: 100%|██████████| 1759/1759 [00:11<00:00, 157.42it/s, batch_loss=0.44422102]


[VALID 015]: loss=0.8688892126083374, metrics={'mse': 0.8688892126083374}


[TRAIN 016]: 100%|██████████| 1759/1759 [00:11<00:00, 154.85it/s, batch_loss=0.6724228] 


[VALID 016]: loss=0.8756545186042786, metrics={'mse': 0.8756545186042786}


[TRAIN 017]: 100%|██████████| 1759/1759 [00:10<00:00, 165.37it/s, batch_loss=0.4715739] 


[VALID 017]: loss=0.8814161419868469, metrics={'mse': 0.8814161419868469}


[TRAIN 018]: 100%|██████████| 1759/1759 [00:11<00:00, 158.92it/s, batch_loss=0.5783349] 


[VALID 018]: loss=0.8919109106063843, metrics={'mse': 0.8919109106063843}


[TRAIN 019]: 100%|██████████| 1759/1759 [00:07<00:00, 229.70it/s, batch_loss=0.4435718] 


[VALID 019]: loss=0.8996073603630066, metrics={'mse': 0.8996073603630066}
🏃 View run bright-ox-408 at: http://localhost:8080/#/experiments/313596797375609562/runs/4c2f7aaf160d44e1a29dd9aa7a386aaf
🧪 View experiment at: http://localhost:8080/#/experiments/313596797375609562
