In [6]:
%reload_ext autoreload
%autoreload 2

# Imports

In [7]:
from kret_notebook import *  # NOTE import first
from kret_lgbm._core.lgbm_nb_imports import *
from kret_lightning._core.lightning_nb_imports import *
from kret_matplotlib._core.mpl_nb_imports import *
from kret_np_pd._core.np_pd_nb_imports import *
from kret_optuna._core.optuna_nb_imports import *
from kret_polars._core.polars_nb_imports import *
from kret_rosetta._core.rosetta_nb_imports import *
from kret_sklearn._core.sklearn_nb_imports import *
from kret_torch_utils._core.torch_nb_imports import *
from kret_tqdm._core.tqdm_nb_imports import *
from kret_type_hints._core.types_nb_imports import *
from kret_utils._core.utils_nb_imports import *

# from kret_wandb._core.wandb_nb_imports import *  # NOTE this is slow to import

In [8]:
from kret_lightning.examples.cifar10_module import *
from kret_lightning.examples.cifar10_datamodule import *

# Load Data

In [9]:
cifar_dm = CIFAR10DataModule(UKS_CONSTANTS.DATA_DIR / "CIFAR10", batch_size=64, num_workers=6)

In [10]:
cifar_dm.data_dir

PosixPath('/Users/Akseldkw/coding/data_kretsinger/CIFAR10')

# Implementation

In [11]:
cifar_nn = CIFAR10ResNet(lr=0.001)
cifar_nn.hparams_initial

Saving hparams, ignoring ()
Saving hparams, ignoring ()


"dropout_rate":     0.3
"l1_penalty":       0.0
"l2_penalty":       0.0
"lr":               0.001
"num_blocks":       2
"num_classes":      10
"num_filters":      64
"patience":         10
"warmup_step_frac": 0.1

In [12]:
cifar_nn.save_load_logging_dict

{'save_dir': PosixPath('/Users/Akseldkw/coding/data_kretsinger/lightning_logs'),
 'name': 'CIFAR10ResNet',
 'version': 'v_001'}

In [13]:
static_args = TrainerStaticDefaults.OPTUNA_SWEEP.copy()
# static_args["max_epochs"] = 10
static_args

{'max_epochs': 10,
 'limit_train_batches': 0.25,
 'limit_val_batches': 0.5,
 'log_every_n_steps': 50,
 'enable_model_summary': False,
 'enable_checkpointing': False,
 'gradient_clip_val': 1.0,
 'max_time': {'minutes': 5}}

In [14]:
dynamic_args = TrainerDynamicDefaults.trainer_dynamic_defaults(cifar_nn, cifar_dm)
dynamic_args

{'logger': <lightning.pytorch.loggers.csv_logs.CSVLogger at 0x30a8cbce0>,
 'default_root_dir': PosixPath('/Users/Akseldkw/coding/data_kretsinger/lightning_logs/CIFAR10ResNet/v_001'),
 'callbacks': []}

In [15]:
dynamic_args["logger"].save_dir, dynamic_args["logger"].name

('/Users/Akseldkw/coding/data_kretsinger/lightning_logs', 'CIFAR10ResNet')

In [16]:
trainer_args = static_args | dynamic_args

In [17]:
trainer = L.Trainer(**trainer_args)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores


In [18]:
def objective(trial: optuna.trial.Trial) -> float:

    # preset = trial.suggest_categorical("preset", ["tiny", "small", "medium", "large", "xlarge"])
    num_blocks = trial.suggest_int("num_blocks", 1, 4)  # 1=tiny, 2=small, 3=medium, 4=large
    num_filters = trial.suggest_categorical("num_filters", [32, 64, 128])
    dropout_rate = trial.suggest_float("dropout_rate", 0.1, 0.5)
    lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
    l1 = trial.suggest_float("l1", 1e-5, 1e-2, log=True)
    l2 = trial.suggest_float("l2", 1e-5, 1e-2, log=True)

    model = CIFAR10ResNet(
        num_blocks=num_blocks,
        num_filters=num_filters,
        dropout_rate=dropout_rate,
        lr=lr,
        l1_penalty=l1,
        l2_penalty=l2,
        patience=4,
    )
    model._sweep_mode = True  # NOTE important!
    dynamic_args = TrainerDynamicDefaults.trainer_dynamic_defaults(cifar_nn, cifar_dm, trial=trial)
    trainer_args = static_args | dynamic_args

    trainer = L.Trainer(**trainer_args)  # New trainer per trial!
    assert trainer.logger is not None
    trainer.logger.log_hyperparams(model.hparams_initial)
    trainer.fit(model, datamodule=cifar_dm, **TrainerStaticDefaults.TRAINER_FIT)

    return trainer.callback_metrics["val_loss"].item()

