In [1]:
from tab_benchmark.models import RidgeCV
import inspect
import numpy as np
import re

In [2]:
def replace_numpy_types(signature):
    """Replace type annotations of the format <class 'numpy.*'> with numpy.*."""
    pattern = re.compile(r"<class 'numpy\.(\w+)'>")
    signature = pattern.sub(r"numpy.\1", signature)
    return signature

def generate_stub(cls):
    parent_classes = cls.__bases__
    # import modules
    import_statement = ""
    for base_class in parent_classes:
        module = base_class.__module__
        name = base_class.__name__
        import_statement += f"from {module} import {name}\n"
    # Retrieve parent classes
    parent_classes = ", ".join(base.__name__ for base in cls.__bases__)
    stub = f"class {cls.__name__}({parent_classes}):\n"

    # Filter out special methods and attributes
    special_attributes = {
        "__module__", "__annotations__", "set_fit_request", "set_score_request", "__doc__", "__abstractmethods__", "_abc_impl"
    }
    
    # Filter methods and attributes defined in the class type definition
    class_members = {name: obj for name, obj in cls.__dict__.items() if name not in special_attributes}
    for name, obj in class_members.items():
        if inspect.isfunction(obj) or inspect.ismethod(obj):
            signature = str(inspect.signature(obj))
            signature = replace_numpy_types(signature)
            stub += f"    def {name}{signature}: ...\n"
        elif isinstance(obj, property):
            stub += f"    @property\n    def {name}(self) -> {obj.__class__.__name__}: ...\n"
        else:
            # Assume these are attributes
            stub += f"    {name}: {type(obj).__name__}\n"
    
    return import_statement + stub

# Generate stub for RidgeCV
ridgecv_stub = generate_stub(RidgeCV)
print(ridgecv_stub)

from sklearn.linear_model._ridge import RidgeCV
from tabular_benchmark.models import SkLearnExtension
class RidgeCV(RidgeCV, SkLearnExtension):
    def __init__(self, alphas=(0.1, 1.0, 10.0), *, fit_intercept=True, scoring=None, cv=None, gcv_mode=None, store_cv_results=None, alpha_per_target=False, store_cv_values='deprecated', categorical_imputer: Union[str, int, float, NoneType] = 'most_frequent', continuous_imputer: Union[str, int, float, NoneType] = 'median', categorical_encoder: Optional[str] = 'one_hot', handle_unknown_categories: bool = True, variance_threshold: Optional[float] = 0.0, data_scaler: Optional[str] = 'standard', categorical_type: Union[numpy.dtype, str, NoneType] = numpy.float32, continuous_type: Optional[numpy.dtype] = numpy.float32, target_imputer: Union[str, int, float, NoneType] = None, categorical_target_encoder: Optional[str] = 'ordinal', categorical_target_min_frequency: Union[int, float, NoneType] = 10, continuous_target_scaler: Optional[str] = 'standard', c

In [3]:
model = RidgeCV()

In [None]:
model.fit()