diff --git a/orchestrator/schema/domain.py b/orchestrator/schema/domain.py index 6ea713d7..9e26e45e 100644 --- a/orchestrator/schema/domain.py +++ b/orchestrator/schema/domain.py @@ -26,6 +26,7 @@ class VariableTypeEnum(str, enum.Enum): BINARY_VARIABLE_TYPE = "BINARY_VARIABLE_TYPE" # the value of the variable is binary UNKNOWN_VARIABLE_TYPE = "UNKNOWN_VARIABLE_TYPE" # the type of value of the variable is unknown/unspecified IDENTIFIER_VARIABLE_TYPE = "IDENTIFIER_VARIABLE_TYPE" # the value is some type of, possible unique, identifier + VECTOR_VARIABLE_TYPE = "VECTOR_VARIABLE_TYPE" # the value is a vector class ProbabilityFunctionsEnum(str, enum.Enum): @@ -416,6 +417,10 @@ def variableType_matches_values(cls, value, values: "pydantic.FieldValidationInf elif value == VariableTypeEnum.OPEN_CATEGORICAL_VARIABLE_TYPE: assert values.data.get("interval") is None assert values.data.get("domainRange") is None + elif value == VariableTypeEnum.VECTOR_VARIABLE_TYPE: + raise ValueError( + "Vector variables are not supported by PropertyDomain - use VectorPropertyDomain instead" + ) return value diff --git a/orchestrator/schema/property.py b/orchestrator/schema/property.py index 11cb7f1b..88b385a3 100644 --- a/orchestrator/schema/property.py +++ b/orchestrator/schema/property.py @@ -2,11 +2,32 @@ # SPDX-License-Identifier: MIT import enum +from typing import Annotated import pydantic from pydantic import ConfigDict from orchestrator.schema.domain import PropertyDomain +from orchestrator.schema.vector_domain import VectorPropertyDomain + + +def domain_type_discriminator(domain): + + if isinstance(domain, PropertyDomain): + return "scalar" + if isinstance(domain, VectorPropertyDomain): + return "vector" + if isinstance(domain, dict): + return "vector" if domain.get("element_domain") else "scalar" + + raise ValueError(f"Unable to determine domain type for domain: {domain}") + + +Domain = Annotated[ + Annotated[PropertyDomain, pydantic.Tag("scalar")] + | Annotated[VectorPropertyDomain, pydantic.Tag("vector")], + pydantic.Discriminator(domain_type_discriminator), +] class MeasuredPropertyTypeEnum(str, enum.Enum): @@ -117,7 +138,7 @@ class Property(pydantic.BaseModel): metadata: dict | None = pydantic.Field( default=None, description="Metadata on the property" ) - propertyDomain: PropertyDomain = pydantic.Field( + propertyDomain: Domain = pydantic.Field( default=PropertyDomain(), description="Provides information on the variable type and the valid values it can take", ) diff --git a/orchestrator/schema/vector_domain.py b/orchestrator/schema/vector_domain.py new file mode 100644 index 00000000..8f649f67 --- /dev/null +++ b/orchestrator/schema/vector_domain.py @@ -0,0 +1,88 @@ +# Copyright (c) IBM Corporation +# SPDX-License-Identifier: MIT + +import itertools + +import pydantic +from pydantic import BaseModel, ConfigDict, Field + +from orchestrator.schema.domain import PropertyDomain, VariableTypeEnum + + +class VectorPropertyDomain(BaseModel): + element_domain: PropertyDomain = Field(..., description="Domain of elements") + number_elements: int = Field(..., description="Length/dimension of the vector") + variableType: VariableTypeEnum = Field( + default=VariableTypeEnum.VECTOR_VARIABLE_TYPE + ) + + model_config = ConfigDict(frozen=True, extra="forbid") + + @pydantic.field_validator("variableType") + def variableType_matches_values(cls, value, values: "pydantic.FieldValidationInfo"): + if value != VariableTypeEnum.VECTOR_VARIABLE_TYPE: + raise ValueError("VariableType must be VECTOR_VARIABLE_TYPE") + return value + + def valueInDomain(self, value: list) -> bool: + """Check that all elements in the vector are in the element_domain.""" + if not isinstance( + value, (list, tuple) + ): # or len(value) != self.number_elements: + return False + + return all(self.element_domain.valueInDomain(v) for v in value) + + def isSubDomain(self, otherDomain: "VectorPropertyDomain") -> bool: + """Must be a subdomain only to another VectorPropertyDomain.""" + # Must be a VectorPropertyDomain and have the proper variableType (robustness) + if not hasattr(otherDomain, "variableType") or ( + otherDomain.variableType != self.variableType + ): + return False + # Must have equal or fewer dimensions + if self.number_elements > otherDomain.number_elements: + return False + # Each element subdomain + return self.element_domain.isSubDomain(otherDomain.element_domain) + + @property + def domain_values(self) -> list: + # The cartesian product of the element domain values, number_elements times + # Returns a list of vectors + try: + elem_values = self.element_domain.domain_values + except Exception as e: + raise ValueError( + f"element_domain must be discrete and have domain_values: {e!s}" + ) + # Cartesian product + return list(itertools.product(elem_values, repeat=self.number_elements)) + + @property + def size(self) -> int: + """Returns the size (number of possible vectors) if countable.""" + + n_elem_values = len(self.element_domain.domain_values) + return n_elem_values**self.number_elements + + def __eq__(self, other): + if not isinstance(other, VectorPropertyDomain): + return False + return ( + self.number_elements == other.number_elements + and self.element_domain == other.element_domain + and self.variableType == other.variableType + ) + + def _repr_pretty_(self, p, cycle=False): + if cycle: + p.text("Cycle detected") + else: + p.text(f"Type: {self.variableType}") + p.breakable() + p.text(f"Number of elements: {self.number_elements}") + p.breakable() + with p.group(2, "Element Domain:"): + p.break_() + p.pretty(self.element_domain) diff --git a/tests/schema/test_vector_domain.py b/tests/schema/test_vector_domain.py new file mode 100644 index 00000000..c3e263d1 --- /dev/null +++ b/tests/schema/test_vector_domain.py @@ -0,0 +1,92 @@ +# Copyright (c) IBM Corporation +# SPDX-License-Identifier: MIT + +import math + +import pydantic +import pytest + +from orchestrator.schema.domain import PropertyDomain, VariableTypeEnum +from orchestrator.schema.vector_domain import VectorPropertyDomain + + +@pytest.fixture +def simple_element_domain(): + # Discrete domain: {1, 2, 3} + return PropertyDomain( + variableType=VariableTypeEnum.DISCRETE_VARIABLE_TYPE, + values=[1, 2, 3], + ) + + +def test_vector_property_domain_valid_vector(simple_element_domain): + vpd = VectorPropertyDomain(element_domain=simple_element_domain, number_elements=2) + assert vpd.valueInDomain([1, 2]) + assert vpd.valueInDomain([2, 3]) + assert not vpd.valueInDomain([1, 999]) # 999 not in element domain + assert not vpd.valueInDomain([1]) # Too short + assert not vpd.valueInDomain([1, 2, 3]) # Too long + + +def test_vector_property_domain_domain_values(simple_element_domain): + vpd = VectorPropertyDomain(element_domain=simple_element_domain, number_elements=2) + values = vpd.domain_values + # Should be the cartesian product + expected = [(a, b) for a in [1, 2, 3] for b in [1, 2, 3]] + assert set(values) == set(expected) + assert len(values) == 9 # 3^2 + + +def test_vector_property_domain_size(simple_element_domain): + vpd = VectorPropertyDomain(element_domain=simple_element_domain, number_elements=3) + assert vpd.size == 27 + # Inf if element_domain not countable + # Make continuous domain (should not allow domain_values) + from orchestrator.schema.domain import PropertyDomain, VariableTypeEnum + + cd = PropertyDomain( + variableType=VariableTypeEnum.CONTINUOUS_VARIABLE_TYPE, domainRange=[0, 1] + ) + vpd_cont = VectorPropertyDomain(element_domain=cd, number_elements=2) + assert math.isinf(vpd_cont.size) + with pytest.raises(Exception, match="element_domain must be discrete"): + _ = vpd_cont.domain_values + + +def test_vector_property_domain_isSubDomain(simple_element_domain): + eldom_small = PropertyDomain( + variableType=VariableTypeEnum.DISCRETE_VARIABLE_TYPE, values=[1] + ) + eldom_big = PropertyDomain( + variableType=VariableTypeEnum.DISCRETE_VARIABLE_TYPE, values=[1, 2, 3] + ) + vpd_small = VectorPropertyDomain(element_domain=eldom_small, number_elements=2) + vpd_big = VectorPropertyDomain(element_domain=eldom_big, number_elements=3) + # Check: fewer dims + assert vpd_small.isSubDomain(vpd_big) + # Reverse: should fail (more dims) + assert not vpd_big.isSubDomain(vpd_small) + # Same number dims but element subdomain wrong + vpd2 = VectorPropertyDomain(element_domain=eldom_big, number_elements=2) + assert not vpd2.isSubDomain(vpd_small) + + +def test_vector_property_domain_eq(simple_element_domain): + vpd1 = VectorPropertyDomain(element_domain=simple_element_domain, number_elements=2) + vpd2 = VectorPropertyDomain(element_domain=simple_element_domain, number_elements=2) + assert vpd1 == vpd2 + vpd3 = VectorPropertyDomain(element_domain=simple_element_domain, number_elements=3) + assert vpd1 != vpd3 + + +def test_vector_property_domain_variableType_guard(simple_element_domain): + # If someone tries to construct it with wrong variableType, should raise error + + from orchestrator.schema.domain import VariableTypeEnum + + with pytest.raises(pydantic.ValidationError, match="VariableType must be VECTOR"): + VectorPropertyDomain( + element_domain=simple_element_domain, + number_elements=2, + variableType=VariableTypeEnum.DISCRETE_VARIABLE_TYPE, + )