In [None]:
import jax
import jax.numpy as jnp
import optax
import flax.linen as nn
from flax.training import train_state
from typing import NamedTuple, Tuple, Callable, List
from flax.core.scope import FrozenVariableDict
from optax import ScalarOrSchedule
import tensorflow_datasets as tfds
from syuron import dataset


type ModelState = train_state.TrainState


type ApplyFn = Callable
type ModelParams = FrozenVariableDict
type Loss = jnp.ndarray


class Batch(NamedTuple):
    inputs: jnp.ndarray
    outputs: jnp.ndarray


type UseState = Callable[[ScalarOrSchedule, int, List[int], int], ModelState]
type TrainStep = Callable[[ModelState, Batch], Tuple[ModelState, Loss]]
type LossFn = Callable[[ModelParams, Batch, ApplyFn], Loss]


class MLP(nn.Module):
    hidden_sizes: List[int]
    output_size: int

    @nn.compact
    def __call__(self, x):
        for h in self.hidden_sizes:
            x = nn.Dense(features=h)(x)
            x = nn.relu(x)
        x = nn.Dense(features=self.output_size)(x)
        return x


def use_state(learning_rate: ScalarOrSchedule, input_size: int, hidden_sizes: List[int], output_size: int) -> ModelState:
    """
    MLPモデルを生成し、初期パラメータをダミー入力から初期化。
    Adamオプティマイザを用いてTrainState（ModelState）を返す。
    """
    model = MLP(hidden_sizes=hidden_sizes, output_size=output_size)
    rng = jax.random.PRNGKey(0)
    dummy_input = jnp.ones([1, input_size])
    params = model.init(rng, dummy_input)
    tx = optax.adam(learning_rate)
    state = train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=tx)
    return state


def loss_fn(params: ModelParams, batch: Batch, apply_fn: ApplyFn) -> Loss:
    """
    クロスエントロピー損失を計算する。
    モデルの出力(logits)に対してlog_softmaxを適用し、one-hotラベルとのクロスエントロピーの平均を返す。
    """
    logits = apply_fn(params, batch.inputs)
    log_probs = jax.nn.log_softmax(logits)
    loss = -jnp.mean(jnp.sum(batch.outputs * log_probs, axis=-1))
    return loss


def train_step(state: ModelState, batch: Batch) -> Tuple[ModelState, Loss]:
    """
    1バッチ分の学習ステップを実施する関数。
    損失とその勾配をjax.value_and_gradで計算し、apply_gradientsでパラメータ更新を行う。
    """
    def compute_loss(params):
        return loss_fn(params, batch, state.apply_fn)
    loss, grads = jax.value_and_grad(compute_loss)(state.params)
    new_state = state.apply_gradients(grads=grads)
    return new_state, loss


def train_and_eval(use_state: UseState, train_step: TrainStep, loss_fn: LossFn) -> ModelState:
    """
    syuron.dataset.load_mnist()で読み込んだ保存済みMNISTデータセットを用いてMLPで学習・評価を行う。
    ・最初に1バッチ分から入力・出力次元を取得してuse_stateでモデル初期化
    ・初期状態での損失を計算し表示
    ・その後、指定したエポック数（ここでは5）だけtrain_stepを繰り返し、各エポック毎に平均損失を出力する
    ・最終的なモデル状態を返す
    """
    # 保存済みのMNISTデータセットを読み込む（すでに前処理済みと仮定）
    mnist = dataset.load_mnist()
    ds = mnist.batch(128)
    ds_np = tfds.as_numpy(ds)  # JAXはNumPy配列を扱うので変換

    # サンプルバッチから入力次元と出力次元を取得
    sample_batch = next(iter(ds_np))
    input_sample, output_sample = sample_batch  # それぞれ形状: (784,), (10,)
    input_size = input_sample.shape[-1]
    output_size = output_sample.shape[-1]

    # ハイパーパラメータの設定
    hidden_sizes = [128, 64]
    learning_rate = 1e-3
    epochs = 5

    # モデル初期化
    state = use_state(learning_rate, input_size, hidden_sizes, output_size)

    # 初期損失の計算
    batch = Batch(inputs=jnp.array(input_sample),
                  outputs=jnp.array(output_sample))
    init_loss = loss_fn(state.params, batch, state.apply_fn)
    print("Initial loss:", init_loss)

    # 学習ループ
    for epoch in range(epochs):
        epoch_loss = 0.0
        count = 0
        for batch_data in ds_np:
            inputs, outputs = batch_data
            batch = Batch(inputs=jnp.array(inputs), outputs=jnp.array(outputs))
            state, loss = train_step(state, batch)
            epoch_loss += loss
            count += 1
            print(f"Epoch {epoch+1}: Batch: {count} Average Loss = {loss}")
        avg_loss = epoch_loss / count
        print(f"Epoch {epoch+1}: Average Loss = {avg_loss}")
    return state


if __name__ == '__main__':
    final_state = train_and_eval(use_state, train_step, loss_fn)
    print("Training completed. Final model state:")
    print(final_state)

2025-03-17 06:59:23.158548: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742194763.243611   53111 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742194763.267878   53111 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1742194763.439021   53111 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1742194763.439046   53111 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1742194763.439048   53111 computation_placer.cc:177] computation placer alr

Initial loss: 2.2928033
Epoch 1: Average Loss = 0.3607730567455292
Epoch 2: Average Loss = 0.14518822729587555
Epoch 3: Average Loss = 0.1003720760345459
Epoch 4: Average Loss = 0.0749954804778099
Epoch 5: Average Loss = 0.058531466871500015
Training completed. Final model state:
TrainState(step=2345, apply_fn=<bound method Module.apply of MLP(
    # attributes
    hidden_sizes = [128, 64]
    output_size = 10
)>, params={'params': {'Dense_0': {'bias': Array([-0.03060845,  0.03132747,  0.07028603,  0.02497231,  0.0385511 ,
        0.02081833,  0.09615315,  0.05499851,  0.1201482 ,  0.06335388,
       -0.01873233,  0.02870467,  0.02579832,  0.00122235,  0.01266044,
        0.00574078,  0.05957482,  0.08619644,  0.00152224,  0.02846688,
       -0.00956561,  0.0509107 , -0.03737363,  0.03479095,  0.02441783,
       -0.01247602,  0.06261256,  0.04406655,  0.02014735,  0.00550752,
        0.07613483,  0.05022801, -0.00596691, -0.0174517 ,  0.02296995,
       -0.01050434, -0.00444727,  0.082

: 