In [5]:
from utils.train_stft import fit_cnn_stft
import pandas as pd
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.optuna import OptunaSearch
from ray.air import RunConfig
import warnings

warnings.filterwarnings('ignore')

CONFIG = {
    "dropout": tune.choice([0.1,0.2]),
    "f_start": tune.choice([0, 100, 200, 300, 400, 500, 600, 700]),
    "batch_size": tune.choice([512, 256]),
    "learning_rate": tune.loguniform(1e-4, 1e-2),
    "input_channels": tune.choice([32,64,128]),
}

reporter = CLIReporter(
        parameter_columns=["input_channels", "dropout", "learning_rate", "batch_size", "f_start"],
        metric_columns=["loss", "accuracy", "training_iteration"])

algo = OptunaSearch()

tuner = tune.Tuner(
        tune.with_resources(
                tune.with_parameters(fit_cnn_stft),
                resources={"cpu": 16, "gpu": 2}
        ),
        tune_config=tune.TuneConfig(
                metric='g_mean',
                mode="max",
                search_alg=algo,
                num_samples=10,
        ),
        run_config=RunConfig(
                stop={"training_iteration": 10},
                verbose=2,
        ),
        param_space=CONFIG,
)

results = tuner.fit()

best_result = results.get_best_result("g_mean", "max")
print("Best config is:", results.get_best_result().config)

[32m[I 2023-04-22 23:21:35,647][0m A new study created in memory with name: optuna[0m


0,1
Current time:,2023-04-22 23:21:37
Running for:,00:00:02.13
Memory:,27.4/251.6 GiB

Trial name,status,loc,batch_size,dropout,f_start,input_channels,learning_rate
fit_cnn_stft_973aae98,RUNNING,192.168.96.6:325499,512,0.1,100,64,0.00682465


Trial name,accuracy,extrastole_tpr,extrastole_val_tpr,g_mean,loss,murmur_tpr,murmur_val_tpr,should_checkpoint,val_accuracy,val_loss
fit_cnn_stft_973aae98,0.630532,0.674086,0.563636,0.322438,0.00203658,0.478333,0.916667,True,0.413598,0.00333855


[2m[36m(fit_cnn_stft pid=325499)[0m Frequency from 100 to 200
[2m[36m(fit_cnn_stft pid=325499)[0m               precision    recall  f1-score   support
[2m[36m(fit_cnn_stft pid=325499)[0m 
[2m[36m(fit_cnn_stft pid=325499)[0m       normal       0.84      0.23      0.36       202
[2m[36m(fit_cnn_stft pid=325499)[0m       murmur       0.53      0.44      0.48        96
[2m[36m(fit_cnn_stft pid=325499)[0m   extrastole       0.25      1.00      0.40        55
[2m[36m(fit_cnn_stft pid=325499)[0m 
[2m[36m(fit_cnn_stft pid=325499)[0m     accuracy                           0.41       353
[2m[36m(fit_cnn_stft pid=325499)[0m    macro avg       0.54      0.56      0.42       353
[2m[36m(fit_cnn_stft pid=325499)[0m weighted avg       0.66      0.41      0.40       353
[2m[36m(fit_cnn_stft pid=325499)[0m 
[2m[36m(fit_cnn_stft pid=325499)[0m Frequency from 100 to 200
[2m[36m(fit_cnn_stft pid=325499)[0m               precision    recall  f1-score   support
[2m[

[2m[36m(fit_cnn_stft pid=325499)[0m   _warn_prf(average, modifier, msg_start, len(result))
[2m[36m(fit_cnn_stft pid=325499)[0m   _warn_prf(average, modifier, msg_start, len(result))
[2m[36m(fit_cnn_stft pid=325499)[0m   _warn_prf(average, modifier, msg_start, len(result))
2023-04-22 23:26:27,076	INFO tune.py:798 -- Total run time: 291.43 seconds (291.40 seconds for the tuning loop).


Best config is: {'dropout': 0.1, 'f_start': 100, 'batch_size': 512, 'learning_rate': 0.00682465163171436, 'input_channels': 64}


In [2]:
results.get_dataframe()

Unnamed: 0,loss,accuracy,g_mean,murmur_tpr,extrastole_tpr,val_loss,val_accuracy,murmur_val_tpr,extrastole_val_tpr,time_this_iter_s,...,time_since_restore,timesteps_since_restore,iterations_since_restore,warmup_time,config/batch_size,config/dropout,config/f_start,config/input_channels,config/learning_rate,logdir
0,0.003827,0.706567,0.690549,0.800743,0.862385,0.003698,0.695122,0.333333,0.294118,7.426954,...,82.004915,0,10,0.003143,256,0.1,100,128,0.000835,/root/ray_results/train_cnn_stft_2023-04-22_20...
1,0.002081,0.706584,0.567481,0.858017,0.661386,0.004246,0.421642,0.529412,0.416667,7.442525,...,78.328189,0,10,0.003143,512,0.1,100,32,0.005662,/root/ray_results/train_cnn_stft_2023-04-22_20...
2,0.003483,0.773137,0.664195,0.817471,0.8125,0.007189,0.616279,0.716981,0.352941,7.444022,...,78.093661,0,10,0.003143,256,0.1,300,128,0.00392,/root/ray_results/train_cnn_stft_2023-04-22_20...
3,0.001601,0.77029,0.55196,0.890663,0.619718,0.004616,0.337121,0.676471,0.25641,7.182968,...,75.942375,0,10,0.003143,512,0.2,300,128,0.007495,/root/ray_results/train_cnn_stft_2023-04-22_20...
4,0.002365,0.62506,0.235288,0.646699,0.36383,0.003864,0.751037,0.388889,0.28125,7.392783,...,77.402075,0,10,0.003143,512,0.2,50,32,0.000477,/root/ray_results/train_cnn_stft_2023-04-22_20...
5,0.00352,0.775,0.720082,0.84595,0.851211,0.008222,0.591603,0.696429,0.0,7.728003,...,79.037829,0,10,0.003143,256,0.2,200,64,0.001785,/root/ray_results/train_cnn_stft_2023-04-22_20...
6,0.004176,0.663985,0.517213,0.725138,0.713262,0.004669,0.474359,0.414286,0.454545,7.355188,...,77.61679,0,10,0.003143,256,0.2,300,64,0.00042,/root/ray_results/train_cnn_stft_2023-04-22_20...
7,0.00223,0.668816,0.579521,0.757344,0.765203,0.004237,0.593361,0.468085,0.5,7.556767,...,80.16319,0,10,0.003143,512,0.1,0,128,0.00047,/root/ray_results/train_cnn_stft_2023-04-22_20...
8,0.003214,0.761602,0.569306,0.829561,0.686275,0.004226,0.553648,0.690909,0.096774,7.158191,...,75.316018,0,10,0.003143,256,0.1,200,128,0.004523,/root/ray_results/train_cnn_stft_2023-04-22_20...
9,0.002225,0.543655,0.022956,0.547117,0.041958,0.004162,0.631179,0.638889,0.027027,6.976046,...,73.640375,0,10,0.003143,512,0.1,150,64,0.000146,/root/ray_results/train_cnn_stft_2023-04-22_20...
