In [None]:
########################################
# ライブラリ
########################################

import jax
import jax.numpy as jnp

from flax import linen as nn
from flax.training import train_state
import optax

import pandas as pd
from tqdm import trange as tqdm_range
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split

In [None]:
########################################
# モデルの定義
########################################

class MatrixFactorization(nn.Module):
    
    k: int = 20
    
    @nn.compact
    def __call__(self, user_onehot, item_onehot):
        user_vector = nn.Dense(features=self.k, use_bias=False)(user_onehot)
        item_vector = nn.Dense(features=self.k, use_bias=False)(item_onehot)
        return jnp.sum(user_vector * item_vector, axis=1)

In [None]:
########################################
# 損失関数
########################################

@jax.jit
def loss_function(params, X_USER, X_ITEM, y):
    # 予測値計算
    predict = model.apply({'params': params}, X_USER, X_ITEM)
    # 損失を計算
    loss = jnp.mean( (predict - y)**2 )
    return loss

In [None]:
########################################
# データセットの読み込み
########################################

ML100K = pd.read_table("/home/sugahara/data/ml-100k/row/u.data", header=None)

X_USER = pd.get_dummies(ML100K[0]).values
X_ITEM = pd.get_dummies(ML100K[1]).values
y = ML100K[2].values

X_USER, X_ITEM, y = jax.device_put(X_USER), jax.device_put(X_ITEM), jax.device_put(y)

user_size = X_USER.shape[1]
item_size = X_ITEM.shape[1]

In [None]:
########################################
# モデル
########################################

# モデルの作成
model = MatrixFactorization(k=20)
params = model.init( jax.random.PRNGKey(0), jnp.ones((1, user_size)), jnp.ones((1, item_size)) )["params"]

# SGDを定義
learning_rate = 0.001
tx = optax.adam(learning_rate)

# パラメータの管理
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

In [None]:
########################################
# 学習
########################################

epoch_nums = 512

for epoch_id in range(epoch_nums):
    
    # 損失と勾配を計算
    loss, grads = jax.value_and_grad(loss_function)(state.params, X_USER, X_ITEM, y)

    # 更新
    state = state.apply_gradients(grads=grads)

    # 誤差の確認
    print(
        "訓練誤差:",
        '{:.3f}'.format(loss_function(state.params, X_USER, X_ITEM, y)),
        f"【 Epoch: {epoch_id} / {epoch_nums} 】"
    )