In [1]:
import optuna
import pywt
import numpy as np

from main__wavelet_cnn import train

In [3]:
wavelet_list = pywt.wavelist(kind='continuous')

In [5]:
def objective(trial):
    wavelet = trial.suggest_categorical('wavelet', wavelet_list)
    
    data_file = "../../data/william/dataset2/preprocessed_data__no_decimate.csv"

    metrics = train(
        wavelet=wavelet,
        data_file=data_file,
        fig_folder=None,
        seed=np.random.randint(2**32 - 1), 
        n_epochs=500,
        learning_rate=0.005,
        scales=np.geomspace(10, 520, num=20, dtype=int),
        dt=1,
        decimate=5,
        select_every=5,
        verbose=False)
    
    return metrics["accuracy_validation"]

study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=100)

study.best_params 

[32m[I 2022-11-08 18:04:45,070][0m A new study created in memory with name: no-name-afbd9d24-b93a-4934-8203-800b2bf7ff91[0m
100%|████████████████████████████████████████████████████████████| 500/500 [02:44<00:00,  3.05it/s, acc_training=1, acc_validation=0.752]
[32m[I 2022-11-08 18:07:45,764][0m Trial 0 finished with value: 0.7515151500701904 and parameters: {'wavelet': 'cgau6'}. Best is trial 0 with value: 0.7515151500701904.[0m
  wavelet = DiscreteContinuousWavelet(wavelet)
100%|████████████████████████████████████████████████████████████| 500/500 [03:40<00:00,  2.26it/s, acc_training=1, acc_validation=0.842]
[32m[I 2022-11-08 18:11:57,001][0m Trial 1 finished with value: 0.842424213886261 and parameters: {'wavelet': 'cmor'}. Best is trial 1 with value: 0.842424213886261.[0m
100%|████████████████████████████████████████████████████████████| 500/500 [03:46<00:00,  2.21it/s, acc_training=1, acc_validation=0.833]
[32m[I 2022-11-08 18:15:55,338][0m Trial 2 finished with value:

100%|████████████████████████████████████████████████████████| 500/500 [02:55<00:00,  2.85it/s, acc_training=0.539, acc_validation=0.515]
[32m[I 2022-11-08 18:47:53,897][0m Trial 11 finished with value: 0.5151515007019043 and parameters: {'wavelet': 'morl'}. Best is trial 1 with value: 0.842424213886261.[0m
  wavelet = DiscreteContinuousWavelet(wavelet)
100%|█████████████████████████████████████████████████████████| 500/500 [02:53<00:00,  2.88it/s, acc_training=0.63, acc_validation=0.573]
[32m[I 2022-11-08 18:51:13,429][0m Trial 12 finished with value: 0.5727272629737854 and parameters: {'wavelet': 'cmor'}. Best is trial 1 with value: 0.842424213886261.[0m
100%|████████████████████████████████████████████████████████████| 500/500 [02:56<00:00,  2.83it/s, acc_training=1, acc_validation=0.836]
[32m[I 2022-11-08 18:54:26,351][0m Trial 13 finished with value: 0.8363636136054993 and parameters: {'wavelet': 'cgau4'}. Best is trial 1 with value: 0.842424213886261.[0m
100%|███████████

100%|████████████████████████████████████████████████████████| 500/500 [02:55<00:00,  2.85it/s, acc_training=0.995, acc_validation=0.827]
[32m[I 2022-11-08 19:26:04,093][0m Trial 23 finished with value: 0.8272727131843567 and parameters: {'wavelet': 'cgau1'}. Best is trial 1 with value: 0.842424213886261.[0m
100%|████████████████████████████████████████████████████████████| 500/500 [02:56<00:00,  2.84it/s, acc_training=1, acc_validation=0.733]
[32m[I 2022-11-08 19:29:06,738][0m Trial 24 finished with value: 0.7333333492279053 and parameters: {'wavelet': 'gaus3'}. Best is trial 1 with value: 0.842424213886261.[0m
100%|████████████████████████████████████████████████████████████| 500/500 [02:55<00:00,  2.84it/s, acc_training=1, acc_validation=0.745]
[32m[I 2022-11-08 19:32:12,298][0m Trial 25 finished with value: 0.7454545497894287 and parameters: {'wavelet': 'mexh'}. Best is trial 1 with value: 0.842424213886261.[0m
  wavelet = DiscreteContinuousWavelet(wavelet)
100%|██████████

100%|████████████████████████████████████████████████████████| 500/500 [02:55<00:00,  2.85it/s, acc_training=0.486, acc_validation=0.494]
[32m[I 2022-11-08 20:05:04,790][0m Trial 35 finished with value: 0.4939393997192383 and parameters: {'wavelet': 'cgau4'}. Best is trial 1 with value: 0.842424213886261.[0m
100%|████████████████████████████████████████████████████████████| 500/500 [02:47<00:00,  2.98it/s, acc_training=1, acc_validation=0.764]
[32m[I 2022-11-08 20:07:59,071][0m Trial 36 finished with value: 0.7636363506317139 and parameters: {'wavelet': 'gaus7'}. Best is trial 1 with value: 0.842424213886261.[0m
100%|████████████████████████████████████████████████████████████| 500/500 [02:55<00:00,  2.85it/s, acc_training=1, acc_validation=0.788]
[32m[I 2022-11-08 20:11:11,176][0m Trial 37 finished with value: 0.7878788113594055 and parameters: {'wavelet': 'cgau6'}. Best is trial 1 with value: 0.842424213886261.[0m
100%|████████████████████████████████████████████████████████

100%|████████████████████████████████████████████████████████| 500/500 [02:58<00:00,  2.81it/s, acc_training=0.507, acc_validation=0.533]
[32m[I 2022-11-08 20:43:03,631][0m Trial 47 finished with value: 0.5333333611488342 and parameters: {'wavelet': 'shan'}. Best is trial 38 with value: 0.8636363744735718.[0m
100%|████████████████████████████████████████████████████████████| 500/500 [02:56<00:00,  2.83it/s, acc_training=1, acc_validation=0.821]
[32m[I 2022-11-08 20:46:16,696][0m Trial 48 finished with value: 0.821212112903595 and parameters: {'wavelet': 'cgau4'}. Best is trial 38 with value: 0.8636363744735718.[0m
100%|████████████████████████████████████████████████████████████| 500/500 [02:57<00:00,  2.82it/s, acc_training=1, acc_validation=0.739]
[32m[I 2022-11-08 20:49:30,262][0m Trial 49 finished with value: 0.739393949508667 and parameters: {'wavelet': 'cgau3'}. Best is trial 38 with value: 0.8636363744735718.[0m
100%|█████████████████████████████████████████████████████

100%|████████████████████████████████████████████████████████████| 500/500 [03:38<00:00,  2.29it/s, acc_training=1, acc_validation=0.752]
[32m[I 2022-11-08 21:26:24,948][0m Trial 60 finished with value: 0.7515151500701904 and parameters: {'wavelet': 'morl'}. Best is trial 59 with value: 0.8727272748947144.[0m
100%|█████████████████████████████████████████████████████████████| 500/500 [03:38<00:00,  2.29it/s, acc_training=1, acc_validation=0.83]
[32m[I 2022-11-08 21:30:15,011][0m Trial 61 finished with value: 0.8303030133247375 and parameters: {'wavelet': 'morl'}. Best is trial 59 with value: 0.8727272748947144.[0m
100%|████████████████████████████████████████████████████████████| 500/500 [03:15<00:00,  2.55it/s, acc_training=1, acc_validation=0.812]
[32m[I 2022-11-08 21:33:38,419][0m Trial 62 finished with value: 0.8121212124824524 and parameters: {'wavelet': 'gaus3'}. Best is trial 59 with value: 0.8727272748947144.[0m
100%|████████████████████████████████████████████████████

100%|████████████████████████████████████████████████████████████| 500/500 [02:55<00:00,  2.84it/s, acc_training=1, acc_validation=0.779]
[32m[I 2022-11-08 22:03:08,633][0m Trial 71 finished with value: 0.7787878513336182 and parameters: {'wavelet': 'cmor'}. Best is trial 59 with value: 0.8727272748947144.[0m
100%|████████████████████████████████████████████████████████████| 500/500 [02:57<00:00,  2.82it/s, acc_training=1, acc_validation=0.824]
[32m[I 2022-11-08 22:06:12,723][0m Trial 72 finished with value: 0.8242424130439758 and parameters: {'wavelet': 'gaus5'}. Best is trial 59 with value: 0.8727272748947144.[0m
  wavelet = DiscreteContinuousWavelet(wavelet)
100%|████████████████████████████████████████████████████████████| 500/500 [02:56<00:00,  2.84it/s, acc_training=1, acc_validation=0.833]
[32m[I 2022-11-08 22:09:34,543][0m Trial 73 finished with value: 0.8333333134651184 and parameters: {'wavelet': 'cmor'}. Best is trial 59 with value: 0.8727272748947144.[0m
100%|█████

100%|████████████████████████████████████████████████████████████| 500/500 [03:00<00:00,  2.77it/s, acc_training=1, acc_validation=0.724]
[32m[I 2022-11-08 22:43:42,822][0m Trial 84 finished with value: 0.7242424488067627 and parameters: {'wavelet': 'cgau8'}. Best is trial 59 with value: 0.8727272748947144.[0m
  wavelet = DiscreteContinuousWavelet(wavelet)
100%|██████████████████████████████████████████████████████████████| 500/500 [02:58<00:00,  2.80it/s, acc_training=1, acc_validation=0.8]
[32m[I 2022-11-08 22:47:06,914][0m Trial 85 finished with value: 0.800000011920929 and parameters: {'wavelet': 'cmor'}. Best is trial 59 with value: 0.8727272748947144.[0m
100%|████████████████████████████████████████████████████████████| 500/500 [02:56<00:00,  2.84it/s, acc_training=1, acc_validation=0.858]
[32m[I 2022-11-08 22:50:19,680][0m Trial 86 finished with value: 0.8575757741928101 and parameters: {'wavelet': 'cgau8'}. Best is trial 59 with value: 0.8727272748947144.[0m
100%|█████

100%|████████████████████████████████████████████████████████████| 500/500 [02:56<00:00,  2.84it/s, acc_training=1, acc_validation=0.803]
[32m[I 2022-11-08 23:22:42,275][0m Trial 96 finished with value: 0.8030303120613098 and parameters: {'wavelet': 'cgau8'}. Best is trial 59 with value: 0.8727272748947144.[0m
100%|████████████████████████████████████████████████████████████| 500/500 [02:50<00:00,  2.94it/s, acc_training=1, acc_validation=0.739]
[32m[I 2022-11-08 23:25:38,853][0m Trial 97 finished with value: 0.739393949508667 and parameters: {'wavelet': 'gaus5'}. Best is trial 59 with value: 0.8727272748947144.[0m
100%|████████████████████████████████████████████████████████████| 500/500 [02:55<00:00,  2.84it/s, acc_training=1, acc_validation=0.836]
[32m[I 2022-11-08 23:28:51,225][0m Trial 98 finished with value: 0.8363636136054993 and parameters: {'wavelet': 'cgau8'}. Best is trial 59 with value: 0.8727272748947144.[0m
100%|███████████████████████████████████████████████████

{'wavelet': 'morl'}

In [6]:
def objective(trial):
    wavelet = trial.suggest_categorical('wavelet', wavelet_list)
    
    data_file = "../../data/william/dataset2/preprocessed_data__no_decimate.csv"

    metrics = train(
        wavelet=wavelet,
        data_file=data_file,
        fig_folder=None,
        seed=np.random.randint(2**32 - 1), 
        n_epochs=500,
        learning_rate=0.005,
        scales=np.geomspace(10, 520, num=20, dtype=int),
        dt=1,
        decimate=5,
        select_every=5,
        verbose=False)
    
    return metrics["accuracy_validation"]

study2 = optuna.create_study(direction='maximize')
study2.optimize(objective, n_trials=500)

study2.best_params 

[32m[I 2022-11-09 15:37:19,236][0m A new study created in memory with name: no-name-62778ac4-aea0-4740-895d-c0c66987df9c[0m
100%|████████████████████████████████████████████████████████████| 500/500 [02:56<00:00,  2.83it/s, acc_training=1, acc_validation=0.818]
[32m[I 2022-11-09 15:40:32,491][0m Trial 0 finished with value: 0.8181818127632141 and parameters: {'wavelet': 'cgau7'}. Best is trial 0 with value: 0.8181818127632141.[0m
100%|████████████████████████████████████████████████████████| 500/500 [02:49<00:00,  2.94it/s, acc_training=0.514, acc_validation=0.503]
[32m[I 2022-11-09 15:43:28,904][0m Trial 1 finished with value: 0.5030303001403809 and parameters: {'wavelet': 'gaus1'}. Best is trial 0 with value: 0.8181818127632141.[0m
 46%|███████████████████████████▊                                | 232/500 [01:18<01:31,  2.94it/s, acc_training=1, acc_validation=0.752]
[33m[W 2022-11-09 15:45:04,617][0m Trial 2 failed because of the following error: KeyboardInterrupt()[0m
T