In [None]:
%load_ext autoreload

In [None]:
import pandas as pd
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split

import jax
import jax.numpy as jnp
from flax import linen as nn
from typing import Sequence

## データセット ml-100k

In [None]:
# READ
ML100K = pd.read_table("~/data/ml-100k/u.data", header=None)
X, Y = (ML100K[[0, 1]] - 1), ML100K[[2]]
user_nums, item_nums = len(ML100K[0].unique()), len(ML100K[1].unique())


# SPLIT
X_TRAIN, X_TEST, Y_TRAIN, Y_TEST = train_test_split(X, Y, test_size=0.3, random_state=0)

# DEVICE_PUT
X_TRAIN, X_TEST, Y_TRAIN, Y_TEST = jax.device_put((
    X_TRAIN.to_numpy(),
    X_TEST.to_numpy(),
    Y_TRAIN.to_numpy(),
    Y_TEST.to_numpy(),
))

# SIZE
print("TRAIN:", X_TRAIN.shape, Y_TRAIN.shape)
print("TEST:", X_TEST.shape, Y_TEST.shape)

## モデル

In [None]:
class MatrixFactorization(nn.Module):
    
    user_nums: int
    item_nums: int
    embed_dim_size: int = 100
    
    def setup(self):
        self.userEmbed = nn.Embed(num_embeddings=self.user_nums, features=self.embed_dim_size)
        self.itemEmbed = nn.Embed(num_embeddings=self.item_nums, features=self.embed_dim_size)
    
    @nn.compact
    def __call__(self, X):
        return (self.userEmbed(X[:, 0]) * self.itemEmbed(X[:, 1])).sum(axis=1).reshape(-1, 1)

## 学習

In [None]:
%autoreload
from trainer.regressionTrainer import regressionTrainer # 学習器

In [None]:
model = MatrixFactorization(user_nums=user_nums, item_nums=item_nums, embed_dim_size=10)
trainer = regressionTrainer(model=model)
state = trainer.fit(X_TRAIN, Y_TRAIN, X_TEST=X_TEST, Y_TEST=Y_TEST, epoch_nums=128, learning_rate=0.01)
fig = trainer.plot_loss_history(hide_init_loss=False)
fig.show()

In [None]:
model = MatrixFactorization(user_nums=user_nums, item_nums=item_nums, embed_dim_size=100)
trainer = regressionTrainer(model=model)
state = trainer.fit(X_TRAIN, Y_TRAIN, X_TEST=X_TEST, Y_TEST=Y_TEST, epoch_nums=128, learning_rate=0.01)
fig = trainer.plot_loss_history(hide_init_loss=False)
fig.show()