# rv_args: Arguments are random variables

> Using Python dataclass and optuna distribution to define arguments of a function, in order to enable documentatable, easy and pythonic way to handle hyperparameters optimization.

In [None]:
#| default_exp rv_args.nucleus

In [None]:
#| hide
%load_ext autoreload
%autoreload 2
from nbdev.showdoc import *

我们使用dataclass，要求传入函数的参数是强类型，而且有一个随机概率分布，这样方便后面定义调参。

In [None]:
#| export
from dataclasses import dataclass, field, MISSING, _MISSING_TYPE, fields, asdict
from typing import List, Dict, Any, Type, Optional, Callable, Union
from optuna.distributions import BaseDistribution, distribution_to_json, json_to_distribution

rv_dataclass_metadata_key = "thu_rv"
rv_missing_value = "thu_rv_missing"


import sys
assert sys.version_info >= (3, 7), "Python version >= 3.7 is required."


@dataclass
class PythonField:
    default:Any = rv_missing_value# The default value of the field
    default_factory:Callable[[], Any] = rv_missing_value# A function to generate the default value of the field
    init:bool=True
    repr:bool=True
    hash:Union[None, bool]=None
    compare:bool=True
    metadata:Union[Dict[str, Any], None]=None
    # kw_only:Union[_MISSING_TYPE, bool]=MISSING
    kw_only:Union[None, bool]=rv_missing_value
    def __post_init__(self):        # print(self)
        if self.default == rv_missing_value:
            self.default = MISSING
        if self.default_factory == rv_missing_value:
            self.default_factory = MISSING
        if self.kw_only == rv_missing_value:
            self.kw_only = MISSING
        # self.default = self.default or MISSING
        # self.default_factory = self.default_factory or MISSING
        # self.kw_only = self.kw_only or MISSING
    def __call__(self, **kwargs: Any) -> Any:
        if self.metadata is None:
            # self.metadata = {**kwargs}
            metadata = {**kwargs}

        if sys.version_info < (3, 9):
            return field(default=self.default, 
                         default_factory=self.default_factory, 
                         init=self.init, 
                         repr=self.repr, 
                         hash=self.hash, 
                         compare=self.compare)
        else:
            return field(default=self.default, 
                        default_factory=self.default_factory, 
                        init=self.init, 
                        repr=self.repr, 
                        hash=self.hash, 
                        compare=self.compare, 
                        metadata=metadata, 
                        kw_only=self.kw_only)
    def __invert__(self):
        # 也就是 ~
        return self()

@dataclass
class RandomVariable(PythonField):
    description: str = "MISSING description. "# The description of the field
    distribution:BaseDistribution = "MISSING distribution. "# The distribution of the data
    def __call__(self, **kwargs: Any) -> Any:
        return super().__call__(description=self.description, distribution=self.distribution, 
                                **{rv_dataclass_metadata_key: self}, 
                                **kwargs)
    def __invert__(self):
        return self()

In [None]:
RandomVariable()()
RandomVariable()().metadata[rv_dataclass_metadata_key]
asdict(RandomVariable()().metadata[rv_dataclass_metadata_key])

{'default': <dataclasses._MISSING_TYPE>,
 'default_factory': <dataclasses._MISSING_TYPE>,
 'init': True,
 'repr': True,
 'hash': None,
 'compare': True,
 'metadata': None,
 'kw_only': <dataclasses._MISSING_TYPE>,
 'description': 'MISSING description. ',
 'distribution': 'MISSING distribution. '}

In [None]:
#| export 
from decorator import decorator
from fastcore.basics import patch_to
from dataclasses import asdict
import pandas as pd
from optuna import Trial

def is_experiment_setting(cls):
    for field in fields(cls):
        if not isinstance(field.metadata.get(rv_dataclass_metadata_key, None), RandomVariable):
           return False
    return True
        
def show_dataframe_doc(cls):
    results = []
    for field in fields(cls):
        rv = field.metadata.get(rv_dataclass_metadata_key, None)
        if rv is None:
            raise ValueError("Class decorated with @experiment_setting needs to use ~RandomVariable fields. ")
        field_info = dict(name=field.name, type=field.type) | asdict(rv)
        results.append(field_info)
    return pd.DataFrame(results)


def get_optuna_search_space(cls, frozen_rvs:set = None):
    search_space = {}
    for field in fields(cls):
        field_name = field.name
        if frozen_rvs is not None and field_name in frozen_rvs:
            continue
        rv = field.metadata.get(rv_dataclass_metadata_key, None)
        if rv is None:
            raise ValueError("Class decorated with @experiment_setting needs to use ~RandomVariable fields. ")
        search_space[field_name] = rv.distribution
    return search_space

