In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from utils.train_stft import train_cnn_stft
import matplotlib.pyplot as plt
import numpy as np
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.optuna import OptunaSearch

CONFIG = {
    "dropout": tune.choice([0.1,0.2,0.3,0.4]),
    "f_start": tune.uniform(0, 1000),
    "batch_size": tune.choice([4,8,16,32,64]),
    "learning_rate": tune.loguniform(1e-5, 1e-2),
    "input_channels": tune.choice([32,64,128]),
}
reporter = CLIReporter(
        parameter_columns=["input_channels", "dropout", "learning_rate", "batch_size", "f_start"],
        metric_columns=["loss", "accuracy", "training_iteration"])

tuner = tune.Tuner(
        tune.with_resources(
                tune.with_parameters(train_cnn_stft),
                resources={"cpu": 16, "gpu": 2}
        ),
        tune_config=tune.TuneConfig(
                metric='accuracy',
                mode="max",
                search_alg=OptunaSearch(),
        ),
        param_space=CONFIG,
)
results = tuner.fit()

# model = torchvision.models.resnet18(False)
# model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
# model.fc = torch.nn.Linear(512, 3)

best_result = results.get_best_result("accuracy", "max")