In [1]:
import optax
import pandas as pd
import polars as pl
from flax import nnx
from flax_trainer.evaluator import RegressionEvaluator
from flax_trainer.loader import MiniBatchLoader
from flax_trainer.loss_fn import mean_squared_error
from flax_trainer.trainer import Trainer
from sklearn.model_selection import train_test_split

from flax_recsys.model.two_tower import TwoTower

### READ

In [2]:
rating_df, movie_df, user_df = (
    pl.from_pandas(
        pd.read_csv(
            "../../dataset/ML1M/ml-1m/ratings.dat",
            delimiter="::",
            engine="python",
            header=None,
            names=["user_id", "item_id", "rating", "timestamp"],
        )
    ),
    pl.from_pandas(
        pd.read_csv(
            "../../dataset/ML1M/ml-1m/movies.dat",
            delimiter="::",
            engine="python",
            header=None,
            names=["item_id", "title", "genres"],
            encoding="ISO-8859-1",
        )
    ).with_columns(pl.col("genres").str.split("|")),
    pl.from_pandas(
        pd.read_csv(
            "../../dataset/ML1M/ml-1m/users.dat",
            delimiter="::",
            engine="python",
            header=None,
            names=["user_id", "gender", "age", "occupation", "zip_code"],
            # encoding="ISO-8859-1",
        )
    ),
)

dataset_df = (
    rating_df.join(movie_df, how="left", on="item_id")
    .join(user_df, how="left", on="user_id")
    .select(pl.all().exclude("rating", "genres"), pl.col("rating"))
    .with_columns(pl.all().exclude("rating").rank("dense") - 1)
)

user_categorical_feature_indices = [0, 4, 5, 6, 7]
user_categorical_feature_cardinalities = [
    dataset_df[:, idx].n_unique() for idx in user_categorical_feature_indices
]
user_numerical_feature_indices = []
item_categorical_feature_indices = [1, 3]
item_categorical_feature_cardinalities = [
    dataset_df[:, idx].n_unique() for idx in item_categorical_feature_indices
]
item_numerical_feature_indices = []

train_df, valid_df = train_test_split(
    dataset_df, test_size=0.1, random_state=0, shuffle=True
)
dataset_df

user_id,item_id,timestamp,title,gender,age,occupation,zip_code,rating
u32,u32,u32,u32,u32,u32,u32,u32,i64
0,1104,397442,2452,0,0,10,1588,5
0,639,397457,1739,0,0,10,1588,3
0,853,397454,2289,0,0,10,1588,3
0,3177,397440,1054,0,0,10,1588,4
0,2162,400250,557,0,0,10,1588,5
…,…,…,…,…,…,…,…,…
6039,1019,382,3574,1,2,6,466,1
6039,1022,21,814,1,2,6,466,5
6039,548,17,3578,1,2,6,466,5
6039,1024,346,3090,1,2,6,466,4


### 学習

In [3]:
model = TwoTower(
    user_categorical_feature_indices=user_categorical_feature_indices,
    user_categorical_feature_cardinalities=user_categorical_feature_cardinalities,
    user_numerical_feature_indices=user_numerical_feature_indices,
    user_hidden_layer_dims=[128, 64],
    item_categorical_feature_indices=item_categorical_feature_indices,
    item_categorical_feature_cardinalities=item_categorical_feature_cardinalities,
    item_numerical_feature_indices=item_numerical_feature_indices,
    item_hidden_layer_dims=[128, 64],
    embed_dim=30,
    output_layer_dim=30,
    rngs=nnx.Rngs(0),
)

In [4]:
loader = MiniBatchLoader(dataset_df=train_df, batch_size=512, rngs=nnx.Rngs(0))
evaluator = RegressionEvaluator(dataset_df=valid_df)

trainer = Trainer(
    model=model,
    optimizer=optax.adamw(learning_rate=0.001, weight_decay=0.001),
    train_loader=loader,
    loss_fn=mean_squared_error,
    test_evaluator=evaluator,
    early_stopping_patience=10,
    epoch_num=32,
)
trainer = trainer.fit()

[TEST  000] loss=14.048372268676758


[TRAIN 001]: 100%|██████████| 1759/1759 [00:07<00:00, 242.51it/s, batch_loss=0.7955281] 


[TEST  001] loss=0.8363550901412964


[TRAIN 002]: 100%|██████████| 1759/1759 [00:06<00:00, 253.64it/s, batch_loss=0.83432454]


[TEST  002] loss=0.8051806688308716


[TRAIN 003]: 100%|██████████| 1759/1759 [00:07<00:00, 243.92it/s, batch_loss=0.6562436] 


[TEST  003] loss=0.7783756256103516


[TRAIN 004]: 100%|██████████| 1759/1759 [00:07<00:00, 250.71it/s, batch_loss=0.60549176]


[TEST  004] loss=0.7691575288772583


[TRAIN 005]: 100%|██████████| 1759/1759 [00:07<00:00, 251.27it/s, batch_loss=0.68420726]


[TEST  005] loss=0.7674286365509033


[TRAIN 006]: 100%|██████████| 1759/1759 [00:06<00:00, 253.75it/s, batch_loss=0.6035888] 


[TEST  006] loss=0.7583749890327454


[TRAIN 007]: 100%|██████████| 1759/1759 [00:07<00:00, 247.36it/s, batch_loss=0.67974174]


[TEST  007] loss=0.7574067115783691


[TRAIN 008]: 100%|██████████| 1759/1759 [00:08<00:00, 212.25it/s, batch_loss=1.0016836] 


[TEST  008] loss=0.7584105730056763


[TRAIN 009]: 100%|██████████| 1759/1759 [00:07<00:00, 242.72it/s, batch_loss=0.61300594]


[TEST  009] loss=0.7667893767356873


[TRAIN 010]: 100%|██████████| 1759/1759 [00:07<00:00, 245.63it/s, batch_loss=0.7228317] 


[TEST  010] loss=0.7695207595825195


[TRAIN 011]: 100%|██████████| 1759/1759 [00:07<00:00, 251.23it/s, batch_loss=0.84011364]


[TEST  011] loss=0.7678759694099426


[TRAIN 012]: 100%|██████████| 1759/1759 [00:07<00:00, 224.98it/s, batch_loss=0.49794954]


[TEST  012] loss=0.7723577618598938


[TRAIN 013]: 100%|██████████| 1759/1759 [00:07<00:00, 240.04it/s, batch_loss=0.7092424] 


[TEST  013] loss=0.7781623601913452


[TRAIN 014]: 100%|██████████| 1759/1759 [00:07<00:00, 244.28it/s, batch_loss=0.6780895] 


[TEST  014] loss=0.781099259853363


[TRAIN 015]: 100%|██████████| 1759/1759 [00:06<00:00, 258.15it/s, batch_loss=0.49448097]


[TEST  015] loss=0.785104513168335


[TRAIN 016]: 100%|██████████| 1759/1759 [00:06<00:00, 251.52it/s, batch_loss=0.7873649] 


[TEST  016] loss=0.786565363407135


[TRAIN 017]: 100%|██████████| 1759/1759 [00:07<00:00, 250.27it/s, batch_loss=0.4890897] 


[TEST  017] loss=0.7892701029777527


In [5]:
%%timeit
model.user_tower(valid_df[:1, :-1].to_numpy())

817 μs ± 131 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
