Skip to content

Commit

Permalink
update tests and enable generator with fixed seed
Browse files Browse the repository at this point in the history
  • Loading branch information
qzhu2017 committed Aug 17, 2024
1 parent 79513d5 commit 65f933b
Show file tree
Hide file tree
Showing 15 changed files with 123 additions and 24 deletions.
30 changes: 29 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,33 @@ jobs:
- name: Install dependencies
run: uv pip install .[test] --system

- name: Run Tests
- name: Test_general
run: pytest tests/test_all.py

- name: Test_crystal
run: pytest tests/test_crystal.py

- name: Test_lattice
run: pytest tests/test_lattice.py

- name: Test_molecule
run: pytest tests/test_molecule.py

- name: Test_wyckoff
run: pytest tests/test_wyckoff.py

- name: Test_symmetry
run: pytest tests/test_symmetry.py

- name: Test_supergroup
run: pytest tests/test_supergroup.py

- name: Test_group
run: pytest tests/test_group.py

#- name: Test_xrd
# run: pytest tests/test_xrd.py




4 changes: 2 additions & 2 deletions pyxtal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3351,8 +3351,8 @@ def get_1d_rep_x(self):
cell, xyzs = rep.x[0][1:], rep.x[1:]
x = cell
for xyz in xyzs:
x = np.hstack((x, xyz[2:]))

if len(xyz) > 2:
x = np.hstack((x, xyz[2:]))
return x

def from_spg_wps_rep(self, spg, wps, x, elements=None):
Expand Down
2 changes: 1 addition & 1 deletion pyxtal/crystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def _set_ion_wyckoffs(self, numIon, specie, cell, wyks):
# print('good', pt, tol, len(wp.short_distances(pt, cell, tol)))
else:
# generate wp
wp = choose_wyckoff(self.group, numIon - numIon_added, site, self.dim)
wp = choose_wyckoff(self.group, numIon - numIon_added, site, self.dim, self.rng)
if wp is not False:
# print(wp.letter)
# Generate a list of coords from ops
Expand Down
7 changes: 4 additions & 3 deletions pyxtal/molecular_crystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def set_molecules(self, molecules, torsions):
if isinstance(mol, pyxtal_molecule):
p_mol = mol
else:
p_mol = pyxtal_molecule(mol, seed=self.seed, torsions=torsions[i], tm=self.tol_matrix)
p_mol = pyxtal_molecule(mol, seed=self.seed, torsions=torsions[i], tm=self.tol_matrix, random_state=self.random_state)
self.molecules.append(p_mol)

def set_orientations(self):
Expand Down Expand Up @@ -271,6 +271,7 @@ def set_lattice(self, lattice):
thickness=self.thickness,
area=self.area,
min_special=coef * max([mol.get_max_length() for mol in self.molecules]),
random_state=self.random_state,
)
good_lattice = True
break
Expand Down Expand Up @@ -381,9 +382,9 @@ def _set_mol_wyckoffs(self, id, numMol, pyxtal_mol, valid_ori, mol_wyks):

if type(site) is dict: # site with coordinates
key = next(iter(site.keys()))
wp = wyc_mol(self.group, diff, key, valid_ori, True, self.dim)
wp = wyc_mol(self.group, diff, key, valid_ori, True, self.dim, self.random_state)
else:
wp = wyc_mol(self.group, diff, site, valid_ori, True, self.dim)
wp = wyc_mol(self.group, diff, site, valid_ori, True, self.dim, self.random_state)

