# rv_args: Arguments are random variables

> Using Python dataclass and optuna distribution to define arguments of a function, in order to enable documentatable, easy and pythonic way to handle hyperparameters optimization.

In [298]:
#| default_exp rv_args.nucleus

In [299]:
#| hide
%load_ext autoreload
%autoreload 2
from nbdev.showdoc import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [300]:
from dataclasses import dataclass, field, MISSING, _MISSING_TYPE, fields, asdict, Field
# Field?
t = field(default_factory=list)
# 没有 __dict__ 需要特殊判断

我们使用dataclass，要求传入函数的参数是强类型，而且有一个随机概率分布，这样方便后面定义调参。

In [301]:
#| export
from dataclasses import dataclass, field, MISSING, _MISSING_TYPE, fields, asdict, Field
from typing import List, Dict, Any, Type, Optional, Callable, Union
from optuna.distributions import BaseDistribution, distribution_to_json, json_to_distribution

rv_dataclass_metadata_key = "thu_rv"
rv_missing_value = "thu_rv_missing"


import sys
import warnings
assert sys.version_info >= (3, 7), "Python version >= 3.7 is required."


@dataclass
class ScholarPythonField:
    default:Any = rv_missing_value# The default value of the field
    default_factory:Callable[[], Any] = rv_missing_value# A function to generate the default value of the field
    init:bool=True
    repr:bool=True
    hash:Union[None, bool]=None
    compare:bool=True
    metadata:Union[Dict[str, Any], None]=None
    # kw_only:Union[_MISSING_TYPE, bool]=MISSING
    kw_only:Union[None, bool]=rv_missing_value
    def __post_init__(self):        # print(self)
        if self.default == rv_missing_value:
            self.default = MISSING
        if self.default_factory == rv_missing_value:
            self.default_factory = MISSING
        if self.kw_only == rv_missing_value:
            self.kw_only = MISSING
        # self.default = self.default or MISSING
        # self.default_factory = self.default_factory or MISSING
        # self.kw_only = self.kw_only or MISSING
    def __call__(self, **kwargs: Any) -> Any:
        if self.metadata is None:
            # self.metadata = {**kwargs}
            metadata = {**kwargs}
        else:
            metadata = {**self.metadata, **kwargs}

        if sys.version_info < (3, 9):
            return field(default=self.default, 
                         default_factory=self.default_factory, 
                         init=self.init, 
                         repr=self.repr, 
                         hash=self.hash, 
                         compare=self.compare)
        else:
            return field(default=self.default, 
                        default_factory=self.default_factory, 
                        init=self.init, 
                        repr=self.repr, 
                        hash=self.hash, 
                        compare=self.compare, 
                        metadata=metadata, 
                        kw_only=self.kw_only)
    def __invert__(self):
        # 也就是 `~` 运算符
        return self()
    @classmethod
    def from_field(cls, py_field:Field, **kwargs) -> "Self":
        # py_field = cls(**py_field)
        if sys.version_info < (3, 9):
            scholar_field = cls(default=py_field.default, 
                         default_factory=py_field.default_factory, 
                         init=py_field.init, 
                         repr=py_field.repr, 
                         hash=py_field.hash, 
                         compare=py_field.compare, 
                         **kwargs)
        else:
            scholar_field = cls(default=py_field.default, 
                        default_factory=py_field.default_factory, 
                        init=py_field.init, 
                        repr=py_field.repr, 
                        hash=py_field.hash, 
                        compare=py_field.compare, 
                        metadata=py_field.metadata, 
                        kw_only=py_field.kw_only, 
                        **kwargs)
            # scholar_field.metadata[rv_dataclass_metadata_key] = scholar_field
            # if len(py_field.metadata)==0:
            #     scholar_field.metadata = asdict(scholar_field) # 自我信息指向
            # else:
            #     metadata = py_field.metadata.copy()
            #     pass # 否则就保留原本的field自带的信息
            # **{rv_dataclass_metadata_key: self}, 
        return scholar_field



