In [None]:
# src/optimizer.ipynb
"""
get optimizer using variables specified from config
"""

from tensorflow.keras.optimizers import Adam, SGD, RMSprop, Nadam, Adagrad

# pulling variables from config
import json

with open('config.json') as f:
    cfg = json.load(f)

optimizer = cfg["optimize"]["optimizer"]
learning_rate = cfg["optimize"]["learning_rate"]
loss = cfg["optimize"]["loss"]
metrics = cfg["optimize"]["metrics"]
epochs = cfg["optimize"]["epochs"]
batch_size = cfg["optimize"]["batch_size"]
weight_decay = cfg["optimize"]["weight_decay"]
beta_1 = cfg["optimize"]["beta_1"]
beta_2 = cfg["optimize"]["beta_2"]
epsilon = cfg["optimize"]["epsilon"]
momentum = cfg["optimize"]["momentum"]
rho = cfg["optimize"]["rho"]
amsgrad = cfg["optimize"]["amsgrad"]

import inspect
from tensorflow.keras.optimizers import Adam, SGD, RMSprop, Nadam, Adagrad

def get_optimizer(opt_cfg: dict):
    """
    Create a Keras optimizer from a config dict, safely ignoring unknown params.

    Parameters
    ----------
    opt_cfg : dict
        Example:
        {
          "type": "adam",
          "learning_rate": 0.001,
          "beta_1": 0.9,
          "beta_2": 0.999,
          "epsilon": 1e-7,
          "weight_decay": 0.0,
          "momentum": 0.0
        }

    Returns
    -------
    optimizer : keras optimizer instance
    """
    optimizer_map = {
        "adam": Adam,
        "sgd": SGD,
        "rmsprop": RMSprop,
        "nadam": Nadam,
        "adagrad": Adagrad,
    }

    opt_type = opt_cfg.get("type", "adam").lower()
    if opt_type not in optimizer_map:
        raise ValueError(f"Unknown optimizer type '{opt_type}'")

    OptimizerClass = optimizer_map[opt_type]

    # inspect valid arguments for this optimizer
    sig = inspect.signature(OptimizerClass.__init__)
    valid_params = set(sig.parameters.keys())

    # drop "self"
    valid_params.discard("self")

    # build kwargs with only valid params
    kwargs = {
        k: v for k, v in opt_cfg.items() if k in valid_params
    }

    return OptimizerClass(**kwargs)