from copy import deepcopy
def optuna_suggest(cls:Type, trial:Trial, fixed_meta_params, suggest_params_only_in: set = None, frozen_rvs:set = None):
    suggested_params = deepcopy(fixed_meta_params)
    if suggest_params_only_in is None:
        suggest_params_only_in = set(field.name for field in fields(cls))
    if frozen_rvs is None:
        frozen_rvs = set()
    # fixed_meta_params is dataclass
    if not isinstance(fixed_meta_params, cls):
        raise ValueError(f"fixed_meta_params should be an instance of the {cls.__name__} class.")
    for field in fields(cls):
        if field.name not in suggest_params_only_in:
            continue
        if field.name in frozen_rvs:
            continue
        rv = field.metadata.get(rv_dataclass_metadata_key, None)
        if rv is None:
            raise ValueError("Class decorated with @experiment_setting needs to use ~RandomVariable fields. ")
        suggested_value = trial._suggest(field.name, rv.distribution)
        setattr(suggested_params, field.name, suggested_value)
    return suggested_params

import argparse

def argparse_parser_add_arguments(cls:Type, parser:argparse.ArgumentParser, frozen_rvs:set = None):
    if frozen_rvs is None:
        frozen_rvs = set()
    for field in fields(cls):
        field_name = field.name
        if frozen_rvs is not None and field_name in frozen_rvs:
            continue
        # 如果已经添加过这个 argument，就不要了
        if field_name in parser._optionals._group_actions:
            # print(f"Field {field_name} already exists in parser, skipping.")
            continue
        if isinstance(field.type, str):
            try:
                # Try to evaluate the string as a type
                field_type = eval(field.type)
            except:
                # If evaluation fails, skip type conversion
                field_type = None
        elif isinstance(field.type, type):
            field_type = field.type  
        else:
            raise ValueError(
                f"Field {field_name} has an unsupported type: {field.type}"
            )
        
        if field_type is bool:
            field_type = lambda x: x.lower() == 'true'
            
        rv:RandomVariable = field.metadata.get(rv_dataclass_metadata_key, None)
        if rv is None:
            raise ValueError("Class decorated with @experiment_setting needs to use ~RandomVariable fields. ")
        default_value = rv.default if rv.default != rv_missing_value else rv.default_factory() if rv.default_factory != rv_missing_value else None
        parser.add_argument(f"--{field_name}", type=field_type,  
                            help=rv.description, default=default_value)

    


@decorator
def experiment_setting_decorator(dataclass_func, *args, **kwargs):
    result_cls = dataclass_func(*args, **kwargs)
    if not is_experiment_setting(result_cls):
        raise ValueError("Class decorated with @experiment_setting needs to use ~RandomVariable fields. ")
    patch_to(result_cls, cls_method=True)(show_dataframe_doc)
    patch_to(result_cls, cls_method=True)(get_optuna_search_space)
    patch_to(result_cls, cls_method=True)(optuna_suggest)
    patch_to(result_cls, cls_method=True)(argparse_parser_add_arguments)
    return result_cls

experiment_setting = experiment_setting_decorator(dataclass)

一些使用案例

In [None]:
from optuna.distributions import IntDistribution, FloatDistribution, CategoricalDistribution
@experiment_setting
class SupportVectorClassifierConfig:
    # 惩罚系数 C
    C: float = ~RandomVariable(
        default=1.0,
        description="Regularization parameter. The strength of the regularization is inversely proportional to C.",
        distribution=FloatDistribution(1e-5, 1e2, log=True)
    )
    # 核函数类型
    kernel: str = ~RandomVariable(
        default="rbf",
        description="Kernel type to be used in the algorithm.",
        distribution=CategoricalDistribution(choices=["linear", "poly", "rbf", "sigmoid", "precomputed"])
    )
    
    # 多项式核函数的度数
    degree: int = ~RandomVariable(
        default=3,
        description="Degree of the polynomial kernel function ('poly').",
        distribution=IntDistribution(1, 10, log=False)
    )
    
    ...
    

In [None]:
SupportVectorClassifierConfig()

SupportVectorClassifierConfig(C=1.0, kernel='rbf', degree=3)

## Combining `dataclass` (experiment_setting) with PyTorch `nn.Module`

