Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions src/ConfigSpace/hyperparameters/hp_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
normalize,
scale,
)
from ConfigSpace.types import DType, f64, i64
from ConfigSpace.types import DType, ObjectArray, f64, i64

if TYPE_CHECKING:
from ConfigSpace.types import Array, Mask
Expand Down Expand Up @@ -446,11 +446,14 @@ def ordinal_neighborhood(
return np.array([seed.choice(neighbors)], dtype=f64)


# HACK: Technically `Any` isn't an `np.number` that the Transformer expects
# as it's type variable. However for a Constant, we can like with this typing
# hack.
@dataclass
class TransformerConstant(Transformer[DType]):
class TransformerConstant(Transformer[Any]):
"""Implementation of a transformer for a constant value."""

value: DType
value: Any
"""The constant value."""

vector_value_yes: f64
Expand Down Expand Up @@ -479,23 +482,19 @@ def __post_init__(self) -> None:
self.upper_vectorized = self.vector_value_yes

@override
def to_vector(self, value: Array[DType]) -> Array[f64]:
def to_vector(self, value: ObjectArray) -> Array[f64]:
return np.where(
value == self.value,
self.vector_value_yes,
self.vector_value_no,
)

@override
def to_value(self, vector: Array[f64]) -> Array[DType]:
try:
return np.full_like(vector, self.value, dtype=type(self.value))
except TypeError:
# Let numpy figure it out
return np.array([self.value] * len(vector))
def to_value(self, vector: Array[f64]) -> ObjectArray:
return np.full_like(vector, self.value, dtype=object)

@override
def legal_value(self, value: Array[DType]) -> Mask:
def legal_value(self, value: ObjectArray) -> Mask:
return value == self.value # type: ignore

@override
Expand Down
3 changes: 3 additions & 0 deletions src/ConfigSpace/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
Array: TypeAlias = npt.NDArray[DType]
"""Array, a numpy array of a specific dtype."""

ObjectArray: TypeAlias = npt.NDArray[np.object_]
"""Object array, a numpy array of objects."""

f64: TypeAlias = np.float64
"""64-bit floating point number."""

Expand Down