In [None]:
%load_ext autoreload

In [None]:
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split

import jax
import jax.numpy as jnp
from flax import linen as nn
from typing import Sequence

## データセット california_housing

In [None]:
# READ
X, Y, *others = fetch_california_housing(as_frame=True).values()
feature_list = list(X.columns)

# SPLIT
X_TRAIN, X_TEST, Y_TRAIN, Y_TEST = train_test_split(X, Y, test_size=0.3, random_state=0)

# DEVICE_PUT
X_TRAIN, X_TEST, Y_TRAIN, Y_TEST = jax.device_put((
    X_TRAIN.to_numpy(),
    X_TEST.to_numpy(),
    Y_TRAIN.to_numpy(),
    Y_TEST.to_numpy(),
))

# SIZE
print("TRAIN:", X_TRAIN.shape, Y_TRAIN.shape)
print("TEST:", X_TEST.shape, Y_TEST.shape)

## モデル

In [None]:
class MLP(nn.Module):
    
    layer_sizes: Sequence[int]
    
    @nn.compact
    def __call__(self, X):

        for features in self.layer_sizes:
            X = nn.Dense(features=features)(X)
            X = nn.leaky_relu(X)

        return X.reshape(-1)

## 学習

In [None]:
%autoreload
from trainer.dataLoader import dataLoader # データローダ
from trainer.regressionTrainer import regressionTrainer # 学習器

In [None]:
model = MLP(layer_sizes=[8, 4, 2, 1])

trainer = regressionTrainer(model=model, dataLoader=dataLoader, epoch_nums=128, learning_rate=0.001)
state = trainer.fit(X_TRAIN, Y_TRAIN, X_TEST=X_TEST, Y_TEST=Y_TEST)
trainer.plot_loss_history()