In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from utils.train_stft import train_cnn_stft
import matplotlib.pyplot as plt
import numpy as np
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.optuna import OptunaSearch
from ray.air import RunConfig

CONFIG = {
    "dropout": tune.choice([0.1,0.2,0.3,0.4]),
    "f_start": tune.uniform(0, 1000),
    "batch_size": tune.choice([16,32,64]),
    "learning_rate": tune.loguniform(1e-5, 1e-2),
    "input_channels": tune.choice([32,64,128]),
}

reporter = CLIReporter(
        parameter_columns=["input_channels", "dropout", "learning_rate", "batch_size", "f_start"],
        metric_columns=["loss", "accuracy", "training_iteration"])
algo = OptunaSearch()

tuner = tune.Tuner(
        tune.with_resources(
                tune.with_parameters(train_cnn_stft),
                resources={"cpu": 16, "gpu": 2}
        ),
        tune_config=tune.TuneConfig(
                metric='accuracy',
                mode="max",
                search_alg=algo,
                num_samples=5,
        ),
        run_config=RunConfig(
                stop={"training_iteration": 60},
                verbose=2,
        ),
        param_space=CONFIG,
)
results = tuner.fit()

# model = models.resnet18(False)
# model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
# model.fc = torch.nn.Linear(512, 3)

best_result = results.get_best_result("accuracy", "max")
print("Best config is:", results.get_best_result().config)

  from .autonotebook import tqdm as notebook_tqdm
  warn(f"Failed to load image Python extension: {e}")
