Skip to content

Commit

Permalink
Merge pull request #245 from SCM-NV/digits
Browse files Browse the repository at this point in the history
ENH: Add new `yaml` loaders with duplicate key checking
  • Loading branch information
BvB93 committed May 6, 2021
2 parents 4850c6a + 77ba859 commit 16530e4
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 20 deletions.
10 changes: 7 additions & 3 deletions src/qmflows/fileFunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@

__all__ = ['yaml2Settings']

from typing import AnyStr, Callable
from typing import Union, Callable

import yaml

from .settings import Settings
from .type_hints import T
from .yaml_utils import UniqueSafeLoader


def yaml2Settings(xs: AnyStr, mapping_type: Callable[[dict], T] = Settings) -> T:
def yaml2Settings(
xs: Union[str, bytes],
mapping_type: Callable[[dict], T] = Settings,
) -> T:
"""Transform a string containing some data in .yaml format to a Settings object."""
if isinstance(xs, bytes):
xs = xs.decode()

dct = yaml.load(xs, Loader=yaml.FullLoader) # yaml object must be string
dct = yaml.load(xs, Loader=UniqueSafeLoader) # yaml object must be string
return mapping_type(dct)
13 changes: 7 additions & 6 deletions src/qmflows/templates/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import yaml

from ..settings import Settings
from ..yaml_utils import UniqueSafeLoader


#: Templates for single-point calculations.
Expand Down Expand Up @@ -98,7 +99,7 @@
functional: lda
basis:
basis: sto_sz
""", Loader=yaml.FullLoader))
""", Loader=UniqueSafeLoader))

#: Templates for geometry optimization calculations.
geometry = Settings(yaml.load("""
Expand Down Expand Up @@ -197,7 +198,7 @@
runtyp: opt
basis:
basis: sto_sz
""", Loader=yaml.FullLoader))
""", Loader=UniqueSafeLoader))

#: Templates for transition state calculations.
ts = Settings(yaml.load("""
Expand Down Expand Up @@ -229,7 +230,7 @@
ts_search: ef
basis:
basis: sto_sz
""", Loader=yaml.FullLoader))
""", Loader=UniqueSafeLoader))

#: Templates for frequency analyses calculations.
freq = Settings(yaml.load("""
Expand Down Expand Up @@ -296,7 +297,7 @@
basis:
basis: sto_sz
main: freq
""", Loader=yaml.FullLoader))
""", Loader=UniqueSafeLoader))

#: Templates for molecular dynamics (MD) calculations.
md = Settings(yaml.load("""
Expand Down Expand Up @@ -351,7 +352,7 @@
print_level: low
project: cp2k
run_type: MD
""", Loader=yaml.FullLoader))
""", Loader=UniqueSafeLoader))

#: Templates for cell optimization calculations.
cell_opt = Settings(yaml.load("""
Expand Down Expand Up @@ -407,4 +408,4 @@
print_level: low
project: cp2k
run_type: cell_opt
""", Loader=yaml.FullLoader))
""", Loader=UniqueSafeLoader))
91 changes: 91 additions & 0 deletions src/qmflows/yaml_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""A module containing containing :mod:`yaml` loaders with duplicate key checking.
Index
-----
.. currentmodule:: qmflows.yaml_utils
.. autosummary::
UniqueUnsafeLoader
UniqueFullLoader
UniqueSafeLoader
API
---
.. autoclass:: UniqueUnsafeLoader
.. autoclass:: UniqueFullLoader
.. autoclass:: UniqueSafeLoader
"""

from collections.abc import Hashable
from typing import Dict, Any

# Use the fast C-based loaders if possible
try:
from yaml import (
CUnsafeLoader as UnsafeLoader,
CFullLoader as FullLoader,
CSafeLoader as SafeLoader,
)
except ImportError:
from yaml import UnsafeLoader, FullLoader, SafeLoader
from yaml.nodes import MappingNode
from yaml.constructor import ConstructorError, BaseConstructor, SafeConstructor

__all__ = ['UniqueLoader', 'UniqueUnsafeLoader', 'UniqueFullLoader', 'UniqueSafeLoader']


def _construct_mapping(
loader: BaseConstructor,
node: MappingNode,
deep: bool = False,
) -> Dict[Any, Any]:
"""A helper function for handling :meth:`~yaml.BaseConstructor.construct_mapping` methods."""
if not isinstance(node, MappingNode):
raise ConstructorError(
None, None,
f"expected a mapping node, but found {node.id}", node.start_mark,
)

if isinstance(loader, SafeConstructor):
loader.flatten_mapping(node)