if wp is not False:
# Generate a list of coords from the wyckoff position
Expand Down
32 changes: 17 additions & 15 deletions pyxtal/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def __init__(
):
mo = None
self.smile = None
self.torsionlist = None
self.torsionlist = [] #None
self.reflect = False
if seed is None:
seed = 0xF00D
Expand Down Expand Up @@ -820,12 +820,14 @@ def rdkit_mol_init(self, smile, fix, torsions):
if fix or torsions is not None or len(torsionlist) == 0:
conf = ref_conf
else:
randomSeed = -1 if self.random_state is None else 1
AllChem.EmbedMultipleConfs(
mol,
numConfs=max([1, 4 * len(torsionlist)]),
maxAttempts=200,
useRandomCoords=True,
pruneRmsThresh=0.5,
randomSeed=randomSeed,
)
N_confs = mol.GetNumConformers()
conf_id = int(self.random_state.choice(range(N_confs)))
Expand Down Expand Up @@ -857,7 +859,7 @@ def perturb_torsion(self, xyz):
slightly perturb the torsion
"""
angs = self.get_torsion_angles(xyz, self.torsionlist)
angs *= 1 + 0.1 * np.random.uniform(-1.0, 1.0, len(angs))
angs *= 1 + 0.1 * self.random_state.uniform(-1.0, 1.0, len(angs))
xyz = self.set_torsion_angles(conf, angs, torsionlist=self.torsionlist)
xyz -= self.get_center(xyz)
return xyz
Expand Down Expand Up @@ -1113,7 +1115,7 @@ def get_orientation(self, xyz, rtol=0.15):

xyz -= self.get_center(xyz)

if len(self.smile) > 1: # not in ["O", "o"]:
if self.smile is not None and len(self.smile) > 1: # not in ["O", "o"]:
rmsd, trans, reflect = self.get_rmsd(xyz)
tol = rtol * len(xyz)

Expand Down Expand Up @@ -1159,7 +1161,7 @@ def get_orientation(self, xyz, rtol=0.15):
for i, lib in enumerate(libs):
matrix0 = matrix * np.repeat(lib, 3, axis=0)
res = np.dot(ref, np.linalg.inv(matrix0))
dists[i] = np.sum((res - xyz) ** 2)
dists[i] = np.sum((res - xyz[:len(ref)]) ** 2)
# print(i, res)
id = np.argmin(dists)
matrix = matrix * np.repeat(libs[id], 3, axis=0)
Expand Down Expand Up @@ -1297,7 +1299,7 @@ def get_orientations_in_wp(self, wp, rtol=1e-2):
"""
# For single atoms, there are no constraints
if len(self.mol) == 1 or wp.index == 0:
return [Orientation([[1, 0, 0], [0, 1, 0], [0, 0, 1]], degrees=2)]
return [Orientation([[1, 0, 0], [0, 1, 0], [0, 0, 1]], degrees=2, random_state=self.random_state)]
# C1 molecule cannot take specical position
elif wp.index > 1 and self.pga.sch_symbol == "C1":
return []
Expand Down Expand Up @@ -1425,7 +1427,7 @@ def get_orientations_in_wp(self, wp, rtol=1e-2):
T = rotate_vector(v1, v2)
# If there is only one constraint
if c1[1] == []:
o = Orientation(T, degrees=1, axis=constraint1.axis)
o = Orientation(T, degrees=1, axis=constraint1.axis, random_state=self.random_state)
orientations.append(o)
else:
# Loop over second molecular constraints
Expand All @@ -1442,12 +1444,12 @@ def get_orientations_in_wp(self, wp, rtol=1e-2):
a = angle(np.dot(T2, opa.axis), constraint2.axis)
if not np.isclose(a, 0, rtol=rtol):
T2 = np.dot(np.linalg.inv(R), T)
o = Orientation(T2, degrees=0)
o = Orientation(T2, degrees=0, random_state=self.random_state)
orientations.append(o)

# Ensure the identity orientation is checked if no constraints are found
if constraints_m == []:
o = Orientation(np.identity(3), degrees=2)
o = Orientation(np.identity(3), degrees=2, random_state=self.random_state)
orientations.append(o)
# Remove redundancy from orientations
list_i = list(range(len(orientations)))
Expand Down Expand Up @@ -1606,18 +1608,18 @@ def change_orientation(self, angle="random", flip=False):
if self.degrees >= 1:
# choose the axis
if self.axis is None:
axis = np.random.rand(3) - 0.5
axis = self.random_state.random(3) - 0.5
self.axis = axis / np.linalg.norm(axis)

# parse the angle
if angle == "random":
angle = np.random.rand() * np.pi * 2
angle = self.random_state.random() * np.pi * 2
self.angle = angle

# update the matrix
r1 = Rotation.from_rotvec(self.angle * self.axis)

if self.degrees == 2 and flip and np.random.random() > 0.5:
if self.degrees == 2 and flip and self.random_state.random() > 0.5:
ax = self.random_state.choice(["x", "y", "z"])
angle0 = self.random_state.choice([90, 180, 270])
r2 = Rotation.from_euler(ax, angle0, degrees=True)
Expand Down Expand Up @@ -1646,7 +1648,7 @@ def rotate_by_matrix(self, matrix, ignore_constraint=True):
axis = None

matrix = matrix.dot(self.matrix)
return Orientation(matrix, self.degrees, axis)
return Orientation(matrix, self.degrees, axis, self.random_state)

def get_matrix(self, angle="random"):
"""
Expand All @@ -1665,15 +1667,15 @@ def get_matrix(self, angle="random"):
"""
if self.degrees == 2:
if angle == "random":
axis = np.random.sample(3)
axis = self.random_state.sample(3)
axis = axis / np.linalg.norm(axis)
angle = np.random.random() * np.pi * 2
angle = self.random_state.random() * np.pi * 2
else:
axis = self.axis
return Rotation.from_rotvec(angle * axis).as_matrix()

elif self.degrees == 1:
angle = np.random.random() * np.pi * 2 if angle == "random" else self.angle
angle = self.random_state.random() * np.pi * 2 if angle == "random" else self.angle
return Rotation.from_rotvec(angle * self.axis).as_matrix()

elif self.degrees == 0:
Expand Down
2 changes: 1 addition & 1 deletion pyxtal/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def rotate_vector(v1, v2, rtol=1e-4):
if np.abs(dot - 1) < rtol:
return np.identity(3)
elif np.abs(dot + 1) < rtol:
r = [np.random.random(), np.random.random(), np.random.random()]
r = np.random.sample(3) #[np.random.random(), np.random.random(), np.random.random()]
v3 = np.cross(v1, r)
v3 /= np.linalg.norm(v3)
# return aa2matrix(v3, np.pi)
Expand Down
1 change: 1 addition & 0 deletions pyxtal/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def from_pyxtal(cls, struc, standard=False):
vector.append(site.encode())
smiles.append(site.molecule.smile)
x = vector
if smiles[0] is None: smiles = None
return cls(x, smiles)