In [None]:
#| hide
# deprecated
# def my_dataclass(cls):
#     # https://discuss.pytorch.org/t/typeerror-unhashable-type-for-my-torch-nn-module/109424/5
#     cls = dataclass(cls, eq=False) 
#     old_init = cls.__init__
#     def new_init(*args, **kwargs):
#         cls.__pre_init__(*args, **kwargs)
#         old_init(*args, **kwargs)
#     cls.__init__ = new_init
#     return cls

为了解决 https://discuss.pytorch.org/t/typeerror-unhashable-type-for-my-torch-nn-module/109424/5 中提到的问题，首先定义

In [None]:
#| export
@decorator
def pre_init_decorator(init_func, self, *args, **kwargs):
    self.__pre_init__(*args, **kwargs)
    return init_func(self, *args, **kwargs)

In [None]:
#| hide 
# TODO decorator style for dataclass_for_torch
# @decorator
# def dataclass_for_torch_decorator(dataclass_func, cls, eq=False, *args, **kwargs):
#     result_cls = dataclass_func(cls, eq=eq, *args, **kwargs)
#     result_cls.__init__ = pre_init_decorator(result_cls.__init__, cls)
#     return result_cls

In [None]:
#| export
def dataclass_for_torch_decorator(dataclass_func):
    def wrapped_func(cls):
        result_cls = dataclass_func(cls, eq=False) 
        result_cls.__init__ = pre_init_decorator(result_cls.__init__, self=cls) #TODO 非常奇怪，但是似乎测试逻辑是对的
        return cls
    return wrapped_func

In [None]:
#| export
_experiment_module = dataclass_for_torch_decorator(experiment_setting) # 隐藏，不建议直接使用

In [None]:
#| export
import torch
import torch.nn as nn
@_experiment_module
class ExperimentModule(nn.Module):
    def __pre_init__(self, *args, **kwargs):
        # 为什么官方 dataclass 没有 pre init 我气死了。
        super().__init__() # torch 的初始化
        
    def __post_init__(self):
        # dataclass生成的init是没有调用super().__init__()的，所以需要手动调用
        # https://docs.python.org/3/library/dataclasses.html#dataclasses.__post_init__
        # 这里调用PyTorch的init，接下来用户写self.xx = xx就能注册参数、子模块之类的。
        # super().__init__() 
        # 为了防止用户自己忘记写 super().__post_init__() ，我们换个名字方便用户记忆。
        self.setup()
    def setup(self):
        # 用户实现，初始化增量神经网络的增量参数v
        raise NotImplementedError("Should be implemented by subclass! ")
    
    def __repr__(self):
        return super().__repr__()
    
    def extra_repr(self) -> str:
        return super().extra_repr()
    
    def __init_subclass__(cls) -> None:
        super().__init_subclass__()
        original_repr = cls.__repr__
        original_extra_repr = cls.extra_repr
        # dataclass(cls) # 这个3.10以后是in place的， 不保证？
        _experiment_module(cls) # 这个3.10以后是in place的， 不保证？
        dataclass_repr = cls.__repr__
        def extra_repr(self):
            dcr = dataclass_repr(self)
            dcr = dcr[dcr.index("(")+1:dcr.rindex(")")]
            return dcr+original_extra_repr(self)
        # cls.extra_repr = lambda self:(dataclass_repr(self)+original_extra_repr(self)) # dataclass的 repr提供给PyTorch
        cls.extra_repr = extra_repr # dataclass的 repr提供给PyTorch
        cls.__repr__ = original_repr

使用示例

In [None]:
class ExampleDataclassModule(ExperimentModule):
    name:str =~RandomVariable(default="root", description="Name of the person")
    age:int =~RandomVariable(default=25, description="Age of the person")   
    def setup(self):
        print(self.name, self.age)
        self.linear = nn.Linear(self.age, self.age)
    def forward(self, x):
        x = torch.Tensor([x])
        return self.linear(x)
    def extra_repr(self) -> str:
        return super().extra_repr()+" Hello World!"

In [None]:
test = ExampleDataclassModule(name="hello", age=1)
# test = TestFlaxStyle()
print(test) # str
print(repr(test))
print(test(1))

hello 1
ExampleDataclassModule(
  name='hello', age=1 Hello World!
  (linear): Linear(in_features=1, out_features=1, bias=True)
)
ExampleDataclassModule(
  name='hello', age=1 Hello World!
  (linear): Linear(in_features=1, out_features=1, bias=True)
)
tensor([0.4311], grad_fn=<ViewBackward0>)


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()