Skip to content

Commit

Permalink
Refs #11268. Proper unit tests for builders.
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Wedel committed Jan 14, 2016
1 parent 4f6cd1f commit 69c5f65
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 128 deletions.
55 changes: 31 additions & 24 deletions Framework/PythonInterface/plugins/algorithms/LoadCIF.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,21 @@ class CrystalStructureBuilder(object):
be passed in as well, so the source of the parsed data is replaceable.
'''

def __init__(self, cifFile):
cifData = cifFile[cifFile.keys()[0]]
def __init__(self, cifFile=None):
if cifFile is not None:
cifData = cifFile[cifFile.keys()[0]]

self.spaceGroup = self._getSpaceGroup(cifData)
self.unitCell = self._getUnitCell(cifData)
self.atoms = self._getAtoms(cifData)
self.spaceGroup = self._getSpaceGroup(cifData)
self.unitCell = self._getUnitCell(cifData)
self.atoms = self._getAtoms(cifData)

def getCrystalStructure(self):
return CrystalStructure(self.unitCell, self.spaceGroup, self.atoms)

def _getSpaceGroup(self, cifData):
try:
return self._getSpaceGroupFromString(cifData)
except RuntimeError as error:
except (RuntimeError, ValueError) as error:
try:
return self._getSpaceGroupFromNumber(cifData)
except RuntimeError as e:
Expand All @@ -55,14 +56,17 @@ def _getSpaceGroupFromString(self, cifData):
def _getCleanSpaceGroupSymbol(self, rawSpaceGroupSymbol):
# Remove :1 and :H from the symbol. Those are not required at the moment because they are the default.
removalRe = re.compile(':[1H]', re.IGNORECASE)
return re.sub(removalRe, '', rawSpaceGroupSymbol)
return re.sub(removalRe, '', rawSpaceGroupSymbol).strip()

def _getSpaceGroupFromNumber(self, cifData):
spaceGroupNumber = [int(cifData[x]) for x in
[u'_space_group_it_number', u'_symmetry_int_tables_number'] if
x in cifData.keys()][0]
[u'_space_group_it_number', u'_symmetry_int_tables_number'] if
x in cifData.keys()]

possibleSpaceGroupSymbols = SpaceGroupFactory.subscribedSpaceGroupSymbols(spaceGroupNumber)
if len(spaceGroupNumber) == 0:
raise RuntimeError('No space group symbol in CIF.')

possibleSpaceGroupSymbols = SpaceGroupFactory.subscribedSpaceGroupSymbols(spaceGroupNumber[0])

if len(possibleSpaceGroupSymbols) != 1:
raise RuntimeError(
Expand All @@ -76,19 +80,20 @@ def _getUnitCell(self, cifData):

unitCellValueMap = dict(
[(str(x), str(cifData[x])) if x in cifData.keys() else (str(x), None) for x in
unitCellComponents])
unitCellComponents])

if unitCellValueMap['_cell_length_a'] is None:
raise RuntimeError('The a-parameter of the unit cell is not specified in the supplied CIF.\n' \
'Key to look for: _cell_length_a')

replacementMap = {
'_cell_length_b': str(unitCellValueMap['_cell_length_a']),
'_cell_length_c': str(unitCellValueMap['_cell_length_a']),
'_cell_angle_alpha': '90.0', '_cell_angle_beta': '90.0', '_cell_angle_gamma': '90.0'}
'_cell_length_b': str(unitCellValueMap['_cell_length_a']),
'_cell_length_c': str(unitCellValueMap['_cell_length_a']),
'_cell_angle_alpha': '90.0', '_cell_angle_beta': '90.0', '_cell_angle_gamma': '90.0'}


unitCellValues = [unitCellValueMap[str(key)] if unitCellValueMap[str(key)] is not None else replacementMap[str(key)] for key in unitCellComponents]
unitCellValues = [
unitCellValueMap[str(key)] if unitCellValueMap[str(key)] is not None else replacementMap[str(key)] for key
in unitCellComponents]

return ' '.join(unitCellValues)

Expand Down Expand Up @@ -129,10 +134,11 @@ def _getCleanAtomSymbol(self, atomSymbol):


class UBMatrixBuilder(object):
def __init__(self, cifFile):
cifData = cifFile[cifFile.keys()[0]]
def __init__(self, cifFile = None):
if cifFile is not None:
cifData = cifFile[cifFile.keys()[0]]

self._ubMatrix = self._getUBMatrix(cifData)
self._ubMatrix = self._getUBMatrix(cifData)

def getUBMatrix(self):
return self._ubMatrix
Expand All @@ -149,6 +155,7 @@ def _getUBMatrix(self, cifData):

return ','.join(ubValues)


class LoadCIF(PythonAlgorithm):
def category(self):
return "Diffraction\\DataHandling"
Expand Down Expand Up @@ -198,11 +205,11 @@ def _getCrystalStructureFromCifFile(self, cifFile):
crystalStructure = builder.getCrystalStructure()

self.log().information(
'''Loaded the following crystal structure:
Unit cell: {0}
Space group: {1}
Atoms: {2}
'''.format(builder.unitCell, builder.spaceGroup, builder.atoms))
'''Loaded the following crystal structure:
Unit cell: {0}
Space group: {1}
Atoms: {2}
'''.format(builder.unitCell, builder.spaceGroup, builder.atoms))

return crystalStructure

Expand Down
247 changes: 143 additions & 104 deletions Framework/PythonInterface/test/python/plugins/algorithms/LoadCIFTest.py
Original file line number Diff line number Diff line change
@@ -1,112 +1,151 @@
# pylint: disable=no-init,invalid-name,too-many-public-methods
import unittest
from testhelpers import assertRaisesNothing, run_algorithm
from testhelpers.tempfile_wrapper import TemporaryFileHelper
from testhelpers import assertRaisesNothing

from mantid.kernel import *
from mantid.api import *
from mantid.simpleapi import *
from LoadCIF import UBMatrixBuilder, CrystalStructureBuilder

class CrystalStructureBuilderTest(unittest.TestCase):
from mantid.api import AlgorithmFactory


def merge_dicts(lhs, rhs):
merged = lhs.copy()
merged.update(rhs)

return merged


class CrystalStructureBuilderTestSpaceGroup(unittest.TestCase):
def setUp(self):
self.base_file = 'data_test'

self.valid_space_group_new = ['_space_group_name_h-m_alt \'P m -3 m\'']
self.valid_space_group_old = ['_symmetry_space_group_name_h-m \'P m -3 m\'']
self.invalid_space_group_wrong_new = ['_space_group_name_h-m_alt \'doesnotexist\'']
self.invalid_space_group_wrong_old = ['_symmetry_space_group_name_h-m \'doesnotExistEither\'']

self.valid_space_group_number_new = ['_space_group_it_number 230']
self.valid_space_group_number_old = ['_symmetry_int_tables_number 230']
self.invalid_space_group_number_new = ['_space_group_it_number 13']
self.invalid_space_group_number_old = ['_symmetry_int_tables_number 13']


self.cell_a = ['_cell_length_a 5.6']
self.cell_b = ['_cell_length_b 3.6']
self.cell_c = ['_cell_length_c 1.6']
self.cell_alpha = ['_cell_angle_alpha 101.1']
self.cell_beta = ['_cell_angle_beta 105.4']
self.cell_gamma = ['_cell_angle_gamma 102.2']

self.valid_atoms = ['''
loop_
_atom_site_label
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
_atom_site_U_iso_or_equiv
Si 1/8 1/8 1/8 1.0 0.02
Al 0.232 0.2112 0.43 0.5 0.01
''']

self.valid_ub = ['''
_diffrn_orient_matrix_UB_11 -0.03788345
_diffrn_orient_matrix_UB_12 0.13866313
_diffrn_orient_matrix_UB_13 0.31593824
_diffrn_orient_matrix_UB_21 0.01467139
_diffrn_orient_matrix_UB_22 -0.31690207
_diffrn_orient_matrix_UB_23 0.14084537
_diffrn_orient_matrix_UB_31 0.34471609
_diffrn_orient_matrix_UB_32 0.02872634
_diffrn_orient_matrix_UB_33 0.02872634
''']

self.workspace = CreateSingleValuedWorkspace(OutputWorkspace='testws', DataValue=100)

def testGetSpaceGroup(self):
otherComponents = self.cell_a + self.valid_atoms

self._checkRaisesNothing(otherComponents + self.valid_space_group_new)
self._checkRaisesNothing(otherComponents + self.valid_space_group_old)
self._checkRaisesNothing(otherComponents + self.valid_space_group_new)
self._checkRaisesNothing(otherComponents + self.valid_space_group_new + self.valid_space_group_old)

self._checkRaisesRuntimeError(otherComponents + self.invalid_space_group_wrong_new)
self._checkRaisesRuntimeError(otherComponents + self.invalid_space_group_wrong_old)

self._checkRaisesNothing(otherComponents + self.valid_space_group_number_new)
self._checkRaisesNothing(otherComponents + self.valid_space_group_number_old)
self._checkRaisesNothing(otherComponents + self.valid_space_group_old + self.valid_space_group_number_new)

# These tests need to be re-enabled once PR 14913 is merged.
# self._checkRaisesRuntimeError(otherComponents + self.invalid_space_group_number_new)
# self._checkRaisesRuntimeError(otherComponents + self.invalid_space_group_number_old)

def testGetUnitCell(self):
otherComponents = self.valid_space_group_new + self.valid_atoms

self._checkRaisesNothing(otherComponents + self.cell_a)
self._checkRaisesNothing(otherComponents + self.cell_a + self.cell_c)
self._checkRaisesNothing(otherComponents + self.cell_a + self.cell_b + self.cell_c)
self._checkRaisesNothing(otherComponents + self.cell_a + self.cell_b + self.cell_c + self.cell_beta)
self._checkRaisesNothing(otherComponents + self.cell_a + self.cell_b + self.cell_c + self.cell_alpha)
self._checkRaisesNothing(otherComponents + self.cell_a + self.cell_b + self.cell_c + self.cell_gamma)
self._checkRaisesNothing(otherComponents + self.cell_a + self.cell_b + self.cell_c + self.cell_alpha +
self.cell_gamma + self.cell_beta)

def _checkRaisesNothing(self, iterable):
validCif = self._getCifFromList(iterable)
validFile = TemporaryFileHelper(validCif)

assertRaisesNothing(self, LoadCIF,
Workspace = self.workspace,
InputFile = validFile.getName())

def _checkRaisesRuntimeError(self, iterable):
invalidCif = self._getCifFromList(iterable)
invalidFile = TemporaryFileHelper(invalidCif)

self.assertRaises(RuntimeError, LoadCIF,
Workspace = self.workspace,
InputFile = invalidFile.getName())

def _getCifFromList(self, iterable):
return '''{}\n{}'''.format(self.base_file, '\n'.join(iterable))
self.builder = CrystalStructureBuilder()

def test_getSpaceGroupFromString_valid_no_exceptions(self):
valid_new = {u'_space_group_name_h-m_alt': u'P m -3 m'}
valid_old = {u'_symmetry_space_group_name_h-m': u'P m -3 m'}

assertRaisesNothing(self, self.builder._getSpaceGroupFromString, cifData=valid_old)
assertRaisesNothing(self, self.builder._getSpaceGroupFromString, cifData=valid_new)
assertRaisesNothing(self, self.builder._getSpaceGroupFromString, cifData=merge_dicts(valid_new, valid_old))

def test_getSpaceGroupFromString_valid_correct_value(self):
valid_new = {u'_space_group_name_h-m_alt': u'P m -3 m'}
valid_old = {u'_symmetry_space_group_name_h-m': u'P m -3 m'}
valid_old_different = {u'_symmetry_space_group_name_h-m': u'F d d d'}
invalid_old = {u'_symmetry_space_group_name_h-m': u'invalid'}

self.assertEqual(self.builder._getSpaceGroupFromString(valid_new), 'P m -3 m')
self.assertEqual(self.builder._getSpaceGroupFromString(valid_old), 'P m -3 m')
self.assertEqual(self.builder._getSpaceGroupFromString(merge_dicts(valid_new, valid_old)), 'P m -3 m')
self.assertEqual(self.builder._getSpaceGroupFromString(merge_dicts(valid_new, valid_old_different)), 'P m -3 m')
self.assertEqual(self.builder._getSpaceGroupFromString(merge_dicts(valid_new, invalid_old)), 'P m -3 m')

def test_getSpaceGroupFromString_invalid(self):
valid_new = {u'_space_group_name_h-m_alt': u'P m -3 m'}
valid_old = {u'_symmetry_space_group_name_h-m': u'P m -3 m'}
invalid_new = {u'_space_group_name_h-m_alt': u'invalid'}
invalid_old = {u'_symmetry_space_group_name_h-m': u'invalid'}

self.assertRaises(RuntimeError, self.builder._getSpaceGroupFromString, cifData={})
self.assertRaises(ValueError, self.builder._getSpaceGroupFromString, cifData=invalid_new)
self.assertRaises(ValueError, self.builder._getSpaceGroupFromString, cifData=invalid_old)
self.assertRaises(ValueError, self.builder._getSpaceGroupFromString,
cifData=merge_dicts(invalid_new, valid_old))

def test_getCleanSpaceGroupSymbol(self):
fn = self.builder._getCleanSpaceGroupSymbol

self.assertEqual(fn('P m -3 m :1'), 'P m -3 m')
self.assertEqual(fn('P m -3 m :H'), 'P m -3 m')

def test_getSpaceGroupFromNumber_invalid(self):
invalid_old = {u'_symmetry_int_tables_number': u'400'}
invalid_new = {u'_space_group_it_number': u'400'}

self.assertRaises(RuntimeError, self.builder._getSpaceGroupFromNumber, cifData={})
self.assertRaises(RuntimeError, self.builder._getSpaceGroupFromNumber, cifData=invalid_old)
self.assertRaises(RuntimeError, self.builder._getSpaceGroupFromNumber, cifData=invalid_new)


class CrystalStructureBuilderTestUnitCell(unittest.TestCase):
def setUp(self):
self.builder = CrystalStructureBuilder()

def test_getUnitCell_invalid(self):
invalid_no_a = {u'_cell_length_b': u'5.6'}
self.assertRaises(RuntimeError, self.builder._getUnitCell, cifData={})
self.assertRaises(RuntimeError, self.builder._getUnitCell, cifData=invalid_no_a)

def test_getUnitCell_cubic(self):
cell = {u'_cell_length_a': u'5.6'}

self.assertEqual(self.builder._getUnitCell(cell), '5.6 5.6 5.6 90.0 90.0 90.0')

def test_getUnitCell_tetragonal(self):
cell = {u'_cell_length_a': u'5.6', u'_cell_length_c': u'2.3'}

self.assertEqual(self.builder._getUnitCell(cell), '5.6 5.6 2.3 90.0 90.0 90.0')

def test_getUnitCell_orthorhombic(self):
cell = {u'_cell_length_a': u'5.6', u'_cell_length_b': u'1.6', u'_cell_length_c': u'2.3'}

self.assertEqual(self.builder._getUnitCell(cell), '5.6 1.6 2.3 90.0 90.0 90.0')

def test_getUnitCell_hexagonal(self):
cell = {u'_cell_length_a': u'5.6', u'_cell_length_c': u'2.3', u'_cell_angle_gamma': u'120.0'}

self.assertEqual(self.builder._getUnitCell(cell), '5.6 5.6 2.3 90.0 90.0 120.0')


class CrystalStructureBuilderTestAtoms(unittest.TestCase):
def setUp(self):
self.builder = CrystalStructureBuilder()

def test_getAtoms_required_keys(self):
mandatoryKeys = dict([(u'_atom_site_label', [u'Si']),
(u'_atom_site_fract_x', [u'1/8']),
(u'_atom_site_fract_y', [u'1/8']),
(u'_atom_site_fract_z', [u'1/8'])])

for key in mandatoryKeys:
tmp = mandatoryKeys.copy()
del tmp[key]
self.assertRaises(RuntimeError, self.builder._getAtoms, cifData=tmp)

def test_getAtoms_correct(self):
data = dict([(u'_atom_site_label', [u'Si', u'Al']),
(u'_atom_site_fract_x', [u'1/8', u'0.34']),
(u'_atom_site_fract_y', [u'1/8', u'0.56']),
(u'_atom_site_fract_z', [u'1/8', u'0.23']),
(u'_atom_site_occupancy', [u'1.0', u'1.0']),
(u'_atom_site_U_iso_or_equiv', [u'0.01', u'0.02'])])

self.assertEqual(self.builder._getAtoms(data), 'Si 1/8 1/8 1/8 1.0 0.01;Al 0.34 0.56 0.23 1.0 0.02')


class UBMatrixBuilderTest(unittest.TestCase):
def setUp(self):
self.builder = UBMatrixBuilder()
self.valid_matrix = {u'_diffrn_orient_matrix_ub_11': u'-0.03',
u'_diffrn_orient_matrix_ub_12': u'0.13',
u'_diffrn_orient_matrix_ub_13': u'0.31',
u'_diffrn_orient_matrix_ub_21': u'0.01',
u'_diffrn_orient_matrix_ub_22': u'-0.31',
u'_diffrn_orient_matrix_ub_23': u'0.14',
u'_diffrn_orient_matrix_ub_31': u'0.34',
u'_diffrn_orient_matrix_ub_32': u'0.02',
u'_diffrn_orient_matrix_ub_33': u'0.02'}

def test_getUBMatrix_invalid(self):
for key in self.valid_matrix:
tmp = self.valid_matrix.copy()
del tmp[key]

self.assertRaises(RuntimeError, self.builder._getUBMatrix, cifData=tmp)

def test_getUBMatrix_correct(self):
self.assertEqual(self.builder._getUBMatrix(self.valid_matrix), '-0.03,0.13,0.31,0.01,-0.31,0.14,0.34,0.02,0.02')


if __name__ == '__main__':
# Only test if algorithm is registered (pyparsing dependency).
# Only test if algorithm is registered (PyCifRW dependency).
if AlgorithmFactory.exists("LoadCIF"):
unittest.main()
unittest.main()

0 comments on commit 69c5f65

Please sign in to comment.