In [None]:
import jax
import jax.numpy as jnp
import optax
import flax.linen as nn
from flax.training import train_state
from typing import Tuple, List
from optax import ScalarOrSchedule
from syuron import dataset
from syuron import mlp


class MLP(nn.Module):
    hidden_sizes: List[int]
    output_size: int

    @nn.compact
    def __call__(self, x):
        for h in self.hidden_sizes:
            x = nn.Dense(features=h)(x)
            x = nn.relu(x)
        x = nn.Dense(features=self.output_size)(x)
        return x


def use_state(learning_rate: ScalarOrSchedule, input_size: int, hidden_sizes: List[int], output_size: int) -> mlp.ModelState:
    """
    MLPモデルを生成し、初期パラメータをダミー入力から初期化。
    Adamオプティマイザを用いてTrainState（ModelState）を返す。
    """
    model = MLP(hidden_sizes=hidden_sizes, output_size=output_size)
    rng = jax.random.PRNGKey(0)
    dummy_input = jnp.ones([1, input_size])
    params = model.init(rng, dummy_input)
    tx = optax.adam(learning_rate)
    state = train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=tx)
    return state


def loss_fn(params: mlp.ModelParams, batch: dataset.Batch, apply_fn: mlp.ApplyFn) -> mlp.Loss:
    """
    クロスエントロピー損失を計算する。
    モデルの出力(logits)に対してlog_softmaxを適用し、one-hotラベルとのクロスエントロピーの平均を返す。
    """
    logits = apply_fn(params, batch.inputs)
    log_probs = jax.nn.log_softmax(logits)
    loss = -jnp.mean(jnp.sum(batch.outputs * log_probs, axis=-1))
    return loss  # type: ignore


def train_step(state: mlp.ModelState, batch: dataset.Batch) -> Tuple[mlp.ModelState, mlp.Loss]:
    """
    1バッチ分の学習ステップを実施する関数。
    損失とその勾配をjax.value_and_gradで計算し、apply_gradientsでパラメータ更新を行う。
    """
    loss, grads = jax.value_and_grad(loss_fn)(
        state.params, batch, state.apply_fn)
    new_state = state.apply_gradients(grads=grads)
    return new_state, loss


final_state, final_loss = mlp.train_and_eval(
    dataset.load_mnist,
    use_state,
    train_step,
    loss_fn,
    mlp.OptimizableParams(learning_rate=1e-3, hidden_sizes=[128, 64]),
    5
)

print("Training completed. Final model loss:")
print(final_loss)

print("Baysian optimization start.")
final_state, final_loss = mlp.bayesian_optim(
    dataset.load_mnist, use_state, train_step, loss_fn)

2025-03-17 10:32:30.394433: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742207550.415509   31762 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742207550.422004   31762 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1742207550.438547   31762 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1742207550.438568   31762 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1742207550.438570   31762 computation_placer.cc:177] computation placer alr

Initial loss: 2.3123507


Epoch 1/5: 100%|██████████| 118/118 [00:13<00:00,  8.63it/s]


Epoch 1: Average Loss = 0.6157775521278381


Epoch 2/5: 100%|██████████| 118/118 [00:08<00:00, 13.30it/s]


Epoch 2: Average Loss = 0.22774720191955566


Epoch 3/5: 100%|██████████| 118/118 [00:08<00:00, 13.47it/s]


Epoch 3: Average Loss = 0.17137962579727173


Epoch 4/5: 100%|██████████| 118/118 [00:08<00:00, 13.63it/s]


Epoch 4: Average Loss = 0.13709883391857147


Epoch 5/5: 100%|██████████| 118/118 [00:08<00:00, 13.61it/s]


Epoch 5: Average Loss = 0.11209221184253693
Training completed. Final model state:
0.11209221
Baysian optimization start.
|   iter    |  target   | hidden... | hidden... | learni... |
-------------------------------------------------------------
Initial loss: 2.3137894


Epoch 1/1: 100%|██████████| 118/118 [00:12<00:00,  9.13it/s]


