# Importe de librerias

In [1]:
import lightning as L
from pytorch_lightning.loggers import CSVLogger

from sklearn.model_selection import train_test_split
import torch
import optuna
from optuna.integration import PyTorchLightningPruningCallback

torch.set_float32_matmul_precision('high')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from dataset import LitPriceData
from trainer import LitTrainer

In [3]:
from utils import splitData

In [4]:
name = "bitcoin"

In [5]:
def objective(trial):
    hidden_size = trial.suggest_int("hidden_size", 32, 256)
    num_layers = trial.suggest_int("num_layers", 2, 8)
    lr = trial.suggest_float("lr", 1e-4, 1e-2, log=True)
    dropout = trial.suggest_float("dropout", 0.0, 0.5)
    sequence_length = trial.suggest_int("sequence_legth", 12, 72, step=12)

    logger = CSVLogger("lightning_logs", name="optuna")
    model = LitTrainer(hidden_size=hidden_size, num_layers=num_layers, lr=lr, dropout=dropout)
    X, y = splitData(name, sequence_length=sequence_length)
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
    dataModule = LitPriceData(X_train, y_train, X_val, y_val)

    # Add pruning callback
    pruning_callback = PyTorchLightningPruningCallback(trial, monitor="val_loss")

    trainer = L.Trainer(
        max_epochs=75,
        accelerator="auto",
        logger=logger,
        enable_progress_bar=False,
        enable_model_summary=False,
        enable_checkpointing=False,
        log_every_n_steps=2,
        callbacks=[pruning_callback]
    )
    trainer.fit(model=model, datamodule=dataModule)
    return trainer.callback_metrics["val_loss"].item()

In [None]:
# Use a pruner in the Optuna study
pruner = optuna.pruners.MedianPruner(n_startup_trials=10, n_warmup_steps=5)
study = optuna.create_study(direction="minimize", pruner=pruner)
study.optimize(objective, n_trials=500)
print("Best trial:")
print(study.best_params)

In [7]:
best_params = study.best_params
best_params

{'hidden_size': 46,
 'num_layers': 2,
 'lr': 0.006910149762728569,
 'dropout': 0.13197445494029517,
 'sequence_legth': 24}

In [8]:
optuna.visualization.plot_optimization_history(study)

In [9]:
optuna.visualization.plot_parallel_coordinate(study)

In [10]:
optuna.visualization.plot_param_importances(study)

In [11]:
optuna.visualization.plot_slice(study)

In [14]:
optuna.visualization.plot_timeline(study)