In [5]:
import ast
import importlib.util
import os

def extract_optimizer_info_from_file(file_path: str):
    # Load source code and parse AST
    with open(file_path, "r", encoding="utf-8") as f:
        code = f.read()
    tree = ast.parse(code)

    optimizer_name = None
    hyperparams = {}

    for node in ast.walk(tree):
        if isinstance(node, ast.ClassDef):
            if any(isinstance(base, ast.Name) and base.id == 'Optimizer' for base in node.bases):
                optimizer_name = node.name

                for item in node.body:
                    if isinstance(item, ast.FunctionDef) and item.name == "__init__":
                        args = item.args
                        param_names = [arg.arg for arg in args.args[2:]]
                        defaults = args.defaults
                        start = len(param_names) - len(defaults)

                        for i, default in enumerate(defaults):
                            param_name = param_names[start + i]
                            try:
                                value = ast.literal_eval(default)
                            except Exception:
                                value = ast.unparse(default) if hasattr(ast, 'unparse') else str(default)
                            hyperparams[param_name] = value
                break

    # Dynamically import the module and get the class
    module_name = os.path.splitext(os.path.basename(file_path))[0]
    spec = importlib.util.spec_from_file_location(module_name, file_path)
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)

    optimizer_class = getattr(mod, optimizer_name)

    return {
        "optimizer_name": optimizer_name,
        "optimizer_class": optimizer_class,
        "hyperparameters": hyperparams
    }



# If run as a script
if __name__ == "__main__":

    file_path = "opt.py"
    info = extract_optimizer_info_from_file(file_path)
    print(info)


{'optimizer_name': 'SoftSign', 'optimizer_class': <class 'opt.SoftSign'>, 'hyperparameters': {'lr': [-5.0, -2.0], 'beta': [0.6, 0.99], 'eps': [-12.3, -1.0], 'weight_decay': [-3.0, -0.7], 'correct_bias': True}}


In [7]:
d = {
    'lr': [-5.0, -2.0],
    'beta': [0.6, 0.99],
    'eps': [-12.3, -1.0],
    'weight_decay': [-3.0, -0.7],
    'correct_bias': True
}

list_only = {k: v for k, v in d.items() if isinstance(v, list)}
print(list_only)


{'lr': [-5.0, -2.0], 'beta': [0.6, 0.99], 'eps': [-12.3, -1.0], 'weight_decay': [-3.0, -0.7]}