2023-04-20 19:54:37,411	INFO worker.py:1544 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m
  return ot.distributions.UniformDistribution(
  return ot.distributions.LogUniformDistribution(
[32m[I 2023-04-20 19:54:38,376][0m A new study created in memory with name: optuna[0m


0,1
Current time:,2023-04-20 19:54:40
Running for:,00:00:02.24
Memory:,26.6/251.6 GiB

Trial name,status,loc,batch_size,dropout,f_start,input_channels,learning_rate
train_cnn_stft_18e76409,RUNNING,192.168.96.6:4091823,32,0.3,845.314,64,0.000200598




Trial name,accuracy,extrastole_fpr,extrastole_tpr,extrastole_val_fpr,extrastole_val_fpr_denom,extrastole_val_fpr_numer,extrastole_val_tpr,extrastole_val_tpr_denom,extrastole_val_tpr_numer,loss,murmur_fpr,murmur_tpr,murmur_val_fpr,murmur_val_tpr,should_checkpoint,val_accuracy,val_loss
train_cnn_stft_18e76409,0.563959,0,0,0,465,0,0,38,0,0.0301844,0.434548,0.0969697,0.359459,0,True,0.66004,0.0285989


[2m[36m(train_cnn_stft pid=4091823)[0m Number of data per class
[2m[36m(train_cnn_stft pid=4091823)[0m 0    210
[2m[36m(train_cnn_stft pid=4091823)[0m 1    174
[2m[36m(train_cnn_stft pid=4091823)[0m 2    128
[2m[36m(train_cnn_stft pid=4091823)[0m Name: count, dtype: int64
[2m[36m(train_cnn_stft pid=4091823)[0m Number of data per class
[2m[36m(train_cnn_stft pid=4091823)[0m 0    110
[2m[36m(train_cnn_stft pid=4091823)[0m 1     37
[2m[36m(train_cnn_stft pid=4091823)[0m 2     14
[2m[36m(train_cnn_stft pid=4091823)[0m Name: count, dtype: int64
[2m[36m(train_cnn_stft pid=4091823)[0m Number of data per class
[2m[36m(train_cnn_stft pid=4091823)[0m 0    199
[2m[36m(train_cnn_stft pid=4091823)[0m 1    204
[2m[36m(train_cnn_stft pid=4091823)[0m 2    132
[2m[36m(train_cnn_stft pid=4091823)[0m Name: count, dtype: int64
[2m[36m(train_cnn_stft pid=4091823)[0m Number of data per class
[2m[36m(train_cnn_stft pid=4091823)[0m 0    121
[2m[36m(train_cn

In [None]:
df = results.get_dataframe()
columns_to_drop = ['time_this_iter_s', 'should_checkpoint', 'done', 'timesteps_total',
       'episodes_total', 'training_iteration', 'trial_id', 'experiment_id',
       'date', 'timestamp', 'time_total_s', 'pid', 'hostname', 'node_ip',
       'time_since_restore', 'timesteps_since_restore',
       'iterations_since_restore', 'warmup_time', 'extrastole_val_tpr_numer', 
       'extrastole_val_tpr_denom', 'extrastole_val_fpr_numer', 'extrastole_val_fpr_denom']
for column in columns_to_drop:
    df = df.drop(columns=column)
path = '/'.join(df['logdir'][0].split('/')[:4])
df.to_csv(f'{path}/results.csv')

In [None]:
df

Unnamed: 0,loss,accuracy,murmur_tpr,murmur_fpr,extrastole_tpr,extrastole_fpr,murmur_val_tpr,murmur_val_fpr,val_loss,val_accuracy,extrastole_val_tpr,extrastole_val_fpr,config/batch_size,config/dropout,config/f_start,config/input_channels,config/learning_rate,logdir
0,0.021069,0.889292,0.879093,0.097872,0.71066,0.007735,0.488095,0.339667,0.032046,0.540594,0.16,0.154167,32,0.2,814.416605,128,7e-05,/root/ray_results/train_cnn_stft_2023-04-20_19...
1,0.03008,0.595496,0.480769,0.457105,0.0,0.0,0.284404,0.306407,0.027878,0.683761,0.0,0.0,32,0.3,432.228535,128,0.003189,/root/ray_results/train_cnn_stft_2023-04-20_19...
2,0.027874,0.666375,0.574413,0.30303,0.0,0.0,0.644444,0.183673,0.024706,0.785219,0.0,0.0,32,0.4,2.59172,64,0.005213,/root/ray_results/train_cnn_stft_2023-04-20_19...
3,0.027562,0.687956,0.494475,0.355586,0.0,0.0,0.288288,0.212938,0.027318,0.742739,0.0,0.0,32,0.2,192.009553,32,0.005157,/root/ray_results/train_cnn_stft_2023-04-20_19...
4,0.052627,0.858524,0.913151,0.202721,0.0,0.0,0.174419,0.303725,0.06035,0.648276,0.0,0.0,16,0.3,618.533334,32,1.2e-05,/root/ray_results/train_cnn_stft_2023-04-20_19...
5,0.027325,0.700823,0.568182,0.352227,0.0,0.0,0.095652,0.34072,0.027977,0.663866,0.0,0.0,32,0.4,746.424666,32,0.000216,/root/ray_results/train_cnn_stft_2023-04-20_19...
6,0.050537,0.836547,0.996825,0.222222,0.0,0.0,0.351351,0.305882,0.055245,0.719262,0.0,0.0,16,0.2,311.457304,64,1.1e-05,/root/ray_results/train_cnn_stft_2023-04-20_19...
7,0.010762,0.887801,0.975936,0.021362,0.301887,0.002075,0.752577,0.354108,0.016346,0.664444,0.0,0.0,64,0.3,298.123993,32,0.00022,/root/ray_results/train_cnn_stft_2023-04-20_19...
8,0.026979,0.827216,1.0,0.103185,0.0,0.0,0.903704,0.171687,0.029379,0.841542,0.0,0.0,32,0.3,36.506448,128,1e-05,/root/ray_results/train_cnn_stft_2023-04-20_19...
9,0.021268,0.875115,0.996951,0.106439,0.0,0.0,0.293233,0.472892,0.033146,0.56129,0.0,0.0,32,0.2,640.17023,32,3.8e-05,/root/ray_results/train_cnn_stft_2023-04-20_19...