@dataclass
class Variable(ScholarPythonField):
    description: str = "MISSING description. "# The description of the field
    distribution:BaseDistribution = "MISSING distribution. "# The distribution of the data
    
    def __call__(self, **kwargs: Any) -> Any:
        return super().__call__(description=self.description, distribution=self.distribution, 
                                **{rv_dataclass_metadata_key: self}, 
                                **kwargs)
    @classmethod
    def from_field(cls, py_field:Field) -> "Self":
        description = py_field.metadata.get("description", "MISSING description. ")
        distribution = py_field.metadata.get("distribution", "MISSING distribution. ")
        return super().from_field(py_field, description=description, distribution=distribution)



@dataclass
class IndependentVariable(Variable):
    """独立变量"""

@dataclass
class DependentVariable(Variable):
    """相关变量"""

@dataclass
class RandomVariable(Variable):
    """Variable 是 RandomVariable 的别名，已废弃，请使用 RandomVariable。"""
    def __init__(self, *args, **kwargs):
        # warnings.warn(
        #     "RandomVariable 已废弃，请使用 Variable 代替。",
        #     DeprecationWarning,
        #     stacklevel=2
        # )
        super().__init__(*args, **kwargs)


In [302]:
py_field = field(
    default=1.0,
    metadata={
        "help": "The regularization parameter. The strength of the regularization is inversely proportional to C. Must be strictly positive. The penalty is a squared l2 penalty."
    },
)
var = Variable.from_field(py_field)
var

Variable(default=1.0, default_factory=<dataclasses._MISSING_TYPE object at 0x7f5ba85e7010>, init=True, repr=True, hash=None, compare=True, metadata=mappingproxy({'help': 'The regularization parameter. The strength of the regularization is inversely proportional to C. Must be strictly positive. The penalty is a squared l2 penalty.'}), kw_only=<dataclasses._MISSING_TYPE object at 0x7f5ba85e7010>, description='MISSING description. ', distribution='MISSING distribution. ')

In [303]:
~var

