Skip to content

Commit

Permalink
1.Dropped torch.det for MPS compatibility 2.better testing for rotati…
Browse files Browse the repository at this point in the history
…on/supercell invariance
  • Loading branch information
BowenD-UCB committed Sep 27, 2023
1 parent d5ba24b commit 4cee9e6
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 137 deletions.
2 changes: 1 addition & 1 deletion chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ def from_graphs(
else:
strain = None
lattice = graph.lattice
volumes.append(torch.det(lattice))
volumes.append(torch.dot(lattice[0], torch.cross(lattice[1], lattice[2])))
strains.append(strain)

# Bonds
Expand Down
196 changes: 60 additions & 136 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np
import pytest
from numpy.testing import assert_allclose
from pymatgen.core import Structure
from pytest import mark

Expand Down Expand Up @@ -119,163 +118,88 @@ def test_predict_structure() -> None:
assert out["atom_fea"].shape == (8, 64)


def test_predict_structure_rotated() -> None:
@mark.parametrize("axis", [[0, 0, 1], [1, 1, 0], [-2, 3, 1]])
@mark.parametrize("rotation_angle", [5, 30, 45, 120])
def test_predict_structure_rotated(rotation_angle: float, axis: list) -> None:
from pymatgen.transformations.standard_transformations import RotationTransformation

rotation_transformation = RotationTransformation(axis=[0, 0, 1], angle=30)
rotated_structure = rotation_transformation.apply_transformation(structure)
pristine_structure = structure.copy()
pristine_structure.perturb(0.1)
pristine_prediction = model.predict_structure(
pristine_structure, return_site_energies=True
)

# Rotation
rotation_transformation = RotationTransformation(axis=axis, angle=rotation_angle)
rotated_structure = rotation_transformation.apply_transformation(pristine_structure)
out = model.predict_structure(rotated_structure, return_site_energies=True)

assert sorted(out) == ["e", "f", "m", "s", "site_energies"]
assert out["e"] == pytest.approx(-7.37159, abs=1e-4)

# Define a rotation matrix for rotation about Z-axis by 90 degrees
theta = np.radians(30) # Convert angle to radians
c, s = np.cos(theta), np.sin(theta)

rotation_matrix = np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]])

force = np.array(
[
[4.4703484e-08, -4.2840838e-08, 2.4071064e-02],
[-4.4703484e-08, -1.4551915e-08, -2.4071217e-02],
[-1.7881393e-07, 1.0244548e-08, 2.5402933e-02],
[5.9604645e-08, -2.3283064e-08, -2.5402665e-02],
[-1.1920929e-07, 6.6356733e-08, -2.1660209e-02],
[2.3543835e-06, -8.0077443e-06, 9.5508099e-03],
[-2.2947788e-06, 7.9898164e-06, -9.5513463e-03],
[-5.9604645e-08, -0.0000000e00, 2.1660626e-02],
]
)
rotated_force = force @ rotation_matrix
assert out["f"] == pytest.approx(rotated_force, abs=1e-4)
assert out["e"] == pytest.approx(pristine_prediction["e"], rel=1e-4, abs=1e-4)

magmom = [0.00521, 0.00521, 3.85728, 3.85729, 0.02538, 0.03706, 0.03706, 0.02538]
assert out["m"] == pytest.approx(magmom, abs=1e-4)
# Convert angle to radians
theta = np.radians(rotation_angle)

site_energies = [
-3.8090043,
-3.8090036,
-10.2737875,
-10.2737875,
-7.659066,
-7.744509,
-7.744509,
-7.659066,
]
assert out["site_energies"] == pytest.approx(site_energies, rel=1e-4, abs=1e-4)
# Normalize the axis
axis_normalized = axis / np.linalg.norm(axis)
a, b, c = axis_normalized

# Compute the skew-symmetric matrix K
K = np.array([[0, -c, b], [c, 0, -a], [-b, a, 0]])

# Compute the rotation matrix using Rodrigues' formula
R = np.eye(3) + np.sin(theta) * K + (1 - np.cos(theta)) * np.dot(K, K)

rotated_force = pristine_prediction["f"] @ R.transpose()
assert out["f"] == pytest.approx(rotated_force, rel=1e-4, abs=1e-3)

rotated_stress = R @ pristine_prediction["s"] @ R.transpose()
assert out["s"] == pytest.approx(np.array(rotated_stress), rel=1e-4, abs=1e-3)

assert out["m"] == pytest.approx(pristine_prediction["m"], rel=1e-4, abs=1e-4)

assert out["site_energies"] == pytest.approx(
pristine_prediction["site_energies"], rel=1e-4, abs=1e-4
)


def test_predict_supercell() -> None:
supercell = structure.make_supercell([2, 2, 1], in_place=False)
pristine_structure = structure.copy()
pristine_structure.perturb(0.1)
pristine_prediction = model.predict_structure(
pristine_structure, return_site_energies=True
)
supercell = pristine_structure.make_supercell([2, 2, 1], in_place=False)
out = model.predict_structure(supercell, return_site_energies=True)

