In [None]:
import PytorchOptuna
import torch
from torchmetrics.classification import BinaryAccuracy
from Model import Model
from PMDataset import PMDataset
from torch.optim import AdamW
from torch.nn import BCEWithLogitsLoss
from torch.utils.data import DataLoader

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
valid_dataset = PMDataset("PreProcessing/optunaData/X_train_optuna.npy","PreProcessing/optunaData/y_train_optuna.npy")
train_dataset = PMDataset("PreProcessing/optunaData/X_valid_optuna.npy","PreProcessing/optunaData/y_valid_optuna.npy")


valid_dataLoader= DataLoader(valid_dataset,batch_size=128,shuffle=False)

train_dataLoader= DataLoader(train_dataset,batch_size=128,shuffle=True)

In [None]:
def create_model(trial):
    
    num_blocks= trial.suggest_int('num_blocks',1,4)
    width= trial.suggest_int('width',64,2048)
    
    return Model(input_size=138,output_size=1,num_blocks=num_blocks,width=width).to(device)

In [None]:
def objective(trial,model):

    lr= trial.suggest_float('lr',1e-6,1e-1, log=True)
    
    
    beta_1= trial.suggest_float('beta_1',0.8,0.999,log=True)
    beta_2=trial.suggest_float('beta_2',0.9,0.9999,log=True)
    
    
    optimizer= AdamW(model.parameters(),lr=lr,betas=(beta_1,beta_2))
        
    
    test_score=PytorchOptuna.train_model(model=model,maxEpochs=100,dataLoaderTrain=train_dataLoader,
                              dataLoaderValid=valid_dataLoader,
                              lossFn=BCEWithLogitsLoss(),optimizer=optimizer,
                              metric=BinaryAccuracy().to(device),device=device,trial=trial,earlyStoppingArgs=[1e-5,10,False])
    
    return test_score

In [None]:

opt= PytorchOptuna.PytorchOptuna('Models/model2', direction='maximize',create_model=create_model,objective=objective,study_name='study2',storage='sqlite:///study2.db',load_if_exists=True)

opt.optimize(n_trials=50)


print("Best trial:")
trial = opt.study.best_trial

print("  Value: ", trial.value)

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))

In [None]:
import optuna
fig = optuna.visualization.plot_param_importances(opt.study)
fig.show()