Field(name=None,type=None,default=1.0,default_factory=<dataclasses._MISSING_TYPE object at 0x7f5ba85e7010>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The regularization parameter. The strength of the regularization is inversely proportional to C. Must be strictly positive. The penalty is a squared l2 penalty.', 'description': 'MISSING description. ', 'distribution': 'MISSING distribution. ', 'thu_rv': Variable(default=1.0, default_factory=<dataclasses._MISSING_TYPE object at 0x7f5ba85e7010>, init=True, repr=True, hash=None, compare=True, metadata=mappingproxy({'help': 'The regularization parameter. The strength of the regularization is inversely proportional to C. Must be strictly positive. The penalty is a squared l2 penalty.'}), kw_only=<dataclasses._MISSING_TYPE object at 0x7f5ba85e7010>, description='MISSING description. ', distribution='MISSING distribution. ')}),kw_only=<dataclasses._MISSING_TYPE object at 0x7f5ba85e7010>,_field_type=None)

In [304]:
# 其实不支持 Python 3.9 以下的没有metadata的东西

In [305]:
RandomVariable()()
RandomVariable()().metadata[rv_dataclass_metadata_key]
asdict(RandomVariable()().metadata[rv_dataclass_metadata_key])

{'default': <dataclasses._MISSING_TYPE at 0x7f59f5937550>,
 'default_factory': <dataclasses._MISSING_TYPE at 0x7f59f5a8b090>,
 'init': True,
 'repr': True,
 'hash': None,
 'compare': True,
 'metadata': None,
 'kw_only': <dataclasses._MISSING_TYPE at 0x7f59f5a8b9d0>,
 'description': 'MISSING description. ',
 'distribution': 'MISSING distribution. '}

In [306]:
# | export
from decorator import decorator
from fastcore.basics import patch_to
from dataclasses import asdict
import pandas as pd
from optuna import Trial


def is_experiment_setting(cls):
    for py_field in fields(cls):
        if not isinstance(
            py_field.metadata.get(rv_dataclass_metadata_key, None), Variable
        ):
            return False
    return True


def show_dataframe_doc(cls):
    results = []
    for py_field in fields(cls):
        rv = py_field.metadata.get(rv_dataclass_metadata_key, None)
        if rv is None:
            raise ValueError(
                "Class decorated with @experiment_setting needs to use ~RandomVariable fields. "
            )
        field_info = dict(name=py_field.name, type=py_field.type) | asdict(rv)
        results.append(field_info)
    return pd.DataFrame(results)


def get_optuna_search_space(cls, frozen_rvs: set = None):
    search_space = {}
    for py_field in fields(cls):
        field_name = py_field.name
        if frozen_rvs is not None and field_name in frozen_rvs:
            continue
        rv = py_field.metadata.get(rv_dataclass_metadata_key, None)
        if rv is None:
            raise ValueError(
                "Class decorated with @experiment_setting needs to use ~RandomVariable fields. "
            )
        search_space[field_name] = rv.distribution
    return search_space


from copy import deepcopy


def optuna_suggest(
    cls: Type,
    trial: Trial,
    fixed_meta_params,
    suggest_params_only_in: set = None,
    frozen_rvs: set = None,
):
    suggested_params = deepcopy(fixed_meta_params)
    if suggest_params_only_in is None:
        suggest_params_only_in = set(py_field.name for py_field in fields(cls))
    if frozen_rvs is None:
        frozen_rvs = set()
    # fixed_meta_params is dataclass
    if not isinstance(fixed_meta_params, cls):
        raise ValueError(
            f"fixed_meta_params should be an instance of the {cls.__name__} class."
        )
    for py_field in fields(cls):
        if py_field.name not in suggest_params_only_in:
            continue
        if py_field.name in frozen_rvs:
            continue
        rv = py_field.metadata.get(rv_dataclass_metadata_key, None)
        if rv is None:
            raise ValueError(
                "Class decorated with @experiment_setting needs to use ~RandomVariable fields. "
            )
        suggested_value = trial._suggest(py_field.name, rv.distribution)
        setattr(suggested_params, py_field.name, suggested_value)
    return suggested_params


import argparse


def argparse_parser_add_arguments(
    cls: Type, parser: argparse.ArgumentParser, frozen_rvs: set = None
):
    if frozen_rvs is None:
        frozen_rvs = set()
    for py_field in fields(cls):
        field_name = py_field.name
        if frozen_rvs is not None and field_name in frozen_rvs:
            continue
        # 如果已经添加过这个 argument，就不要了
        if field_name in parser._optionals._group_actions:
            # print(f"Field {field_name} already exists in parser, skipping.")
            continue
        if isinstance(py_field.type, str):
            try:
                # Try to evaluate the string as a type
                field_type = eval(py_field.type)
            except:
                # If evaluation fails, skip type conversion
                field_type = None
        elif isinstance(py_field.type, type):
            field_type = py_field.type
        else:
            raise ValueError(
                f"Field {field_name} has an unsupported type: {py_field.type}"
            )

        if field_type is bool:
            field_type = lambda x: x.lower() == "true"

        rv: Variable = py_field.metadata.get(rv_dataclass_metadata_key, None)
        if rv is None:
            raise ValueError(
                "Class decorated with @experiment_setting needs to use ~RandomVariable fields. "
            )
        default_value = (
            rv.default
            if rv.default != rv_missing_value
            else rv.default_factory()
            if rv.default_factory != rv_missing_value
            else None
        )
        parser.add_argument(
            f"--{field_name}",
            type=field_type,
            help=rv.description,
            default=default_value,
        )




def experiment_setting_decorator_buggy(dataclass_func, *args, **kwargs):
    def experiment_setting(cls, *args, **kwargs):
        # 先应用一次才能检查出fields
        normal_dataclass_cls = dataclass_func(cls, *args, **kwargs)
        # 1. 自动修复cls中 不是 Variable的位置
        has_warned = False
        for py_field in fields(normal_dataclass_cls):
            if not isinstance(
                py_field.metadata.get(rv_dataclass_metadata_key, None), Variable
            ):
                if not has_warned:
                    has_warned = True
                    warnings.warn(
                        f"""Field {py_field.name} is not a ~Variable, 
                        Why am I seeing this warning: Class {cls.__name__} is decorated with @experiment_setting.
                        No worries, we will construct a ~Variable for it.
                        """,
                        category=UserWarning,
                        stacklevel=4,
                    )
                var = Variable.from_field(py_field)
                setattr(cls, py_field.name, var.__call__())
                # py_field.metadata = var.metadata
        # 再来一次应用dataclass
        result_cls = dataclass_func(cls, *args, **kwargs)
        # 增加特色方法
        patch_to(result_cls, cls_method=True)(show_dataframe_doc)
        patch_to(result_cls, cls_method=True)(get_optuna_search_space)
        patch_to(result_cls, cls_method=True)(optuna_suggest)
        patch_to(result_cls, cls_method=True)(argparse_parser_add_arguments)

        return result_cls
    
    return experiment_setting


In [307]:
isinstance(field(), Field)

True

In [308]:
class A:
    b = 34
    
A.__dict__

mappingproxy({'__module__': '__main__',
              'b': 34,
              '__dict__': <attribute '__dict__' of 'A' objects>,
              '__weakref__': <attribute '__weakref__' of 'A' objects>,
              '__doc__': None})

In [313]:
#| export
def pyfields(any_cls, return_name: bool = False):
    """迭代器：遍历类中所有类型为 Field 的类变量，返回 (name, field) 对
        筛选 MyClass 中类型为 Field 的类变量
        获取类的所有属性字典（仅包含类自身定义的）
    """
    for name, value in vars(any_cls).items():
        # 排除特殊方法（如 __module__、__doc__ 等）
        if not name.startswith("__") and isinstance(value, Field):
            if return_name:
                yield name, value
            else:
                yield value


In [314]:
class MyClass:
    # 类变量，类型为 Field
    name = field(default="test",)
    age = field(default=18)
    other = 1
    
print("类中类型为 Field 的变量：")
for name, pyfield in pyfields(MyClass, return_name=True):
    print(f"{name}: {pyfield}")

类中类型为 Field 的变量：
name: Field(name=None,type=None,default='test',default_factory=<dataclasses._MISSING_TYPE object at 0x7f5ba85e7010>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=<dataclasses._MISSING_TYPE object at 0x7f5ba85e7010>,_field_type=None)
age: Field(name=None,type=None,default=18,default_factory=<dataclasses._MISSING_TYPE object at 0x7f5ba85e7010>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=<dataclasses._MISSING_TYPE object at 0x7f5ba85e7010>,_field_type=None)


In [315]:
#| export
from copy import deepcopy
@decorator
def experiment_setting_decorator(dataclass_func, cls, **kwargs):
    # decorator不允许*args，这也很正常，装饰器确实应该写出kwargs。
    # 先应用一次才能检查出fields
    # 注意deepcopy无效
    # 如果重复来dataclass会有bug。
    # normal_dataclass_cls = dataclass_func(deepcopy(cls), **kwargs)
    # 1. 自动修复cls中 不是 Variable的位置
    
    has_warned = False
    for py_field in pyfields(cls):
        if not isinstance(
            py_field.metadata.get(rv_dataclass_metadata_key, None), Variable
        ):
            if not has_warned:
                has_warned = True
                warnings.warn(
                    f"""Field {py_field.name} is not a ~Variable, 
                    Why am I seeing this warning: Class {cls.__name__} is decorated with @experiment_setting.
                    No worries, we will construct a ~Variable for it.
                    """,
                    category=UserWarning,
                    stacklevel=4,
                )
            var = Variable.from_field(py_field)
            setattr(cls, py_field.name, var.__call__())
            # py_field.metadata = var.metadata
    # # 再来一次应用dataclass
    result_cls = dataclass_func(cls, **kwargs)
        
    patch_to(result_cls, cls_method=True)(show_dataframe_doc)
    patch_to(result_cls, cls_method=True)(get_optuna_search_space)
    patch_to(result_cls, cls_method=True)(optuna_suggest)
    patch_to(result_cls, cls_method=True)(argparse_parser_add_arguments)
    return result_cls


experiment_setting = experiment_setting_decorator(dataclass)

一些使用案例

In [316]:
from optuna.distributions import IntDistribution, FloatDistribution, CategoricalDistribution
@experiment_setting
class SupportVectorClassifierConfig:
    # 惩罚系数 C
    C: float = ~RandomVariable(
        default=1.0,
        description="Regularization parameter. The strength of the regularization is inversely proportional to C.",
        distribution=FloatDistribution(1e-5, 1e2, log=True)
    )
    # 核函数类型
    kernel: str = ~RandomVariable(
        default="rbf",
        description="Kernel type to be used in the algorithm.",
        distribution=CategoricalDistribution(choices=["linear", "poly", "rbf", "sigmoid", "precomputed"])
    )
    
    # 多项式核函数的度数
    degree: int = ~RandomVariable(
        default=3,
        description="Degree of the polynomial kernel function ('poly').",
        distribution=IntDistribution(1, 10, log=False)
    )
    
    # other_config: int = 1
    
    ...
    

In [317]:
SupportVectorClassifierConfig()
fields(SupportVectorClassifierConfig)[0]
is_experiment_setting(SupportVectorClassifierConfig)

True

### DataClass的自动装饰

使用子类继承来定义新的配置类，自动免除装饰器

In [318]:
#| export 

class DataClass:
    decorator = dataclass
    
    def __init_subclass__(cls) -> None:
        super().__init_subclass__()
        cls.decorator(cls) 

class ExperimentSetting(DataClass):
    decorator = experiment_setting


In [319]:
class TestMockedSVMConfig(ExperimentSetting):
    C: float = 1.0
    kernel: str = "linear"
    degree: int = ~Variable(
        default=3,
        description="Degree of the polynomial kernel function ('poly').",
        distribution=IntDistribution(1, 10, log=False)
    )
test_config = TestMockedSVMConfig()
test_config

TestMockedSVMConfig(C=1.0, kernel='linear', degree=3)

In [323]:
t=fields(TestMockedSVMConfig)[0].metadata
# t=fields(TestMockedSVMConfig)[2].metadata
# t is mappingproxy
# hasattr(t, 'metadata')
# t?
# len(t)
t

mappingproxy({})

In [324]:
fields(TestMockedSVMConfig)[2]

Field(name='degree',type=<class 'int'>,default=3,default_factory=<dataclasses._MISSING_TYPE object at 0x7f5ba85e7010>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'description': "Degree of the polynomial kernel function ('poly').", 'distribution': IntDistribution(high=10, log=False, low=1, step=1), 'thu_rv': Variable(default=3, default_factory=<dataclasses._MISSING_TYPE object at 0x7f5ba85e7010>, init=True, repr=True, hash=None, compare=True, metadata=None, kw_only=<dataclasses._MISSING_TYPE object at 0x7f5ba85e7010>, description="Degree of the polynomial kernel function ('poly').", distribution=IntDistribution(high=10, log=False, low=1, step=1))}),kw_only=False,_field_type=_FIELD)