@classmethod
Expand Down
48 changes: 47 additions & 1 deletion tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_Al2SiO5(self):
assert s.valid


class resort(unittest.TestCase):
class Test_resort(unittest.TestCase):
def test_molecule(self):
rng = np.random.default_rng(0)
# glycine dihydrate
Expand Down Expand Up @@ -215,6 +215,52 @@ def test_atom(self):
assert N1 == N2


class Test_rng(unittest.TestCase):
"""
Test rng generators in two ways
1. random_state as a fixed integer
2. random_state as a generator
"""
def test_rng_integer(self):
xtal = pyxtal(); xtal.from_random(3, 194, ['C'], [8], random_state=0)
xs = xtal.get_1d_rep_x()
assert np.sum((xs - np.array([4.679, 6.418, 0.943])**2)) < 1e-2

xtal = pyxtal(molecular=True)
xtal.from_random(3, 19, ['aspirin'], [4], random_state=0)
rep = xtal.get_1D_representation().x
d1 = np.array([115, 17.294, 15.077, 9.018])
d2 = np.array([0.677, 0.243, 0.612, 0.057, -1.194, 0.110])
assert np.sum((rep[0] - d1)**2) < 1e-2
assert np.sum((rep[1][1:-1] - d2)**2) < 1e-2


def test_rng_generator(self):
rng = np.random.default_rng(1)
xtal = pyxtal()
xtal.from_random(3, 194, ['C'], [8], random_state=rng)
xs = xtal.get_1d_rep_x()
assert np.sum((xs - np.array([7.442, 2.110])**2)) < 1e-2

xtal.from_random(3, 194, ['C'], [8], random_state=rng)
xs = xtal.get_1d_rep_x()
assert np.sum((xs - np.array([5.864, 5.809])**2)) < 1e-2

xtal = pyxtal(molecular=True)
xtal.from_random(3, 19, ['aspirin'], [4], random_state=rng)
rep = xtal.get_1D_representation().x
d1 = np.array([115, 14.207, 18.334, 9.028])
d2 = np.array([0.294, 0.627, 0.528, 157.052, -11.968, -171.851])
assert np.sum((rep[0] - d1)**2) < 1e-2
assert np.sum((rep[1][1:-1] - d2)**2) < 1e-2

xtal.from_random(3, 19, ['aspirin'], [4], random_state=rng)
rep = xtal.get_1D_representation().x
d1 = np.array([115, 12.763, 16.639, 11.073])
d2 = np.array([0.504, 0.127, 0.585, -21.523, -68.406, 152.839])
assert np.sum((rep[0] - d1)**2) < 1e-2
assert np.sum((rep[1][1:-1] - d2)**2) < 1e-2

class Test_operations(unittest.TestCase):
def test_inverse(self):
coord0 = [0.35, 0.1, 0.4]
Expand Down
3 changes: 3 additions & 0 deletions tests/test_crystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,6 @@ def test_mutiple_species(self):
struc = pyxtal()
struc.from_random(0, 4, ["Mo", "S"], [2, 4], 1.0)
assert struc.valid

if __name__ == "__main__":
unittest.main()
3 changes: 3 additions & 0 deletions tests/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,6 @@ def test_get_wyckoff_position_from_xyz(self):
assert wp0 is None
else:
assert wp.get_label() == wp0

if __name__ == "__main__":
unittest.main()
3 changes: 3 additions & 0 deletions tests/test_lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,6 @@ def test_optlat_setting(self):
d1 = sm.StructureMatcher().get_rms_dist(pmg0, pmg1)
d2 = sm.StructureMatcher().get_rms_dist(pmg0, pmg2)
assert sum(d1) + sum(d2) < 0.001

if __name__ == "__main__":
unittest.main()
3 changes: 3 additions & 0 deletions tests/test_molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,6 @@ def test_special_sites(self):
struc = pyxtal(molecular=True)
struc.from_random(3, 61, ["Benzene"], [4])
assert struc.valid

if __name__ == "__main__":
unittest.main()
3 changes: 3 additions & 0 deletions tests/test_symmetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,6 @@ def test_from_symops(self):
wp = Wyckoff_position.from_symops(strs, G)
assert wp.number == spg
assert wp.hall_number == hall

if __name__ == "__main__":
unittest.main()
3 changes: 3 additions & 0 deletions tests/test_wyckoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,6 @@ def test_atom_site(self):
]:
site = atom_site(wp, xyz, search=True)
assert np.allclose(site.position, arr, rtol=0.001)

if __name__ == "__main__":
unittest.main()
3 changes: 3 additions & 0 deletions tests/test_xrd.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,6 @@ def test_similarity(self):
xrd3.get_profile()
s = Similarity(p1, p2, x_range=[15, 90])
assert 0.95 < s.value < 1.001

if __name__ == "__main__":
unittest.main()

0 comments on commit 65f933b

Please sign in to comment.