Skip to content

Commit

Permalink
Adapt to the latest version of PWmat output file (materialsproject#3823)
Browse files Browse the repository at this point in the history
* Modify code compatible for latest version PWmat

* pre-commit auto-fixes

* Add test for keyword  keyword

* pre-commit auto-fixes

* refactor tests to use pytest parametrization

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Janosh Riebesell <janosh.riebesell@gmail.com>
  • Loading branch information
3 people committed May 13, 2024
1 parent 578d29c commit 51180fc
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 10 deletions.
21 changes: 12 additions & 9 deletions pymatgen/io/pwmat/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,22 @@ class LineLocator(MSONable):
"""Find the line indices (starts from 1) of a certain paragraph of text from the file."""

@staticmethod
def locate_all_lines(file_path: PathLike, content: str) -> list[int]:
def locate_all_lines(file_path: PathLike, content: str, exclusion: str = "") -> list[int]:
"""Locate the line in file where a certain paragraph of text is located (return all indices)
Args:
file_path (PathLike): Absolute path to file.
content (str): Certain paragraph of text that needs to be located.
exclusion (str): Certain paragraph of text that is excluded.
"""
row_idxs: list[int] = [] # starts from 1 to be compatible with linecache package
row_no: int = 0
with zopen(file_path, mode="rt") as file:
for row_content in file:
row_no += 1
if content.upper() in row_content.upper():
if content.upper() in row_content.upper() and (
not exclusion or exclusion.upper() not in row_content.upper()
):
row_idxs.append(row_no)
return row_idxs

Expand All @@ -47,18 +50,19 @@ class ListLocator(MSONable):
"""Find the element indices (starts from 0) of a certain paragraph of text from the list."""

@staticmethod
def locate_all_lines(strs_lst: list[str], content: str) -> list[int]:
def locate_all_lines(strs_lst: list[str], content: str, exclusion: str = "") -> list[int]:
"""Locate the elements in list where a certain paragraph of text is located (return all indices)
Args:
strs_lst (list[str]): List of strings.
content (str): Certain paragraph of text that needs to be located.
exclusion (str): Certain paragraph of text that is excluded.
"""
str_idxs: list[int] = [] # starts from 0 to be compatible with list
str_no: int = -1
for tmp_str in strs_lst:
str_no += 1
if content.upper() in tmp_str.upper():
if (content.upper() in tmp_str.upper()) and (not exclusion or exclusion.upper() not in tmp_str.upper()):
str_idxs.append(str_no)
return str_idxs

Expand Down Expand Up @@ -277,13 +281,13 @@ def get_e_tot(self) -> np.ndarray:
"""
# strs_lst:
# [' 216 atoms', 'Iteration (fs) = 0.3000000000E+01',
# ' Etot', 'Ep', 'Ek (eV) = -0.2831881714E+05 -0.2836665392E+05 0.4783678177E+02',
# ' Etot', 'Ep', 'Ek = -0.2831881714E+05 -0.2836665392E+05 0.4783678177E+02',
# ' SCF = 7']
strs_lst = self.strs_lst[0].split(",")
aim_index = ListLocator.locate_all_lines(strs_lst=strs_lst, content="EK (EV) =")[0]
aim_index = ListLocator.locate_all_lines(strs_lst=strs_lst, content="EK")[0]
# strs_lst[aim_index].split() :
# ['Ek', '(eV)', '=', '-0.2831881714E+05', '-0.2836665392E+05', '0.4783678177E+02']
return np.array([float(strs_lst[aim_index].split()[3].strip())])
return np.array([float(strs_lst[aim_index].split("=")[1].split()[0].strip())])

def get_atom_energies(self) -> np.ndarray | None:
"""Return the energies of individual atoms in material system.
Expand Down Expand Up @@ -315,8 +319,7 @@ def get_atom_forces(self) -> np.ndarray:
"""
forces = []
aim_content = "Force".upper()
aim_idx = ListLocator.locate_all_lines(strs_lst=self.strs_lst, content=aim_content)[0]

aim_idx = ListLocator.locate_all_lines(strs_lst=self.strs_lst, content=aim_content, exclusion="average")[0]
for line in self.strs_lst[aim_idx + 1 : aim_idx + self.num_atoms + 1]:
# ['14', '0.089910342901203', '0.077164252174742', '0.254144099204679']
forces.append([float(val) for val in line.split()[1:4]])
Expand Down
Binary file modified tests/files/io/pwmat/MOVEMENT.lzma
Binary file not shown.
30 changes: 29 additions & 1 deletion tests/io/pwmat/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,40 @@
from numpy.testing import assert_allclose

from pymatgen.core import Composition, Structure
from pymatgen.io.pwmat.inputs import ACExtractor, ACstrExtractor, AtomConfig, GenKpt, HighSymmetryPoint
from pymatgen.io.pwmat.inputs import (
ACExtractor,
ACstrExtractor,
AtomConfig,
GenKpt,
HighSymmetryPoint,
LineLocator,
ListLocator,
)
from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest

TEST_DIR = f"{TEST_FILES_DIR}/io/pwmat"


@pytest.mark.parametrize(
("exclusion", "expected_idx"),
[("", 1), ("average", 163), ("AVERAGE", 163)],
)
def test_line_locator(exclusion: str, expected_idx: int):
filepath = f"{TEST_DIR}/MOVEMENT.lzma"
aim_idx = LineLocator.locate_all_lines(file_path=filepath, content="FORCE", exclusion=exclusion)[0]
assert aim_idx == expected_idx


@pytest.mark.parametrize(
("exclusion", "expected_idx"),
[("", 0), ("average", 1), ("AVERAGE", 1)],
)
def test_list_locator(exclusion: str, expected_idx: int):
strs_lst = ["Average Force= 0.12342E+01", "Force"]
aim_idx = ListLocator.locate_all_lines(strs_lst=strs_lst, content="FORCE", exclusion=exclusion)[0]
assert aim_idx == expected_idx


class TestACstrExtractor(PymatgenTest):
def test_extract(self):
filepath = f"{TEST_DIR}/atom.config"
Expand Down

0 comments on commit 51180fc

Please sign in to comment.