Skip to content

Commit

Permalink
Made dirax_indexing more noise-resistant
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentRDC committed Feb 2, 2021
1 parent 1174129 commit 3513648
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
9 changes: 4 additions & 5 deletions crystals/indexing/dirax.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def index_dirax(reflections, initial=None, length_bounds=(2, 20)):
points = [np.squeeze(a) for a in np.vsplit(reflections, reflections.shape[0])]
for a1, a2, a3 in product(points, repeat=3):
normal = np.cross(a2 - a1, a3 - a1)
if np.allclose(normal, 0, atol=1e-4):
if np.allclose(normal, 0, atol=1e-2):
continue
normal /= np.linalg.norm(normal)

Expand All @@ -102,7 +102,7 @@ def index_dirax(reflections, initial=None, length_bounds=(2, 20)):
continue

frac_dist = proj / d_star
nf = np.sum(np.isclose(frac_dist, np.rint(frac_dist), atol=0.01))
nf = np.sum(np.isclose(frac_dist, np.rint(frac_dist), atol=1 / 24))

t = 2 * np.pi * normal / d_star
potential_direct_vectors.add(LatVec(nf, t))
Expand Down Expand Up @@ -144,12 +144,11 @@ def _find_basis(vectors, reflections):
""" Find the shorted three linearly-independent vectors from a list. """
vectors = sorted(vectors, key=np.linalg.norm)
a1 = vectors.pop(0)

try:
index, a2 = next(
(i, v)
for i, v in enumerate(vectors)
if not np.allclose(np.cross(a1, v), 0, atol=1e-4)
if np.linalg.norm(np.cross(a1, v)) >= 1
)
except StopIteration:
raise IndexingError(
Expand All @@ -165,7 +164,7 @@ def _find_basis(vectors, reflections):
m[1, :] = a2
for v in vectors:
m[2, :] = v
if np.linalg.matrix_rank(m) == 3:
if np.linalg.matrix_rank(m, tol=0.1) == 3:
return Lattice(row_echelon_form([a1, a2, v]))

raise IndexingError(
Expand Down
20 changes: 19 additions & 1 deletion crystals/indexing/tests/test_dirax.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_dirax_indexing_ideal(name, bound):
refls = [cryst.scattering_vector(r) for r in cryst.bounded_reflections(bound=bound)]
lat, hkls = index_dirax(refls)

assert np.allclose(hkls - np.rint(hkls), 0, atol=0.1)
assert np.allclose(hkls - np.rint(hkls), 0, atol=0.001)


@pytest.mark.parametrize("name", ["Pu-epsilon", "C", "vo2-m1", "BaTiO3_cubic"])
Expand Down Expand Up @@ -53,6 +53,24 @@ def test_dirax_indexing_alien_reflections(name, bound):
lat, hkls = index_dirax(refls + aliens)
# The alien reflections will not be indexed correctly, of course
hkls = hkls[:num_aliens]
assert np.allclose(hkls - np.rint(hkls), 0, atol=0.01)


@pytest.mark.parametrize(
"name,bound", zip(["Pu-epsilon", "C", "vo2-m1", "BaTiO3_cubic"], [2, 3, 2, 2])
)
def test_dirax_indexing_noise(name, bound):
"""
Test that indexing always succeeds despite noise in reflection positions.
"""
np.random.seed(0)

cryst = Crystal.from_database(name)
refls = [
cryst.scattering_vector(r) + np.random.normal(0, scale=0.01, size=(3,))
for r in cryst.bounded_reflections(bound=bound)
]
lat, hkls = index_dirax(refls)
assert np.allclose(hkls - np.rint(hkls), 0, atol=0.1)


Expand Down

0 comments on commit 3513648

Please sign in to comment.