Skip to content

Commit

Permalink
Dependencies: Update to pydantic~=2.3
Browse files Browse the repository at this point in the history
  • Loading branch information
sphuber committed Oct 11, 2023
1 parent 7d77444 commit ec7a332
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 59 deletions.
64 changes: 35 additions & 29 deletions aiida/manage/configuration/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@
from typing import Any, Dict, List, Optional, Sequence, Tuple
import uuid

from pydantic import BaseModel, Field, ValidationError, validator # pylint: disable=no-name-in-module
from pydantic import ( # pylint: disable=no-name-in-module
BaseModel,
ConfigDict,
Field,
ValidationError,
field_serializer,
field_validator,
)

from aiida.common.exceptions import ConfigurationError
from aiida.common.log import LogLevels
Expand Down Expand Up @@ -59,16 +66,18 @@ def __str__(self) -> str:
return f'Validation Error: {prefix}{path}{self._message}{schema}'


class ConfigVersionSchema(BaseModel):
class ConfigVersionSchema(BaseModel, defer_build=True):
"""Schema for the version configuration of an AiiDA instance."""

CURRENT: int
OLDEST_COMPATIBLE: int


class ProfileOptionsSchema(BaseModel):
class ProfileOptionsSchema(BaseModel, defer_build=True):
"""Schema for the options of an AiiDA profile."""

model_config = ConfigDict(use_enum_values=True)

runner__poll__interval: int = Field(60, description='Polling interval in seconds to be used by process runners.')
daemon__default_workers: int = Field(
1, description='Default number of workers to be launched by `verdi daemon start`.'
Expand Down Expand Up @@ -130,39 +139,37 @@ class ProfileOptionsSchema(BaseModel):
5, description='Maximum number of transport task attempts before a Process is Paused.'
)
rmq__task_timeout: int = Field(10, description='Timeout in seconds for communications with RabbitMQ.')
storage__sandbox: Optional[str] = Field(description='Absolute path to the directory to store sandbox folders.')
storage__sandbox: Optional[str] = Field(
None, description='Absolute path to the directory to store sandbox folders.'
)
caching__default_enabled: bool = Field(False, description='Enable calculation caching by default.')
caching__enabled_for: List[str] = Field([], description='Calculation entry points to enable caching on.')
caching__disabled_for: List[str] = Field([], description='Calculation entry points to disable caching on.')

class Config:
use_enum_values = True

@validator('caching__enabled_for', 'caching__disabled_for')
@field_validator('caching__enabled_for', 'caching__disabled_for')
@classmethod
def validate_caching_identifier_pattern(cls, value: List[str]) -> List[str]:
"""Validate the caching identifier patterns."""
from aiida.manage.caching import _validate_identifier_pattern
for identifier in value:
try:
_validate_identifier_pattern(identifier=identifier)
except ValueError as exception:
raise ValidationError(str(exception)) from exception
_validate_identifier_pattern(identifier=identifier)

return value


class GlobalOptionsSchema(ProfileOptionsSchema):
"""Schema for the global options of an AiiDA instance."""
autofill__user__email: Optional[str] = Field(description='Default user email to use when creating new profiles.')
autofill__user__email: Optional[str] = Field(
None, description='Default user email to use when creating new profiles.'
)
autofill__user__first_name: Optional[str] = Field(
description='Default user first name to use when creating new profiles.'
None, description='Default user first name to use when creating new profiles.'
)
autofill__user__last_name: Optional[str] = Field(
description='Default user last name to use when creating new profiles.'
None, description='Default user last name to use when creating new profiles.'
)
autofill__user__institution: Optional[str] = Field(
description='Default user institution to use when creating new profiles.'
None, description='Default user institution to use when creating new profiles.'
)
rest_api__profile_switching: bool = Field(
False, description='Toggle whether the profile can be specified in requests submitted to the REST API.'
Expand All @@ -173,14 +180,14 @@ class GlobalOptionsSchema(ProfileOptionsSchema):
)


class ProfileStorageConfig(BaseModel):
class ProfileStorageConfig(BaseModel, defer_build=True):
"""Schema for the storage backend configuration of an AiiDA profile."""

backend: str
config: Dict[str, Any]


class ProcessControlConfig(BaseModel):
class ProcessControlConfig(BaseModel, defer_build=True):
"""Schema for the process control configuration of an AiiDA profile."""

