In [None]:
%load_ext autoreload

In [None]:
import pandas as pd
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split

import jax
import jax.numpy as jnp
from flax import linen as nn
from typing import Sequence

## データセット ml-100k

In [None]:
# READ
ML100K = pd.read_table("~/data/ml-100k/raw/u.data", usecols=[0, 1, 2], header=None)
ML100K.columns = ["user_id", "item_id", "rating"]

# ユーザ数とアイテムを計算
user_num = ML100K["user_id"].unique().shape[0]
item_num = ML100K["item_id"].unique().shape[0]

# 入力データ と 正解データ を用意
X, Y = (ML100K[["user_id", "item_id"]].to_numpy() - 1), ML100K[["rating"]].to_numpy()

# SPLIT
X_TRAIN, X_TEST, Y_TRAIN, Y_TEST = train_test_split(X, Y, test_size=0.3, random_state=0)

# DEVICE_PUT
X_TRAIN, X_TEST, Y_TRAIN, Y_TEST = jax.device_put((
    X_TRAIN,
    X_TEST,
    Y_TRAIN,
    Y_TEST,
))

# SIZE
print("TRAIN:", X_TRAIN.shape, Y_TRAIN.shape)
print("TEST:", X_TEST.shape, Y_TEST.shape)

## モデル

In [None]:
class MatrixFactorization(nn.Module):
    
    user_num: int
    item_num: int
    embed_dims: int = 100
    
    def setup(self):
        self.userEmbed = nn.Embed(num_embeddings=self.user_num, features=self.embed_dims)
        self.itemEmbed = nn.Embed(num_embeddings=self.item_num, features=self.embed_dims)
    
    @nn.compact
    def __call__(self, X):
        return (self.userEmbed(X[:, 0]) * self.itemEmbed(X[:, 1])).sum(axis=1).reshape(-1, 1)

## 学習

In [None]:
%autoreload
from trainer.dataLoader import dataLoader # データローダ
from trainer.regressionTrainer import regressionTrainer # 学習器

In [None]:
model = MatrixFactorization(user_num=user_num, item_num=item_num, embed_dims=10)

trainer = regressionTrainer(model=model, dataLoader=dataLoader, epoch_nums=128, learning_rate=0.001)
state = trainer.fit(X_TRAIN, Y_TRAIN, X_TEST=X_TEST, Y_TEST=Y_TEST)
trainer.plot_loss_history()

In [None]:
model = MatrixFactorization(user_num=user_num, item_num=item_num, embed_dims=100)

trainer = regressionTrainer(model=model, dataLoader=dataLoader, epoch_nums=128, learning_rate=0.001)
state = trainer.fit(X_TRAIN, Y_TRAIN, X_TEST=X_TEST, Y_TEST=Y_TEST)
trainer.plot_loss_history()