In [None]:
import matplotlib.pyplot as plt

In [None]:
import jax
import jax.numpy as jnp
from jax import nn
from jax.nn.initializers import glorot_normal, normal

In [None]:
"""
    データセットの読み込み
"""

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
iris_dataset = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris_dataset['data'], iris_dataset['target'], test_size=0.25,  random_state=0)
X_train, X_test, y_train, y_test = jax.device_put(X_train), jax.device_put(X_test), jax.device_put(y_train), jax.device_put(y_test)
y_train = jnp.eye(3)[y_train]

In [None]:
# パラメータの初期化
rng = jax.random.PRNGKey(0)
rng1, rng2 = jax.random.split(rng)
rng1w, rng1b = jax.random.split(rng1)
rng2w, rng2b = jax.random.split(rng2)

params = {
    "linear1": {
        "W": glorot_normal()(rng1w, (4, 100)),
        "b": normal()(rng1b, (100,))
    },
    "linear2": {
        "W": glorot_normal()(rng2w, (100, 3)),
        "b": normal()(rng2b, (3,))
    }
}

In [None]:
# Forward

@jax.jit
def Linear(params, x):
    return jnp.dot(x, params["W"]) + params["b"]


@jax.jit
def MLP(params, x):
    y1 = Linear(params["linear1"], x)
    z1 = nn.relu(y1)
    y2 = Linear(params["linear2"], z1)
    z2 = nn.softmax(y2)
    return z2

In [None]:
# 損失関数
@jax.jit
def categorical_cross_entropy_loss(true_onehot, predict):
    return jnp.mean(-jnp.sum(true_onehot * jnp.log(predict), axis=1))

In [None]:
@jax.jit
def SGD(params, grad, lr = 0.1):
    params["linear1"]["W"] -= lr * grad["linear1"]["W"]
    params["linear2"]["W"] -= lr * grad["linear2"]["W"]
    params["linear1"]["b"] -= lr * grad["linear1"]["b"]
    params["linear2"]["b"] -= lr * grad["linear2"]["b"]
    return params

In [None]:
@jax.jit
def train_batch(params, batch_X, batch_y):
    
    def loss_fn(params_, batch_X_):
        logits = MLP(params_, batch_X_)
        return categorical_cross_entropy_loss(batch_y, logits)
    
    grad = jax.grad(loss_fn)(params, batch_X)
    
    return SGD(params, grad)

In [None]:
batch_size = 50

# 訓練データのインデックスをシャッフル
index = jax.random.permutation(rng, X_train.shape[0])
# バッチ数
batch_length = jnp.ceil(X_train.shape[0] / batch_size)

def train_for_each_batch(batch_idx, params):
    target_train_indices = jax.lax.dynamic_slice(index, [batch_idx*batch_size], [batch_size])
    params = train_batch(params, X_train[target_train_indices], y_train[target_train_indices])
    return params

params = jax.lax.fori_loop(0, int(batch_length), train_for_each_batch, params)