broker_protocol: str = Field('amqp', description='Protocol for connecting to the message broker.')
Expand All @@ -192,29 +199,28 @@ class ProcessControlConfig(BaseModel):
broker_parameters: dict[str, Any] = Field('guest', description='Arguments to be encoded as query parameters.')


class ProfileSchema(BaseModel):
class ProfileSchema(BaseModel, defer_build=True):
"""Schema for the configuration of an AiiDA profile."""

uuid: str = Field(description='', default_factory=uuid.uuid4)
storage: ProfileStorageConfig
process_control: ProcessControlConfig
default_user_email: Optional[str] = None
test_profile: bool = False
options: Optional[ProfileOptionsSchema]
options: Optional[ProfileOptionsSchema] = None

class Config:
json_encoders = {
uuid.UUID: lambda u: str(u), # pylint: disable=unnecessary-lambda
}
@field_serializer('uuid')
def serialize_dt(self, value: uuid.UUID, _info):
return str(value)


class ConfigSchema(BaseModel):
class ConfigSchema(BaseModel, defer_build=True):
"""Schema for the configuration of an AiiDA instance."""

CONFIG_VERSION: Optional[ConfigVersionSchema]
profiles: Optional[dict[str, ProfileSchema]]
options: Optional[GlobalOptionsSchema]
default_profile: Optional[str]
CONFIG_VERSION: Optional[ConfigVersionSchema] = None
profiles: Optional[dict[str, ProfileSchema]] = None
options: Optional[GlobalOptionsSchema] = None
default_profile: Optional[str] = None


class Config: # pylint: disable=too-many-public-methods
Expand Down
32 changes: 19 additions & 13 deletions aiida/manage/configuration/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def name(self) -> str:

@property
def valid_type(self) -> Any:
return self._field.type_
return self._field.annotation

@property
def schema(self) -> Dict[str, Any]:
Expand All @@ -44,45 +44,51 @@ def default(self) -> Any:

@property
def description(self) -> str:
return self._field.field_info.description
return self._field.description

@property
def global_only(self) -> bool:
from .config import ProfileOptionsSchema
return self._name in ProfileOptionsSchema.__fields__
return self._name.replace('.', '__') not in ProfileOptionsSchema.model_fields

def validate(self, value: Any) -> Any:
"""Validate a value
:param value: The input value
:param cast: Attempt to cast the value to the required type
:return: The output value
:raise: ConfigValidationError
"""
value, validation_error = self._field.validate(value, {}, loc=None)
from pydantic import ValidationError

from .config import GlobalOptionsSchema

attribute = self.name.replace('.', '__')

if validation_error:
raise ConfigurationError(validation_error)
try:
result = GlobalOptionsSchema.__pydantic_validator__.validate_assignment(
GlobalOptionsSchema.model_construct(), attribute, value
)
except ValidationError as exception:
raise ConfigurationError(str(exception)) from exception

return value
# Return the value from the constructed model as this will have casted the value to the right type
return getattr(result, attribute)


def get_option_names() -> List[str]:
"""Return a list of available option names."""
from .config import GlobalOptionsSchema
return [key.replace('__', '.') for key in GlobalOptionsSchema.__fields__]
return [key.replace('__', '.') for key in GlobalOptionsSchema.model_fields]


def get_option(name: str) -> Option:
"""Return option."""
from .config import GlobalOptionsSchema
options = GlobalOptionsSchema.__fields__
options = GlobalOptionsSchema.model_fields
option_name = name.replace('.', '__')
if option_name not in options:
raise ConfigurationError(f'the option {name} does not exist')
return Option(name, GlobalOptionsSchema.schema()['properties'][option_name], options[option_name])
return Option(name, GlobalOptionsSchema.model_json_schema()['properties'][option_name], options[option_name])


def parse_option(option_name: str, option_value: Any) -> Tuple[Option, Any]:
Expand Down
9 changes: 9 additions & 0 deletions docs/source/nitpick-exceptions
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,15 @@ py:class ndarray
py:class paramiko.proxy.ProxyCommand

