# Imports

In [4]:
import inspect
import itertools

from hamilton import driver, graph_types
import optuna
from optuna.distributions import (
    CategoricalDistribution,
    IntDistribution,
)
from more_complex_project import (
    ingestion_variant1,
    ingestion_variant2,
    model_v1,
    model_v2,
    transform_v1,
    transform_v2,
)
%load_ext hamilton.plugins.jupyter_magic

The hamilton.plugins.jupyter_magic extension is already loaded. To reload it, use:
  %reload_ext hamilton.plugins.jupyter_magic


# Utilities

In [6]:
# useful for creating module sets
def get_list_permutations(**kwargs: list) -> list[tuple]:
    return [perm for perm in itertools.product(*kwargs.values())]

# useful for creating all config dictionaries
def get_dict_permutations(**kwargs: list) -> list[dict]:
    """get all permutations dicts in a list
    
        ```
        _get_permutations(
            a=[0, 1, 2]
            b=["x", "y"],
            c=[True]
        )
        returns
        
            {"a": 0, "b": "x", "c": True},
            {"a": 1, "b": "x", "c": True},
            {"a": 2, "b": "x", "c": True},
            {"a": 0, "b": "y", "c": True},
            {"a": 1, "b": "y", "c": True},
            {"a": 2, "b": "y", "c": True},
        ]
        ```
    """
    perms = []
    for perm in itertools.product(*kwargs.values()):
        perm_dict = dict(zip(kwargs.keys(), perm))
        perms.append(perm_dict)

    return perms


def collect_inputs(dr: driver.Driver):
    """Collect what are required and optional inputs"""
    h_graph = graph_types.HamiltonGraph.from_graph(dr.graph)

    # store all optional dependencies
    optional_deps = set()
    for h_node in h_graph.nodes:
        optional_deps = optional_deps.union(h_node.optional_dependencies)

    inputs = {}
    for h_node in h_graph.nodes:
        # keep only external_input, which includes: required input, optional_input, and config
        if not h_node.is_external_input:
            continue

        # if it's an optional dependency, collect the default parameter value from
        # it's definition function
        if h_node.name in optional_deps:
            origin_function = h_node.originating_functions[0]
            param = inspect.signature(origin_function).parameters[h_node.name]
            inputs[h_node.name] = dict(
                type=h_node.type,
                required=False,
                default=param.default,
            )
        
        # if not in config either, then it's a required input
        elif h_node.name not in dr.graph.config.keys():
            # required=True isn't the same as default=None
            # this param is required=False and default=None: `param: Optional[int] = None`
            inputs[h_node.name] = dict(
                type=h_node.type,
                required=True,
            )

    return inputs

In [7]:
# collect input example
dr = driver.Builder().with_modules(ingestion_variant1, model_v1, transform_v1).build()
collect_inputs(dr)

{'beta': {'type': float, 'required': True},
 'gamma': {'type': float, 'required': False, 'default': 5.0}}

# Optimization example

## Search space definition

In [3]:
# for `driver.Builder.with_modules(...)`
modules_sets = get_list_permutations(
    ingestion=[ingestion_variant1, ingestion_variant2],
    model=[model_v1, model_v2],
    transform=[transform_v1, transform_v2],
)
modules_sets_ids = list(range(len(modules_sets)))

# for `@config.when()` and `driver.Builder().with_config()`
config_sets = []

# for `Driver.excutes(inputs={...})`
# define literals (e.g., a path, a fixed parameter)
# or distributions with Optuna: https://optuna.readthedocs.io/en/stable/reference/distributions.html
inputs_space = {
    "impute_method": CategoricalDistribution(["none", "zero", "mean"]),
    "n_iterations": IntDistribution(low=2, high=5, log=False, step=1),
    "n_rows": IntDistribution(low=10, high=1000, log=True, step=1),
    #"raw_data_path": ["/path/my_data.parquet"],
}


# create the single search space dictionary used by Optuna
# TODO avoid dict key conflicts between modules, config, and inputs
search_space = dict(
    modules_set_id=CategoricalDistribution(modules_sets_ids),
    # config=CategoricalDistribution(config_sets),
    **inputs_space
)

## Optimization loop
Using Optuna ["Ask-and-tell: Define-and-run" interface](https://optuna.readthedocs.io/en/stable/tutorial/20_recipes/009_ask_and_tell.html#define-and-run)

In [4]:
# Optuna has many utils for storing, reloading, sampling, pruning
# ref: https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html
study = optuna.create_study(direction="minimize")
final_var = "fit_model"

[I 2024-05-29 17:25:31,294] A new study created in memory with name: no-name-ea2ae871-60f5-4121-b4ba-54114675a8e2


In [8]:
# launch study
for _ in range(3):
    trial = study.ask(search_space)
    params = trial.params
    
    modules_set = modules_sets[params["modules_set_id"]]
    dr = (
        driver.Builder()
        .with_modules(*modules_set)
        .build()
    )
    inputs = {k:v for k,v in params.items() if k not in ["modules_set_id", "config"]}
    results = dr.execute([final_var], inputs=inputs)
    
    # I would return results[final_var], but the current functions don't return a float
    study.tell(trial, _)
    
study.best_trial

FrozenTrial(number=3, state=1, values=[0.0], datetime_start=datetime.datetime(2024, 5, 29, 17, 25, 46, 681961), datetime_complete=datetime.datetime(2024, 5, 29, 17, 25, 46, 684732), params={'modules_set_id': 0, 'impute_method': 'zero', 'n_iterations': 2, 'n_rows': 213}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'modules_set_id': CategoricalDistribution(choices=(0, 1, 2, 3, 4, 5, 6, 7)), 'impute_method': CategoricalDistribution(choices=('none', 'zero', 'mean')), 'n_iterations': IntDistribution(high=5, log=False, low=2, step=1), 'n_rows': IntDistribution(high=1000, log=True, low=10, step=1)}, trial_id=3, value=None)