Skip to content

Commit

Permalink
Merge pull request #282 from SCM-NV/basis
Browse files Browse the repository at this point in the history
ENH: Rework the CP2K basis set parser
  • Loading branch information
BvB93 committed Feb 4, 2022
2 parents 5a8c947 + 7c38430 commit 99befe1
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 77 deletions.
13 changes: 10 additions & 3 deletions src/qmflows/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,18 @@
class AtomBasisKey(NamedTuple):
"""Namedtuple containing the `basisFormat` for a given `basis` and `atom`."""

#: Atomic symbol.
atom: str

#: Name of the basis set.
basis: str
basisFormat: Union[Sequence[int],
Sequence[Tuple[str, int]]]
# Orca uses 2-tuples while CP2K uses integer

# NOTE: Orca uses 2-tuples while CP2K uses integer
#: The basis set format.
basisFormat: Union[Sequence[int], Sequence[Tuple[str, int]]]

#: The primary basis key. Used if this instance is an alias of another `AtomBasisKey`.
alias: "None | AtomBasisKey" = None


class AtomBasisData(NamedTuple):
Expand Down
66 changes: 66 additions & 0 deletions src/qmflows/parsers/_cp2k_basis_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Functions for reading CP2K basis set files."""

from typing import Iterator, Generator, Tuple, List
from itertools import islice

import numpy as np

from ..type_hints import PathLike
from ..common import AtomBasisData, AtomBasisKey

__all__ = ["readCp2KBasis"]

_Basis2Tuple = Tuple[List[AtomBasisKey], List[AtomBasisData]]


def _basis_file_iter(f: Iterator[str]) -> Generator[str, None, None]:
"""Iterate through `f` and remove all empty and commented lines."""
for i in f:
i = i.strip().rstrip("\n")
if not i or i.startswith("#"):
continue
yield i


def _read_basis(iterator: Iterator[str]) -> _Basis2Tuple:
"""Helper function for parsing the opened basis set file."""
f = _basis_file_iter(iterator)
keys = []
values = []
for i in f:
# Read the atom type and basis set name(s)
atom, *basis_list = i.split()
atom = atom.capitalize()

# Identify the number of exponent sets
n_sets = int(next(f))
if n_sets != 1:
raise NotImplementedError(
"Basis sets with more than 1 set of exponents are not supported yet"
)

for _ in range(n_sets):
# Parse the basis format, its exponents and its coefficients
basis_fmt = [int(j) for j in next(f).split()]
n_exp = basis_fmt[3]
basis_data = np.array([j.split() for j in islice(f, 0, n_exp)], dtype=np.float64)
exp, coef = basis_data[:, 0], basis_data[:, 1:]

# Two things happen whenever an basis set alias is encountered (i.e. `is_alias > 0`):
# 1. The `alias` field is set for the keys
# 2. The `AtomBasisData` instance, used for the original value, is reused
for is_alias, basis in enumerate(basis_list):
if not is_alias:
basis_key = AtomBasisKey(atom, basis, basis_fmt)
basis_value = AtomBasisData(exp, coef)
keys.append(basis_key)
else:
keys.append(AtomBasisKey(atom, basis, basis_fmt, alias=basis_key))
values.append(basis_value)
return keys, values


def readCp2KBasis(file: PathLike) -> _Basis2Tuple:
"""Read the Contracted Gauss function primitives format from a text file."""
with open(file, "r") as f:
return _read_basis(f)
68 changes: 7 additions & 61 deletions src/qmflows/parsers/cp2KParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,20 @@

import numpy as np
from more_itertools import chunked
from pyparsing import (FollowedBy, Group, Literal, NotAny, OneOrMore, Optional,
SkipTo, Suppress, Word, ZeroOrMore, alphanums, alphas,
nums, oneOf, restOfLine, srange)
from pyparsing import (
Literal, Optional, SkipTo, Suppress, Word, ZeroOrMore, alphas,
nums, oneOf, srange,
)
from scm.plams import Molecule, Units

from ..common import AtomBasisData, AtomBasisKey, InfoMO, MO_metadata, CP2KVersion
from ..common import InfoMO, MO_metadata, CP2KVersion
from ..type_hints import Literal as Literal_
from ..type_hints import PathLike, T, WarnDict, WarnMap
from ..utils import file_to_context
from ..warnings_qmflows import QMFlows_Warning
from .parser import (floatNumber, minusOrplus, natural, point,
try_search_pattern)
from .parser import minusOrplus, natural, point, try_search_pattern
from .xyzParser import manyXYZ, tuplesXYZ_to_plams
from ._cp2k_basis_parser import readCp2KBasis

__all__ = ['readCp2KBasis', 'read_cp2k_coefficients', 'get_cp2k_freq',
'read_cp2k_number_of_orbitals', 'read_cp2k_xyz', 'read_cp2k_table',
Expand Down Expand Up @@ -303,38 +304,6 @@ def remove_trailing(xs: List[List[str]]) -> List[List[str]]:
# Orbital Information:" 12 1 cd 4d+1"
orbInfo = natural * 2 + Word(alphas, max=2) + orbitals

# ====================> Basis File <==========================
comment = Literal("#") + restOfLine

parser_atom_label = (
Word(srange("[A-Z]"), max=1) +
Optional(Word(srange("[a-z]"), max=1))
)

parser_basis_name = Word(alphanums + "-") + Suppress(restOfLine)

parser_format = OneOrMore(natural + NotAny(FollowedBy(point)))

parser_key = (
parser_atom_label.setResultsName("atom") +
parser_basis_name.setResultsName("basisName") +
Suppress(Literal("1"))
)

parser_basis_data = OneOrMore(floatNumber)

parser_basis = (
parser_key +
parser_format.setResultsName("format") +
parser_basis_data.setResultsName("coeffs")
)

top_parser_basis = (
OneOrMore(Suppress(comment)) + OneOrMore(
Group(parser_basis + Suppress(Optional(OneOrMore(comment)))))
)


# ===============================<>====================================
# Parsing From File

Expand Down Expand Up @@ -441,29 +410,6 @@ def split_log_file(path: PathLike, root: str, file_name: str) -> None:
subprocess.check_call(cmd, shell=True, stdout=subprocess.DEVNULL, cwd=root)


#: A 2-tuple; output of :func:`readCp2KBasis`.
Tuple2List = Tuple[List[AtomBasisKey], List[AtomBasisData]]


def readCp2KBasis(path: PathLike) -> Tuple2List:
"""Read the Contracted Gauss function primitives format from a text file."""
bss = top_parser_basis.parseFile(path)
atoms = [''.join(xs.atom[:]).lower() for xs in bss]
names = [' '.join(xs.basisName[:]).upper() for xs in bss]
formats = [list(map(int, xs.format[:])) for xs in bss]

# for example 2 0 3 7 3 3 2 1 there are sum(3 3 2 1) =9 Lists
# of Coefficients + 1 lists of exponents
nCoeffs = [int(sum(xs[4:]) + 1) for xs in formats]
coefficients = [list(map(float, cs.coeffs[:])) for cs in bss]
rss = [swap_coefficients(*args) for args in zip(nCoeffs, coefficients)]
tss = [get_head_and_tail(xs) for xs in rss]
basisData = [AtomBasisData(xs[0], xs[1]) for xs in tss]
basiskey = [AtomBasisKey(at, name, fmt) for at, name, fmt in zip(atoms, names, formats)]

return (basiskey, basisData)


#: A :class:`~collections.abc.Sequence` typevar.
ST = TypeVar('ST', bound=Sequence)

Expand Down
40 changes: 27 additions & 13 deletions test/test_cp2k_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
import shutil
from pathlib import Path

import numpy as np
import h5py
import pytest
from assertionlib import assertion

from qmflows.parsers.cp2KParser import parse_cp2k_warnings, readCp2KBasis, get_cp2k_version_run
from qmflows.test_utils import PATH, requires_cp2k
from qmflows.warnings_qmflows import QMFlows_Warning, cp2k_warnings
from qmflows.common import AtomBasisKey


def test_parse_cp2k_warnings():
Expand All @@ -20,20 +23,31 @@ def test_parse_cp2k_warnings():
for val in map_warns.values()))


def test_read_basis():
class TestReadBasis:
"""Test that the basis are read correctly."""
BASIS_FILE = PATH / "BASIS_MOLOPT"

for key, data in zip(*readCp2KBasis(BASIS_FILE)):
# The formats contains a list
assertion.len(key.basisFormat)
# Atoms are either 1 or two characters
assertion.le(len(key.atom), 2)
# All basis are MOLOPT
assertion.contains(key.basis, "MOLOPT")
# There is a list of exponents and coefficinets
assertion.len(data.exponents)
assertion.len(data.coefficients[0])

@staticmethod
def get_key(key_tup: AtomBasisKey) -> str:
return os.path.join(
key_tup.atom,
key_tup.basis,
"-".join(str(i) for i in key_tup.basisFormat),
)

def test_pass(self):
basis_file = PATH / "BASIS_MOLOPT"
keys, values = readCp2KBasis(basis_file)

with h5py.File(PATH / "basis.hdf5", "r") as f:
for key_tup, value_tup in zip(keys, values):
key = self.get_key(key_tup)
group = f[key]
alias = group["alias"][...].astype(str)

np.testing.assert_allclose(value_tup.exponents, group["exponents"], err_msg=key)
np.testing.assert_allclose(value_tup.coefficients, group["coefficients"], err_msg=key)
if key_tup.alias is not None:
assertion.eq(self.get_key(key_tup.alias), alias, message=key)


@requires_cp2k
Expand Down
Binary file added test/test_files/basis.hdf5
Binary file not shown.

0 comments on commit 99befe1

Please sign in to comment.