diff --git a/src/ConfigSpace/hyperparameters/hp_components.py b/src/ConfigSpace/hyperparameters/hp_components.py index aed73274..1e53101f 100644 --- a/src/ConfigSpace/hyperparameters/hp_components.py +++ b/src/ConfigSpace/hyperparameters/hp_components.py @@ -12,7 +12,7 @@ normalize, scale, ) -from ConfigSpace.types import DType, f64, i64 +from ConfigSpace.types import DType, ObjectArray, f64, i64 if TYPE_CHECKING: from ConfigSpace.types import Array, Mask @@ -446,11 +446,14 @@ def ordinal_neighborhood( return np.array([seed.choice(neighbors)], dtype=f64) +# HACK: Technically `Any` isn't an `np.number` that the Transformer expects +# as it's type variable. However for a Constant, we can like with this typing +# hack. @dataclass -class TransformerConstant(Transformer[DType]): +class TransformerConstant(Transformer[Any]): """Implementation of a transformer for a constant value.""" - value: DType + value: Any """The constant value.""" vector_value_yes: f64 @@ -479,7 +482,7 @@ def __post_init__(self) -> None: self.upper_vectorized = self.vector_value_yes @override - def to_vector(self, value: Array[DType]) -> Array[f64]: + def to_vector(self, value: ObjectArray) -> Array[f64]: return np.where( value == self.value, self.vector_value_yes, @@ -487,15 +490,11 @@ def to_vector(self, value: Array[DType]) -> Array[f64]: ) @override - def to_value(self, vector: Array[f64]) -> Array[DType]: - try: - return np.full_like(vector, self.value, dtype=type(self.value)) - except TypeError: - # Let numpy figure it out - return np.array([self.value] * len(vector)) + def to_value(self, vector: Array[f64]) -> ObjectArray: + return np.full_like(vector, self.value, dtype=object) @override - def legal_value(self, value: Array[DType]) -> Mask: + def legal_value(self, value: ObjectArray) -> Mask: return value == self.value # type: ignore @override diff --git a/src/ConfigSpace/types.py b/src/ConfigSpace/types.py index 4f87c508..d80d5448 100644 --- a/src/ConfigSpace/types.py +++ b/src/ConfigSpace/types.py @@ -27,6 +27,9 @@ Array: TypeAlias = npt.NDArray[DType] """Array, a numpy array of a specific dtype.""" +ObjectArray: TypeAlias = npt.NDArray[np.object_] +"""Object array, a numpy array of objects.""" + f64: TypeAlias = np.float64 """64-bit floating point number."""