assert sorted(out) == ["e", "f", "m", "s", "site_energies"]
assert out["e"] == pytest.approx(-7.37159, abs=1e-4)
assert out["e"] == pytest.approx(pristine_prediction["e"], rel=1e-4, abs=1e-4)

forces = [
[4.4703484e-08, -4.2840838e-08, 2.4071064e-02],
[-4.4703484e-08, -1.4551915e-08, -2.4071217e-02],
[-1.7881393e-07, 1.0244548e-08, 2.5402933e-02],
[5.9604645e-08, -2.3283064e-08, -2.5402665e-02],
[-1.1920929e-07, 6.6356733e-08, -2.1660209e-02],
[2.3543835e-06, -8.0077443e-06, 9.5508099e-03],
[-2.2947788e-06, 7.9898164e-06, -9.5513463e-03],
[-5.9604645e-08, -0.0000000e00, 2.1660626e-02],
]
for idx, force in enumerate(forces):
for cell_idx in range(4):
assert_allclose(out["f"][idx * 4 + cell_idx], force, atol=1e-4)

stress = [
[3.3677614e-01, -1.9665707e-07, -5.6416429e-06],
[4.9939729e-07, 2.4675032e-01, 1.8549043e-05],
[-4.0414070e-06, 1.9096897e-05, 4.0323928e-02],
]
assert_allclose(out["s"], stress, atol=1e-4)
assert out["f"] == pytest.approx(
np.repeat(pristine_prediction["f"], 4, axis=0), rel=1e-4, abs=1e-4
)

magmoms = [0.00521, 0.00521, 3.85728, 3.85729, 0.02538, 0.03706, 0.03706, 0.02538]
for idx, magmom in enumerate(magmoms):
for cell_idx in range(4):
assert_allclose(out["m"][idx * 4 + cell_idx], magmom, atol=1e-4)
assert out["s"] == pytest.approx(pristine_prediction["s"], rel=1e-4, abs=1e-4)

site_energies = [
-3.8090043,
-3.8090036,
-3.8090043,
-3.8090036,
-3.8090043,
-3.8090036,
-3.8090043,
-3.8090036,
-10.2737875,
-10.2737875,
-10.2737875,
-10.2737875,
-10.2737875,
-10.2737875,
-10.2737875,
-10.2737875,
-7.659066,
-7.659066,
-7.659066,
-7.659066,
-7.744509,
-7.744509,
-7.744509,
-7.744509,
-7.744509,
-7.744509,
-7.744509,
-7.744509,
-7.659066,
-7.659066,
-7.659066,
-7.659066,
]
assert out["site_energies"] == pytest.approx(site_energies, rel=1e-4, abs=1e-4)
assert out["site_energies"] == pytest.approx(
np.repeat(pristine_prediction["site_energies"], 4), rel=1e-4, abs=1e-4
)


def test_predict_batched_structures() -> None:
structs = [structure, structure, structure]
pristine_structure = structure.copy()
pristine_structure.perturb(0.1)
pristine_prediction = model.predict_structure(
pristine_structure, return_site_energies=True
)
structs = [pristine_structure, pristine_structure, pristine_structure]
out = model.predict_structure(structs, return_site_energies=True)
assert len(out) == len(structs)

assert all(preds["e"] == pytest.approx(-7.37159, abs=1e-4) for preds in out)

forces = [
[4.4703484e-08, -4.2840838e-08, 2.4071064e-02],
[-4.4703484e-08, -1.4551915e-08, -2.4071217e-02],
[-1.7881393e-07, 1.0244548e-08, 2.5402933e-02],
[5.9604645e-08, -2.3283064e-08, -2.5402665e-02],
[-1.1920929e-07, 6.6356733e-08, -2.1660209e-02],
[2.3543835e-06, -8.0077443e-06, 9.5508099e-03],
[-2.2947788e-06, 7.9898164e-06, -9.5513463e-03],
[-5.9604645e-08, -0.0000000e00, 2.1660626e-02],
]
stress = [
[3.3677614e-01, -1.9665707e-07, -5.6416429e-06],
[4.9939729e-07, 2.4675032e-01, 1.8549043e-05],
[-4.0414070e-06, 1.9096897e-05, 4.0323928e-02],
]
magmom = [0.00521, 0.00521, 3.85728, 3.85729, 0.02538, 0.03706, 0.03706, 0.02538]
site_energies = [
-3.8090043,
-3.8090036,
-10.2737875,
-10.2737875,
-7.659066,
-7.744509,
-7.744509,
-7.659066,
]
for preds in out:
assert_allclose(preds["f"], forces, atol=1e-4)
assert_allclose(preds["s"], stress, atol=1e-4)
assert preds["m"] == pytest.approx(magmom, rel=1e-4, abs=1e-4)
assert preds["site_energies"] == pytest.approx(
site_energies, rel=1e-4, abs=1e-4
)
for property in ["e", "f", "s", "m", "site_energies"]:
assert preds[property] == pytest.approx(
pristine_prediction[property], rel=1e-4, abs=1e-4
)


model_arg_keys = frozenset(
Expand Down

0 comments on commit 4cee9e6

Please sign in to comment.