In [None]:
import os
import hydra
import pytorch_lightning as pl
from contrib.eeg.solvers import *
from torch.optim import Adam
import optuna

os.environ['HYDRA_FULL_ERROR'] = "1"
with hydra.initialize(config_path="config", version_base=None):
    cfg = hydra.compose("config_global",
                        overrides=[
                            'datamodule.dataset_kw.to_load=10000', 
                            'datamodule.dl_kw.batch_size=32',
                            #'datamodule.dataset_kw.simu_name=_dl_ses_fsaverage_st1020_',
                            'datamodule.subset_name=none',
                            ])

pl.seed_everything(333) # seed for reproducibility

# Data
dm = hydra.utils.call(cfg.datamodule)
dm.setup("train")

# Hyperparameters with Optuna
prior_ponde = trial.suggest_float("prior_ponde", 1e-2, 1e2, log=True)
lr_grad = trial.suggest_float("lr_grad", 1e-12, 1e-6, log=True)
Lambda = trial.suggest_float("Lambda", 1e-2, 1e1, log=True)

# Model
solver = EsiGradSolver_n(fwd=hydra.utils.call(cfg.fwd), n_step=10,
                        prior_cost=hydra.utils.call(cfg.prior_cost),
                        obs_cost=hydra.utils.call(cfg.obs_cost),
                        grad_mod=hydra.utils.call(cfg.grad_mod),
                        
                        # Search with Optuna
                        prior_ponde=prior_ponde,
                        lr_grad=lr_grad)
litmodel = EsiLitModule(solver=solver, opt_fn=Adam, lr=5e-5,
                        loss_fn=hydra.utils.call(cfg.cost_functions.train_cost_fn), Lambda=Lambda)
# print(litmodel)
# for k, p in litmodel.named_parameters():
#     print(f"{k}: {p.numel()}")
# print(sum(p.numel() for p in litmodel.parameters() if p.requires_grad))

# Trainer
trainer = hydra.utils.call(cfg.trainer)

In [None]:
pruner = optuna.pruners.MedianPruner()
study = optuna.create_study(direction="minimize", pruner=pruner)
study.optimize(objective, n_trials=100, timeout=600)

print("Number of finished trials: {}".format(len(study.trials)))

print("Best trial:")
trial = study.best_trial

print("  Value: {}".format(trial.value))

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))