In [3]:
# %% [markdown]
# # レスラー方程式(外力のある状態)
# 
# レスラー方程式の外力のある場合に関して，$sin$波に位相のシフトがある場合を考える．

# %%
#必要なパッケージのインポート

import numpy as np

import matplotlib
import matplotlib.pyplot as plt

import reservoirpy as rpy

from scipy.integrate import solve_ivp
import pandas as pd
from reservoirpy.observables import nrmse, rsquare

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


rpy.verbosity(0)

from reservoirpy.nodes import Reservoir, Ridge
from reservoirpy.datasets import mackey_glass

# just a little tweak to center the plots, nothing to worry about
from IPython.core.display import HTML
HTML("""
<style>
.img-center {
    display: block;
    margin-left: auto;
    margin-right: auto;
    }
.output_png {
    display: table-cell;
    text-align: center;
    vertical-align: middle;
    }
</style>
""")

rpy.set_seed(42)




# %%
filename_with_force = 'rossler_data_with_shifted_force2.1.1.csv'

# CSVファイルを読み込む
data_loaded_with_force = pd.read_csv(filename_with_force)

# CSVから値を抽出してNumpy配列に格納
X = data_loaded_with_force[['x', 'y', 'P_shifted']].values

from reservoirpy.datasets import to_forecasting

train_len = 10000
test_len = 10000

x, y = to_forecasting(X, forecast=1)
X_train, y_train = x[:train_len], y[:train_len]
X_test, y_test = x[train_len:train_len+test_len], y[train_len:train_len+test_len]

dataset = ((X_train, y_train), (X_test, y_test))

# This step may vary depending on what you put inside 'dataset'
train_data, validation_data = dataset
X_train, y_train = train_data
X_val, y_val = validation_data


In [4]:
import optuna


In [5]:

# Optunaの目的関数
def objective(trial):
    # パラメータの提案
    N_value = 5000  # Nの値は固定
    sr = trial.suggest_float('sr', 1e-2, 10, log = True)
    lr = trial.suggest_float('lr', 1e-3, 1, log = True)
    iss = trial.suggest_float('iss', 0, 1)
    ridge = trial.suggest_float('ridge', 1e-9, 1e-2, log = True)
    
    losses = []; r2s = [];
    for n in range(3):  # 例としてインスタンスごとに3回試行
        # モデルの構築
        reservoir = Reservoir(N_value, sr=sr, lr=lr, input_scaling=iss, seed=n)
        readout = Ridge(ridge=ridge)
        model = reservoir >> readout

        # モデルの訓練とテスト
        # Train your model and test your model.
        prediction = model.fit(X_train, y_train) \
                           .run(X_test)
        
        loss = nrmse(y_test, prediction, norm_value=np.ptp(X_train))
        r2 = rsquare(y_test, prediction)

        # 評価指標の計算
        loss = nrmse(y_test, prediction, norm_value=np.ptp(X_train))
        r2 = rsquare(y_test, prediction)
        losses.append(loss)
        r2s.append(r2)
        
         # トライアルの進捗を報告
        trial.report(np.mean(losses), n)

        # プルーニングのチェック
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    return np.mean(losses)


In [6]:

# Optunaのプルーナーを定義
pruner = optuna.pruners.MedianPruner()

# Optunaのスタディを作成し、プルーナーを指定
study = optuna.create_study(direction='minimize', pruner=pruner)
study.optimize(objective, n_trials=300)


[I 2023-11-11 19:20:00,074] A new study created in memory with name: no-name-263488af-c825-4606-bb56-30ffb673c2a6
[W 2023-11-11 19:21:13,648] Trial 0 failed with parameters: {'sr': 2.095696434533729, 'lr': 0.0025782978281261226, 'iss': 0.5354247388132112, 'ridge': 0.0016083602491766841} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/Users/manqueenmannequin/miniforge3/envs/python38gen2/lib/python3.8/site-packages/optuna/study/_optimize.py", line 200, in _run_trial
    value_or_values = func(trial)
  File "/var/folders/y4/6674dz116775fry8d67k82wm0000gn/T/ipykernel_15451/3564247350.py", line 19, in objective
    prediction = model.fit(X_train, y_train) \
  File "/Users/manqueenmannequin/miniforge3/envs/python38gen2/lib/python3.8/site-packages/reservoirpy/model.py", line 917, in run
    states_seq = self._run(
  File "/Users/manqueenmannequin/miniforge3/envs/python38gen2/lib/python3.8/site-packages/reservoirpy/model.py", line 446, in _run
 

KeyboardInterrupt: 

In [None]:
# 最適なパラメータの取得と保存
best_params = study.best_param
