Skip to content

Commit

Permalink
Atom instances are now compatible with NumPy arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentRDC committed Nov 8, 2018
1 parent 29ba47e commit c63ba7e
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 7 deletions.
11 changes: 9 additions & 2 deletions crystals/atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def __init__(self, element, coords, displacement = (0,0,0), magmom = None, **kwa
magmom = ELEM_TO_MAGMOM[element]

self.element = element
self.coords = np.array(coords, dtype = np.float)
self.displacement = np.array(displacement, dtype = np.float)
self.coords = np.asfarray(coords)
self.displacement = np.asfarray(displacement)
self.magmom = magmom

def __repr__(self):
Expand Down Expand Up @@ -205,3 +205,10 @@ def transform(self, *matrices):
"""
for matrix in matrices:
self.coords = transform(matrix, self.coords)

def __array__(self, *args, **kwargs):
""" Returns an array [Z, x, y, z] """
arr = np.empty(shape = (4,), *args, **kwargs)
arr[0] = self.atomic_number
arr[1::] = self.coords
return arr
3 changes: 1 addition & 2 deletions crystals/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,5 @@ def __array__(self, *args, **kwargs):
arr = np.empty(shape = (len(self), 4), *args, **kwargs)
atoms = self.itersorted(key = lambda atm: atm.atomic_number)
for row, atm in enumerate(atoms):
arr[row, 0] = atm.atomic_number
arr[row, 1:] = atm.coords
arr[row, :] = np.array(atm, *args, **kwargs)
return arr
16 changes: 14 additions & 2 deletions tests/test_atom.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
# -*- coding: utf-8 -*-
import unittest
from copy import deepcopy
from random import choice, randint, random, seed
from random import choice
from random import randint
from random import random
from random import seed

import numpy as np

from crystals import Atom, Lattice
from crystals import Atom
from crystals import Lattice
from crystals.affine import rotation_matrix

seed(23)
Expand Down Expand Up @@ -72,6 +76,14 @@ def test_transform_back_and_forth(self):
# No assert sequence almost equal
for x1, x2 in zip(tuple(before), tuple(after)):
self.assertAlmostEqual(x1, x2)

def test_atom_array(self):
""" Test that numpy.array(Atom(...)) works as expected """
arr = np.array(self.atom)
self.assertTupleEqual(arr.shape, (4,))
self.assertEqual(arr[0], self.atom.atomic_number)
self.assertTrue(np.allclose(arr[1::], self.atom.coords))


if __name__ == '__main__':
unittest.main()
4 changes: 3 additions & 1 deletion tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

import numpy as np

from crystals import Atom, AtomicStructure, Crystal
from crystals import Atom
from crystals import AtomicStructure
from crystals import Crystal


class TestAtomicStructure(unittest.TestCase):
Expand Down

0 comments on commit c63ba7e

Please sign in to comment.