py:class pydantic.main.BaseModel
py:class ModelPrivateAttr
py:class CoreSchema
py:class _decorators.DecoratorInfos
py:class _generics.PydanticGenericMetadata
py:class SchemaSerializer
py:class SchemaValidator
py:class Signature
py:class ConfigDict
py:class FieldInfo

# These can be removed once they are properly included in the `__all__` in `plumpy`
py:class plumpy.ports.PortNamespace
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ dependencies:
- pgsu~=0.2.1
- psutil~=5.6
- psycopg2-binary~=2.8
- pydantic~=1.10
- pydantic~=2.4
- pytz~=2021.1
- pyyaml~=6.0
- requests~=2.0
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ dependencies = [
"pgsu~=0.2.1",
"psutil~=5.6",
"psycopg2-binary~=2.8",
"pydantic~=1.10",
"pydantic~=2.4",
"pytz~=2021.1",
"pyyaml~=6.0",
"requests~=2.0",
Expand Down
8 changes: 3 additions & 5 deletions requirements/requirements-py-3.10.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ deprecation==2.1.0
disk-objectstore==0.6.0
docstring-parser==0.15
docutils==0.16
emmet-core==0.57.1
exceptiongroup==1.1.1
executing==1.2.0
fastjsonschema==2.17.1
Expand Down Expand Up @@ -89,8 +88,7 @@ matplotlib-inline==0.1.6
mdit-py-plugins==0.3.5
mdurl==0.1.2
mistune==3.0.1
monty==2023.5.8
mp-api==0.33.3
monty==2023.9.25
mpmath==1.3.0
msgpack==1.0.5
multidict==6.0.4
Expand Down Expand Up @@ -133,10 +131,10 @@ py-cpuinfo==9.0.0
pybtex==0.24.0
pycifrw==4.4.5
pycparser==2.21
pydantic==1.10.9
pydantic==2.4.0
pydata-sphinx-theme==0.13.3
pygments==2.15.1
pymatgen==2023.5.31
pymatgen==2023.9.25
pympler==0.9
pymysql==0.9.3
pynacl==1.5.0
Expand Down
8 changes: 3 additions & 5 deletions requirements/requirements-py-3.11.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ deprecation==2.1.0
disk-objectstore==0.6.0
docstring-parser==0.15
docutils==0.16
emmet-core==0.57.1
executing==1.2.0
fastjsonschema==2.17.1
flask==2.3.2
Expand Down Expand Up @@ -88,8 +87,7 @@ matplotlib-inline==0.1.6
mdit-py-plugins==0.3.5
mdurl==0.1.2
mistune==3.0.1
monty==2023.5.8
mp-api==0.33.3
monty==2023.9.25
mpmath==1.3.0
msgpack==1.0.5
multidict==6.0.4
Expand Down Expand Up @@ -132,10 +130,10 @@ py-cpuinfo==9.0.0
pybtex==0.24.0
pycifrw==4.4.5
pycparser==2.21
pydantic==1.10.9
pydantic==2.4.0
pydata-sphinx-theme==0.13.3
pygments==2.15.1
pymatgen==2023.9.2
pymatgen==2023.9.25
pympler==0.9
pymysql==0.9.3
pynacl==1.5.0
Expand Down
8 changes: 3 additions & 5 deletions requirements/requirements-py-3.9.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ deprecation==2.1.0
disk-objectstore==0.6.0
docstring-parser==0.15
docutils==0.16
emmet-core==0.57.1
exceptiongroup==1.1.1
executing==1.2.0
fastjsonschema==2.17.1
Expand Down Expand Up @@ -91,8 +90,7 @@ matplotlib-inline==0.1.6
mdit-py-plugins==0.3.5
mdurl==0.1.2
mistune==3.0.1
monty==2023.5.8
mp-api==0.33.3
monty==2023.9.25
mpmath==1.3.0
msgpack==1.0.5
multidict==6.0.4
Expand Down Expand Up @@ -135,10 +133,10 @@ py-cpuinfo==9.0.0
pybtex==0.24.0
pycifrw==4.4.5
pycparser==2.21
pydantic==1.10.9
pydantic==2.4.0
pydata-sphinx-theme==0.13.3
pygments==2.15.1
pymatgen==2023.5.31
pymatgen==2023.9.25
pympler==0.9
pymysql==0.9.3
pynacl==1.5.0
Expand Down

0 comments on commit ec7a332

Please sign in to comment.