In [None]:
is_experiment_setting(TestMockedSVMConfig)
# TODO 这个测试没有通过，应该是True才对

False

## Combining `dataclass` (experiment_setting) with PyTorch `nn.Module`

In [None]:
#| hide
# deprecated
# def my_dataclass(cls):
#     # https://discuss.pytorch.org/t/typeerror-unhashable-type-for-my-torch-nn-module/109424/5
#     cls = dataclass(cls, eq=False) 
#     old_init = cls.__init__
#     def new_init(*args, **kwargs):
#         cls.__pre_init__(*args, **kwargs)
#         old_init(*args, **kwargs)
#     cls.__init__ = new_init
#     return cls

为了解决 https://discuss.pytorch.org/t/typeerror-unhashable-type-for-my-torch-nn-module/109424/5 中提到的问题，首先定义

In [326]:
#| export
@decorator
def pre_init_decorator(init_func, self, *args, **kwargs):
    self.__pre_init__(*args, **kwargs)
    return init_func(self, *args, **kwargs)

In [327]:
#| hide 
# TODO decorator style for dataclass_for_torch
# @decorator
# def dataclass_for_torch_decorator(dataclass_func, cls, eq=False, *args, **kwargs):
#     result_cls = dataclass_func(cls, eq=eq, *args, **kwargs)
#     result_cls.__init__ = pre_init_decorator(result_cls.__init__, cls)
#     return result_cls