mapping = {}
for key_node, value_node in node.value:
key = loader.construct_object(key_node, deep=deep)
if not isinstance(key, Hashable):
raise ConstructorError("while constructing a mapping", node.start_mark,
"found unhashable key", key_node.start_mark)
elif key in mapping:
raise ConstructorError("while constructing a mapping", node.start_mark,
"found duplicate key", key_node.start_mark)

value = loader.construct_object(value_node, deep=deep)
mapping[key] = value
return mapping


class UniqueUnsafeLoader(UnsafeLoader):
"""A :class:`~yaml.UnsafeLoader` subclass with duplicate key checking."""

def construct_mapping(self, node: MappingNode, deep: bool = False) -> Dict[Any, Any]:
"""Construct Convert the passed **node** into a :class:`dict`."""
return _construct_mapping(self, node, deep)


class UniqueFullLoader(FullLoader):
"""A :class:`~yaml.FullLoader` subclass with duplicate key checking."""

def construct_mapping(self, node: MappingNode, deep: bool = False) -> Dict[Any, Any]:
"""Construct Convert the passed **node** into a :class:`dict`."""
return _construct_mapping(self, node, deep)


class UniqueSafeLoader(SafeLoader):
"""A :class:`~yaml.SafeLoader` subclass with duplicate key checking."""

def construct_mapping(self, node: MappingNode, deep: bool = False) -> Dict[Any, Any]:
"""Construct Convert the passed **node** into a :class:`dict`."""
return _construct_mapping(self, node, deep)


UniqueLoader = UniqueUnsafeLoader
66 changes: 55 additions & 11 deletions test/test_yaml_settings.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,35 @@
"""Test the conversion from yaml to settings."""

from typing import Type

import yaml
import pytest
from assertionlib import assertion
from yaml.constructor import BaseConstructor, ConstructorError

from qmflows import Settings
from qmflows.fileFunctions import yaml2Settings
from qmflows.yaml_utils import UniqueFullLoader, UniqueLoader, UniqueSafeLoader, UniqueUnsafeLoader

CP2K_PBE_GUESS_DUPLICATE = """
specific:
cp2k:
global:
run_type:
energy
force_eval:
subsys:
cell:
periodic: "None"
periodic: "None"
dft:
xc:
xc_functional pbe: {}
scf:
eps_scf: 1e-6
"""

cp2k_pbe_guess = """
CP2K_PBE_GUESS = """
specific:
cp2k:
global:
Expand All @@ -21,16 +46,35 @@
eps_scf: 1e-6
"""

REF = Settings()
REF.specific.cp2k.force_eval.dft.xc["xc_functional pbe"] = {}
REF.specific.cp2k.force_eval.subsys.cell.periodic = "None"
REF.specific.cp2k["global"]["run_type"] = "energy"
REF.specific.cp2k.force_eval.dft.scf.eps_scf = "1e-6"

LOADER_TYPES = (UniqueFullLoader, UniqueLoader, UniqueSafeLoader, UniqueUnsafeLoader)


def test_yaml2Settings():
"""Test the conversion from yaml to settings."""
s1 = yaml2Settings(cp2k_pbe_guess)
s2 = yaml2Settings(cp2k_pbe_guess.encode())

ref = Settings()
ref.specific.cp2k.force_eval.dft.xc["xc_functional pbe"] = {}
ref.specific.cp2k.force_eval.subsys.cell.periodic = "None"
ref.specific.cp2k["global"]["run_type"] = "energy"
ref.specific.cp2k.force_eval.dft.scf.eps_scf = "1e-6"
assertion.eq(s1.specific, ref.specific)
assertion.eq(s2.specific, ref.specific)
s1 = yaml2Settings(CP2K_PBE_GUESS)
s2 = yaml2Settings(CP2K_PBE_GUESS.encode())

assertion.eq(s1.specific, REF.specific)
assertion.eq(s2.specific, REF.specific)


class TestLoader:
"""Tests for the :mod:`qmflows.yaml_utils` loaders."""

@pytest.mark.parametrize("loader", LOADER_TYPES)
def test_pass(self, loader: Type[BaseConstructor]) -> None:
"""Test for succesful :func:`yaml.load` calls."""
dct = yaml.load(CP2K_PBE_GUESS, Loader=loader)
assertion.eq(dct["specific"], REF.specific)

@pytest.mark.parametrize("loader", LOADER_TYPES)
def test_raise(self, loader: Type[BaseConstructor]) -> None:
"""Test for failed :func:`yaml.load` calls."""
with pytest.raises(ConstructorError):
yaml.load(CP2K_PBE_GUESS_DUPLICATE, Loader=loader)

0 comments on commit 16530e4

Please sign in to comment.