In [20]:
OptunaDefaults.CREATE_STUDY_DEFAULTS

{'pruner': <optuna.pruners._hyperband.HyperbandPruner at 0x303bc87a0>,
 'load_if_exists': True}

In [21]:
pruner = optuna.pruners.HyperbandPruner()

study = optuna.create_study(
    storage=UKS_CONSTANTS.OPTUNA_STORAGE_DB,
    pruner=pruner,
    study_name="cifar10_resnet",
    direction="minimize",
    load_if_exists=True,
)

[I 2026-01-22 00:14:03,845] Using an existing study with name 'cifar10_resnet' instead of creating a new one.


In [22]:
optim_args = OptunaDefaults.STUDY_8_HOURS | OptunaDefaults.OPTIM_STUDY_DEF
optim_args

{'n_trials': None,
 'timeout': 28800,
 'n_jobs': -2,
 'gc_after_trial': True,
 'show_progress_bar': True}

In [None]:
study.optimize(objective)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores


Saving hparams, ignoring ()
Saving hparams, ignoring ()


/Users/Akseldkw/micromamba/envs/kret_312/lib/python3.12/site-packages/lightning/fabric/loggers/csv_logs.py:268: Experiment logs directory /Users/Akseldkw/coding/data_kretsinger/lightning_logs/CIFAR10ResNet/v_001 exists and is not empty. Previous log files in this directory will be deleted when the new ones are saved!
Loading `train_dataloader` to estimate number of stepping batches.


Output()

Metric val_loss improved. New best score: 2.057


# Evaluate

In [22]:
study.best_trial

FrozenTrial(number=0, state=<TrialState.COMPLETE: 1>, values=[0.41628187894821167], datetime_start=datetime.datetime(2026, 1, 19, 11, 4, 24, 454090), datetime_complete=datetime.datetime(2026, 1, 19, 11, 14, 19, 454888), params={'n_layers': 3, 'dropout': 0.4510958969441914, 'n_units_l0': 86, 'n_units_l1': 4, 'n_units_l2': 6}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'n_layers': IntDistribution(high=3, log=False, low=1, step=1), 'dropout': FloatDistribution(high=0.5, log=False, low=0.2, step=None), 'n_units_l0': IntDistribution(high=128, log=True, low=4, step=1), 'n_units_l1': IntDistribution(high=128, log=True, low=4, step=1), 'n_units_l2': IntDistribution(high=128, log=True, low=4, step=1)}, trial_id=1, value=None)

## Load

In [30]:
study_summaries = optuna.study.get_all_study_summaries(storage=UKS_CONSTANTS.OPTUNA_STORAGE_DB)

In [31]:
len(study_summaries)

1

In [32]:
studyload = study_summaries[0]

In [35]:
studyload.n_trials

16

In [33]:
studyload.datetime_start

datetime.datetime(2026, 1, 19, 11, 4, 24, 454090)

In [34]:
studyload.best_trial

FrozenTrial(number=0, state=<TrialState.COMPLETE: 1>, values=[0.41628187894821167], datetime_start=datetime.datetime(2026, 1, 19, 11, 4, 24, 454090), datetime_complete=datetime.datetime(2026, 1, 19, 11, 14, 19, 454888), params={'n_layers': 3, 'dropout': 0.4510958969441914, 'n_units_l0': 86, 'n_units_l1': 4, 'n_units_l2': 6}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'n_layers': IntDistribution(high=3, log=False, low=1, step=1), 'dropout': FloatDistribution(high=0.5, log=False, low=0.2, step=None), 'n_units_l0': IntDistribution(high=128, log=True, low=4, step=1), 'n_units_l1': IntDistribution(high=128, log=True, low=4, step=1), 'n_units_l2': IntDistribution(high=128, log=True, low=4, step=1)}, trial_id=1, value=None)

In [38]:
out = optuna.importance.get_param_importances(study)

In [None]:
optuna.visualization.plot_optimization_history(study)

In [40]:
# Plot parameter importances
optuna.visualization.plot_param_importances(study)