In [328]:
#| export
def dataclass_for_torch_decorator(dataclass_func):
    def wrapped_func(cls):
        result_cls = dataclass_func(cls, eq=False) 
        result_cls.__init__ = pre_init_decorator(result_cls.__init__, self=cls) #TODO 非常奇怪，但是似乎测试逻辑是对的
        return cls
    return wrapped_func

In [329]:
#| export
_experiment_module = dataclass_for_torch_decorator(experiment_setting) # 隐藏，不建议直接使用

In [330]:
#| export
import torch
import torch.nn as nn
@_experiment_module
class ExperimentModule(nn.Module):
    def __pre_init__(self, *args, **kwargs):
        # 为什么官方 dataclass 没有 pre init 我气死了。
        super().__init__() # torch 的初始化
        
    def __post_init__(self):
        # dataclass生成的init是没有调用super().__init__()的，所以需要手动调用
        # https://docs.python.org/3/library/dataclasses.html#dataclasses.__post_init__
        # 这里调用PyTorch的init，接下来用户写self.xx = xx就能注册参数、子模块之类的。
        # super().__init__() 
        # 为了防止用户自己忘记写 super().__post_init__() ，我们换个名字方便用户记忆。
        self.setup()
    def setup(self):
        # 用户实现，初始化增量神经网络的增量参数v
        raise NotImplementedError("Should be implemented by subclass! ")
    
    def __repr__(self):
        return super().__repr__()
    
    def extra_repr(self) -> str:
        return super().extra_repr()
    
    def __init_subclass__(cls) -> None:
        super().__init_subclass__()
        original_repr = cls.__repr__
        original_extra_repr = cls.extra_repr
        # dataclass(cls) # 这个3.10以后是in place的， 不保证？
        _experiment_module(cls) # 这个3.10以后是in place的， 不保证？
        dataclass_repr = cls.__repr__
        def extra_repr(self):
            dcr = dataclass_repr(self)
            dcr = dcr[dcr.index("(")+1:dcr.rindex(")")]
            return dcr+original_extra_repr(self)
        # cls.extra_repr = lambda self:(dataclass_repr(self)+original_extra_repr(self)) # dataclass的 repr提供给PyTorch
        cls.extra_repr = extra_repr # dataclass的 repr提供给PyTorch
        cls.__repr__ = original_repr

使用示例

In [331]:
class ExampleDataclassModule(ExperimentModule):
    name:str =~RandomVariable(default="root", description="Name of the person")
    age:int =~RandomVariable(default=25, description="Age of the person")   
    def setup(self):
        print(self.name, self.age)
        self.linear = nn.Linear(self.age, self.age)
    def forward(self, x):
        x = torch.Tensor([x])
        return self.linear(x)
    def extra_repr(self) -> str:
        return super().extra_repr()+" Hello World!"

In [332]:
test = ExampleDataclassModule(name="hello", age=1)
# test = TestFlaxStyle()
print(test) # str
print(repr(test))
print(test(1))

hello 1
ExampleDataclassModule(
  name='hello', age=1 Hello World!
  (linear): Linear(in_features=1, out_features=1, bias=True)
)
ExampleDataclassModule(
  name='hello', age=1 Hello World!
  (linear): Linear(in_features=1, out_features=1, bias=True)
)
tensor([-0.7333], grad_fn=<ViewBackward0>)


In [333]:
#| hide
import nbdev; nbdev.nbdev_export()