Epoch 1: Average Loss = 4.1450910568237305
Validation loss (lr=0.0731996621871987, hidden1=211, hidden2=488): 4.1450910568237305
| [39m1        [39m | [39m-4.145   [39m | [39m211.8    [39m | [39m488.3    [39m | [39m0.0732   [39m |
Initial loss: 2.3321671


Epoch 1/1: 100%|██████████| 118/118 [00:12<00:00,  9.22it/s]


Epoch 1: Average Loss = 0.32410916686058044
Validation loss (lr=0.015600296039099928, hidden1=319, hidden2=106): 0.32410916686058044
| [35m2        [39m | [35m-0.3241  [39m | [35m319.4    [39m | [35m106.9    [39m | [35m0.0156   [39m |
Initial loss: 2.3330827


Epoch 1/1: 100%|██████████| 118/118 [00:12<00:00,  9.30it/s]


Epoch 1: Average Loss = 1.6977014541625977
Validation loss (lr=0.060111900059309144, hidden1=59, hidden2=447): 1.6977014541625977
| [39m3        [39m | [39m-1.698   [39m | [39m59.88    [39m | [39m447.8    [39m | [39m0.06011  [39m |
Initial loss: 2.3163192


Epoch 1/1: 100%|██████████| 118/118 [00:12<00:00,  9.28it/s]


Epoch 1: Average Loss = 2.763275384902954
Validation loss (lr=0.09699101530634728, hidden1=371, hidden2=41): 2.763275384902954
| [39m4        [39m | [39m-2.763   [39m | [39m371.9    [39m | [39m41.88    [39m | [39m0.09699  [39m |
Initial loss: 2.3234353


Epoch 1/1: 100%|██████████| 118/118 [00:13<00:00,  9.02it/s]


Epoch 1: Average Loss = 0.41462641954421997
Validation loss (lr=0.018183314895742857, hidden1=431, hidden2=133): 0.41462641954421997
| [39m5        [39m | [39m-0.4146  [39m | [39m431.6    [39m | [39m133.9    [39m | [39m0.01818  [39m |
Initial loss: 2.2998137


Epoch 1/1: 100%|██████████| 118/118 [00:12<00:00,  9.15it/s]


Epoch 1: Average Loss = 0.4229424297809601
Validation loss (lr=0.02338204949647332, hidden1=321, hidden2=110): 0.4229424297809601
| [39m6        [39m | [39m-0.4229  [39m | [39m321.2    [39m | [39m110.2    [39m | [39m0.02338  [39m |
Initial loss: 2.2994008


Epoch 1/1: 100%|██████████| 118/118 [00:12<00:00,  9.26it/s]


Epoch 1: Average Loss = 2.204457998275757
Validation loss (lr=0.0819413272014111, hidden1=263, hidden2=97): 2.204457998275757
| [39m7        [39m | [39m-2.204   [39m | [39m264.0    [39m | [39m97.49    [39m | [39m0.08194  [39m |
Initial loss: 2.3180013


Epoch 1/1: 100%|██████████| 118/118 [00:12<00:00,  9.16it/s]


Epoch 1: Average Loss = 3.1419577598571777
Validation loss (lr=0.07316488328345236, hidden1=390, hidden2=148): 3.1419577598571777
| [39m8        [39m | [39m-3.142   [39m | [39m390.1    [39m | [39m148.9    [39m | [39m0.07316  [39m |
Initial loss: 2.3086827


Epoch 1/1: 100%|██████████| 118/118 [00:12<00:00,  9.24it/s]


Epoch 1: Average Loss = 1.671608567237854
Validation loss (lr=0.06038575609888851, hidden1=458, hidden2=125): 1.671608567237854
| [39m9        [39m | [39m-1.672   [39m | [39m458.7    [39m | [39m125.4    [39m | [39m0.06039  [39m |
Initial loss: 2.3071284


Epoch 1/1: 100%|██████████| 118/118 [00:13<00:00,  8.87it/s]


Epoch 1: Average Loss = 1.256033182144165
Validation loss (lr=0.05457824046109374, hidden1=332, hidden2=85): 1.256033182144165
| [39m10       [39m | [39m-1.256   [39m | [39m332.6    [39m | [39m85.86    [39m | [39m0.05458  [39m |
Initial loss: 2.2946875


Epoch 1/1: 100%|██████████| 118/118 [00:12<00:00,  9.10it/s]


Epoch 1: Average Loss = 0.2743774652481079
Validation loss (lr=0.009097740713809916, hidden1=424, hidden2=109): 0.2743774652481079
| [35m11       [39m | [35m-0.2744  [39m | [35m424.7    [39m | [35m109.4    [39m | [35m0.009098 [39m |
Initial loss: 2.317873


Epoch 1/1: 100%|██████████| 118/118 [00:13<00:00,  9.06it/s]


Epoch 1: Average Loss = 0.7297729849815369
Validation loss (lr=0.04121235003120042, hidden1=439, hidden2=76): 0.7297729849815369
| [39m12       [39m | [39m-0.7298  [39m | [39m439.3    [39m | [39m76.26    [39m | [39m0.04121  [39m |
Initial loss: 2.3128498


Epoch 1/1: 100%|██████████| 118/118 [00:12<00:00,  9.19it/s]


Epoch 1: Average Loss = 0.946134090423584
Validation loss (lr=0.048145650638185815, hidden1=472, hidden2=48): 0.946134090423584
| [39m13       [39m | [39m-0.9461  [39m | [39m472.7    [39m | [39m48.19    [39m | [39m0.04815  [39m |
Initial loss: 2.2876186


Epoch 1/1: 100%|██████████| 118/118 [00:13<00:00,  9.05it/s]


Epoch 1: Average Loss = 8.035195350646973
Validation loss (lr=0.09794123766630679, hidden1=298, hidden2=140): 8.035195350646973
| [39m14       [39m | [39m-8.035   [39m | [39m298.0    [39m | [39m140.1    [39m | [39m0.09794  [39m |
Initial loss: 2.317506


Epoch 1/1: 100%|██████████| 118/118 [00:13<00:00,  8.93it/s]


Epoch 1: Average Loss = 0.3839266002178192
Validation loss (lr=0.0017772023543374306, hidden1=407, hidden2=82): 0.3839266002178192
| [39m15       [39m | [39m-0.3839  [39m | [39m407.0    [39m | [39m82.83    [39m | [39m0.001777 [39m |
Initial loss: 2.2964315


Epoch 1/1: 100%|██████████| 118/118 [00:11<00:00,  9.96it/s]


Epoch 1: Average Loss = 0.36535024642944336
Validation loss (lr=0.0025030819425186997, hidden1=429, hidden2=41): 0.36535024642944336
| [39m16       [39m | [39m-0.3654  [39m | [39m429.9    [39m | [39m41.29    [39m | [39m0.002503 [39m |
Initial loss: 2.326091


Epoch 1/1: 100%|██████████| 118/118 [00:13<00:00,  9.02it/s]


Epoch 1: Average Loss = 0.5480305552482605
Validation loss (lr=0.0005747817017286947, hidden1=446, hidden2=173): 0.5480305552482605
| [39m17       [39m | [39m-0.548   [39m | [39m446.1    [39m | [39m173.4    [39m | [39m0.0005748[39m |
Initial loss: 2.3222997


Epoch 1/1: 100%|██████████| 118/118 [00:13<00:00,  9.01it/s]


Epoch 1: Average Loss = 0.2983846366405487
Validation loss (lr=0.0027493900223528466, hidden1=473, hidden2=200): 0.2983846366405487
| [39m18       [39m | [39m-0.2984  [39m | [39m474.0    [39m | [39m200.6    [39m | [39m0.002749 [39m |
Initial loss: 2.3144865


Epoch 1/1: 100%|██████████| 118/118 [00:13<00:00,  8.95it/s]


Epoch 1: Average Loss = 0.2725956439971924
Validation loss (lr=0.008772019387547438, hidden1=440, hidden2=213): 0.2725956439971924
| [35m19       [39m | [35m-0.2726  [39m | [35m440.9    [39m | [35m213.5    [39m | [35m0.008772 [39m |
Initial loss: 2.3307123


Epoch 1/1: 100%|██████████| 118/118 [00:13<00:00,  9.05it/s]


Epoch 1: Average Loss = 7.100915908813477
Validation loss (lr=0.09300431511243798, hidden1=471, hidden2=240): 7.100915908813477
| [39m20       [39m | [39m-7.101   [39m | [39m471.3    [39m | [39m240.8    [39m | [39m0.093    [39m |
Initial loss: 2.311585


Epoch 1/1: 100%|██████████| 118/118 [00:12<00:00,  9.09it/s]


Epoch 1: Average Loss = 0.2708035409450531
Validation loss (lr=0.008377656906364616, hidden1=485, hidden2=174): 0.2708035409450531
| [35m21       [39m | [35m-0.2708  [39m | [35m485.1    [39m | [35m174.1    [39m | [35m0.008378 [39m |
Initial loss: 2.3220866


Epoch 1/1: 100%|██████████| 118/118 [00:13<00:00,  8.92it/s]


Epoch 1: Average Loss = 0.6298381090164185
Validation loss (lr=0.026074908072006773, hidden1=415, hidden2=198): 0.6298381090164185
| [39m22       [39m | [39m-0.6298  [39m | [39m415.8    [39m | [39m198.6    [39m | [39m0.02607  [39m |
Initial loss: 2.3088174


Epoch 1/1: 100%|██████████| 118/118 [00:14<00:00,  8.34it/s]


Epoch 1: Average Loss = 0.27910372614860535
Validation loss (lr=0.004164792778163996, hidden1=406, hidden2=234): 0.27910372614860535
| [39m23       [39m | [39m-0.2791  [39m | [39m406.7    [39m | [39m234.1    [39m | [39m0.004165 [39m |
Initial loss: 2.324307


Epoch 1/1: 100%|██████████| 118/118 [00:13<00:00,  8.72it/s]


Epoch 1: Average Loss = 1.4212507009506226
Validation loss (lr=0.050229294385521, hidden1=370, hidden2=233): 1.4212507009506226
| [39m24       [39m | [39m-1.421   [39m | [39m370.3    [39m | [39m233.8    [39m | [39m0.05023  [39m |
Initial loss: 2.30154


Epoch 1/1: 100%|██████████| 118/118 [00:13<00:00,  8.73it/s]


Epoch 1: Average Loss = 0.3082907497882843
Validation loss (lr=0.002389326179179107, hidden1=510, hidden2=191): 0.3082907497882843
| [39m25       [39m | [39m-0.3083  [39m | [39m510.5    [39m | [39m191.8    [39m | [39m0.002389 [39m |
Initial loss: 2.3087301


Epoch 1/1: 100%|██████████| 118/118 [00:16<00:00,  7.12it/s]


Epoch 1: Average Loss = 6.22697639465332
Validation loss (lr=0.09692174218370951, hidden1=509, hidden2=151): 6.22697639465332
| [39m26       [39m | [39m-6.227   [39m | [39m509.7    [39m | [39m151.6    [39m | [39m0.09692  [39m |
Initial loss: 2.3307073


Epoch 1/1: 100%|██████████| 118/118 [00:13<00:00,  8.56it/s]


Epoch 1: Average Loss = 0.4607982635498047
Validation loss (lr=0.01991350785945427, hidden1=399, hidden2=268): 0.4607982635498047
| [39m27       [39m | [39m-0.4608  [39m | [39m399.3    [39m | [39m268.3    [39m | [39m0.01991  [39m |
Initial loss: 2.3312104


Epoch 1/1: 100%|██████████| 118/118 [00:14<00:00,  8.14it/s]


Epoch 1: Average Loss = 4.115870475769043
Validation loss (lr=0.06641698631276291, hidden1=371, hidden2=290): 4.115870475769043
| [39m28       [39m | [39m-4.116   [39m | [39m371.3    [39m | [39m290.7    [39m | [39m0.06642  [39m |
Initial loss: 2.2948873


Epoch 1/1: 100%|██████████| 118/118 [00:14<00:00,  8.38it/s]


Epoch 1: Average Loss = 1.4448083639144897
Validation loss (lr=0.050483709150074274, hidden1=376, hidden2=100): 1.4448083639144897
| [39m29       [39m | [39m-1.445   [39m | [39m376.1    [39m | [39m100.7    [39m | [39m0.05048  [39m |
Initial loss: 2.3489451


Epoch 1/1: 100%|██████████| 118/118 [00:13<00:00,  8.77it/s]


Epoch 1: Average Loss = 1.0234999656677246
Validation loss (lr=0.03909036270240408, hidden1=421, hidden2=258): 1.0234999656677246
| [39m30       [39m | [39m-1.023   [39m | [39m421.4    [39m | [39m258.6    [39m | [39m0.03909  [39m |
Best parameters found: {'target': np.float64(-0.2708035409450531), 'params': {'hidden_size_1': np.float64(485.14852190119484), 'hidden_size_2': np.float64(174.1198464063595), 'learning_rate': np.float64(0.008377656906364616)}}
Initial loss: 2.311585


Epoch 1/5: 100%|██████████| 118/118 [00:09<00:00, 13.07it/s]


Epoch 1: Average Loss = 0.2708035409450531


Epoch 2/5: 100%|██████████| 118/118 [00:10<00:00, 11.65it/s]


Epoch 2: Average Loss = 0.08949209004640579


Epoch 3/5: 100%|██████████| 118/118 [00:10<00:00, 10.79it/s]


Epoch 3: Average Loss = 0.06113035976886749


Epoch 4/5: 100%|██████████| 118/118 [00:09<00:00, 12.44it/s]


Epoch 4: Average Loss = 0.05003752186894417


Epoch 5/5: 100%|██████████| 118/118 [00:10<00:00, 11.61it/s]

Epoch 5: Average Loss = 0.046506889164447784
Final evaluation with best parameters (5 epochs):
Final loss: 0.04650689



