diff --git a/pyat/README.rst b/pyat/README.rst index e48e03d6f..7b08d72d3 100644 --- a/pyat/README.rst +++ b/pyat/README.rst @@ -27,30 +27,13 @@ It is easiest to do this using a virtualenv, inside pyat: * ``virtualenv --no-site-packages venv`` * ``source venv/bin/activate # or venv\Scripts\activate on Windows`` -* ``pip install numpy`` -* ``pip install scipy`` -* ``pip install pytest`` +* ``pip install -r requirements.txt`` * ``python setup.py develop`` Finally, you should be able to run the tests: * ``py.test test`` -Any changes to .py files are automatically reinstalled in the build, but to -ensure any changes to .c files are reinstalled rerun: - -* ``python setup.py develop`` - -If you get strange behaviour even after running setup.py develop again, then -running the following should fix it: - -* ``find at -name "*.pyc" -exec rm '{}' \;`` -* ``find at -name "*.so" -exec rm '{}' \;`` -* ``python setup.py develop`` - -N.B. setup.py develop needs to be run with the same version of Python that -you are using to run pyAT. - Comparing results with Matlab ----------------------------- @@ -67,3 +50,19 @@ Print statements in the C code will work once the integrators are recompiled. To force recompilation, remove the build directory: * ``rm -rf build`` + +Any changes to .py files are automatically reinstalled in the build, but to +ensure any changes to .c files are reinstalled rerun: + +* ``python setup.py develop`` + +If you get strange behaviour even after running setup.py develop again, then +running the following, inside pyat, should fix it: + +* ``rm -rf build`` +* ``find at -name "*.pyc" -exec rm '{}' \;`` +* ``find at -name "*.so" -exec rm '{}' \;`` +* ``python setup.py develop`` + +N.B. setup.py develop needs to be run with the same version of Python (and +numpy) that you are using to run pyAT. diff --git a/pyat/at/lattice/elements.py b/pyat/at/lattice/elements.py index 313a4f4e3..3c51ab74d 100644 --- a/pyat/at/lattice/elements.py +++ b/pyat/at/lattice/elements.py @@ -60,19 +60,14 @@ def __str__(self): return '\n'.join((self.__class__.__name__ + ':', '\n'.join(keywords))) def __repr__(self): - def differ(v1, v2): - if isinstance(v1, numpy.ndarray): - return not numpy.array_equal(v1, v2) - else: - return v1 != v2 attrs = vars(self).copy() - args = [attrs.pop(k, getattr(self, k)) for k in self.REQUIRED_ATTRIBUTES] - defelem = self.__class__(*args) - arguments = ('{0!r}'.format(arg) for arg in args) - keywords = ('{0}={1!r}'.format(k, v) for k, v in attrs.items() - if differ(v, getattr(defelem, k, None))) - return '{0}({1})'.format(self.__class__.__name__, ', '.join( - itertools.chain(arguments, keywords))) + arguments = [attrs.pop(k, getattr(self, k)) for k in + self.REQUIRED_ATTRIBUTES] + defelem = self.__class__(*arguments) + keywords = ['{0!r}'.format(arg) for arg in arguments] + keywords += ['{0}={1!r}'.format(k, v) for k, v in attrs.items() + if not numpy.array_equal(v, getattr(defelem, k, None))] + return '{0}({1})'.format(self.__class__.__name__, ', '.join(keywords)) def divide(self, frac, keep_axis=False): """split the element in len(frac) pieces whose length diff --git a/pyat/at/lattice/lattice_object.py b/pyat/at/lattice/lattice_object.py index 69b34a6cb..e312dac0d 100644 --- a/pyat/at/lattice/lattice_object.py +++ b/pyat/at/lattice/lattice_object.py @@ -9,6 +9,7 @@ from at.physics import find_orbit4, find_orbit6, find_sync_orbit, find_m44 from at.physics import find_m66, linopt, ohmi_envelope, get_mcf + __all__ = ['Lattice'] TWO_PI_ERROR = 1.E-4 @@ -174,16 +175,22 @@ def __delitem__(self, key): super(Lattice, self).__delitem__(i) def __repr__(self): - at = ', '.join( - '{0}={1!r}'.format(key, val) for key, val in vars(self).items() if - not key.startswith('_')) - return 'Lattice({0}, {1})'.format(super(Lattice, self).__repr__(), at) + attrs = vars(self).copy() + keywords = ['{0}={1!r}'.format(key, attrs.pop(key)) for key in + Lattice._translate.values()] + keywords += ['{0}={1!r}'.format(key, val) for key, val in + attrs.items() if not key.startswith('_')] + return 'Lattice({0}, {1})'.format(super(Lattice, self).__repr__(), + ', '.join(keywords)) def __str__(self): - at = ', '.join( - '{0}={1!r}'.format(key, val) for key, val in vars(self).items() if - not key.startswith('_')) - return 'Lattice(<{0} elements>, {1})'.format(len(self), at) + attrs = vars(self).copy() + keywords = ['{0}={1!r}'.format(key, attrs.pop(key)) for key in + Lattice._translate.values()] + keywords += ['{0}={1!r}'.format(key, val) for key, val in + attrs.items() if not key.startswith('_')] + return 'Lattice(<{0} elements>, {1})'.format(len(self), + ', '.join(keywords)) def copy(self): """Return a shallow copy""" @@ -332,6 +339,13 @@ def linopt(self, *args, **kwargs): raise AtError('linopt needs no radiation in the lattice') return linopt(self, *args, **kwargs) + def get_mcf(self, *args, **kwargs): + """See at.physics.get_mcf(): + """ + if self._radiation: + raise AtError('get_mcf needs no radiation in the lattice') + return get_mcf(self, *args, **kwargs) + def ohmi_envelope(self, *args, **kwargs): """See at.physics.ohmi_envelope(): """ @@ -341,7 +355,6 @@ def ohmi_envelope(self, *args, **kwargs): Lattice.get_s_pos = get_s_pos -Lattice.get_mcf = get_mcf Lattice.find_orbit4 = find_orbit4 Lattice.find_sync_orbit = find_sync_orbit Lattice.find_orbit6 = find_orbit6 @@ -350,7 +363,9 @@ def ohmi_envelope(self, *args, **kwargs): if sys.version_info < (3, 0): Lattice.linopt.__func__.__doc__ += linopt.__doc__ + Lattice.get_mcf.__func__.__doc__ += get_mcf.__doc__ Lattice.ohmi_envelope.__func__.__doc__ += ohmi_envelope.__doc__ else: Lattice.linopt.__doc__ += linopt.__doc__ + Lattice.get_mcf.__doc__ += get_mcf.__doc__ Lattice.ohmi_envelope.__doc__ += ohmi_envelope.__doc__ diff --git a/pyat/at/lattice/utils.py b/pyat/at/lattice/utils.py index a88788b02..6d8f0bb68 100644 --- a/pyat/at/lattice/utils.py +++ b/pyat/at/lattice/utils.py @@ -16,6 +16,9 @@ """ import numpy import itertools +from warnings import warn +from fnmatch import fnmatch +from at.lattice import elements class AtError(Exception): @@ -165,6 +168,40 @@ def refpts_iterator(ring, refpts): yield ring[i] +def get_elements(ring, key, quiet=True): + """Get the elements of a family or class (type) from the lattice. + + Args: + ring: lattice from which to retrieve the elements. + key: can be: + 1) an element instance, will return all elements of the same type + in the lattice, e.g. key=Drift('d1', 1.0) + 2) an element type, will return all elements of that type in the + lattice, e.g. key=at.elements.Sextupole + 3) a string to match against elements' FamName, supports Unix + shell-style wildcards, e.g. key='BPM_*1' + quiet: if false print information about matched elements for FamName + matches, defaults to True. + """ + if isinstance(key, elements.Element): + elems = [elem for elem in ring if isinstance(elem, type(key))] + elif isinstance(key, type): + elems = [elem for elem in ring if isinstance(elem, key)] + elif numpy.issubdtype(type(key), numpy.str_): + elems = [elem for elem in ring if fnmatch(elem.FamName, key)] + if not quiet: + matched_fams = set(elem.FamName for elem in elems) + ending = 'y' if len(matched_fams) == 1 else 'ies' + print("String '{0}' matched {1} famil{2}: {3}\n" + "all corresponding elements have been " + "returned.".format(key, len(matched_fams), ending, + ', '.join(matched_fams))) + else: + raise TypeError("Invalid key type {0}; please enter a string, element" + " type, or element instance.".format(type(key))) + return elems + + def get_s_pos(ring, refpts=None): """ Return a numpy array corresponding to the s position of the specified @@ -185,6 +222,41 @@ def get_s_pos(ring, refpts=None): return s_pos[refpts] +def get_ring_energy(ring): + """Establish the energy of the ring from the Energy attribute of the + elements. Energies of RingParam elements are most prioritised, if none are + found then the energies from RFCavity elements will be used, if none are + found then the energies from all elements will be used. An error will be + raised if no elements have a 'Energy' attribute or if inconsistent values + for energy are found. + + Args: + ring: sequence of elements of which you wish to establish the energy. + """ + rp_energies = [] + rf_energies = [] + energies = [] + for elem in ring: + if hasattr(elem, 'Energy'): + energies.append(elem.Energy) + if isinstance(elem, elements.RingParam): + rp_energies.append(elem.Energy) + elif isinstance(elem, elements.RFCavity): + rf_energies.append(elem.Energy) + if not energies: + raise AtError('Lattice energy is not defined.') + elif rp_energies: + energy = max(rp_energies) + elif rf_energies: + energy = max(rf_energies) + else: + energy = max(energies) + if len(set(energies)) > 1: + warn(AtWarning('Inconsistent energy values in ring, {0} has been ' + 'used.'.format(energy))) + return energy + + def tilt_elem(elem, rots): """ set a new tilt angle to an element. The rotation matrices are stored in the R1 and R2 attributes diff --git a/pyat/at/load/matfile.py b/pyat/at/load/matfile.py index 7a82ed7c6..f3b995fb8 100644 --- a/pyat/at/load/matfile.py +++ b/pyat/at/load/matfile.py @@ -4,7 +4,6 @@ import scipy.io import numpy from . import element_from_dict -from ..lattice import Lattice def _load_element(index, element_array, check=True, quiet=False): @@ -35,7 +34,7 @@ def load_mat(filename, key=None, check=True, quiet=False): check=True if False, skip the coherence tests OUTPUT - pyat Lattice object + list pyat ring """ m = scipy.io.loadmat(filename) if key is None: diff --git a/pyat/at/load/utils.py b/pyat/at/load/utils.py index 809f9c157..7a0e00ee1 100644 --- a/pyat/at/load/utils.py +++ b/pyat/at/load/utils.py @@ -1,8 +1,11 @@ """ Conversion utilities for creating pyat elements """ +import os import numpy from warnings import warn +from distutils import sysconfig +from at import integrators from at.lattice import elements from at.lattice.utils import AtWarning @@ -77,6 +80,13 @@ def low_order(key): return CLASS_MAPPING[fam_name.lower()] except KeyError: pass_method = elem_dict.get('PassMethod', '') + if (quiet is False) and (pass_method is ''): + warn(AtWarning("No PassMethod provided." + "\n{0}".format(elem_dict))) + elif (quiet is False) and (not pass_method.endswith('Pass')): + warn(AtWarning("Invalid PassMethod ({0}), provided pass " + "methods should end in 'Pass'." + "\n{1}".format(pass_method, elem_dict))) class_from_pass = PASS_MAPPING.get(pass_method) if class_from_pass is not None: return class_from_pass @@ -139,28 +149,35 @@ def sanitise_class(index, class_name, kwargs): AttributeError: if the PassMethod and Class are incompatible. """ def err_message(message, *args): - location = '' if index is None else 'Error in element {0}: '.format( - index) - return ''.join( - (location, - 'PassMethod {0} is not compatible with '.format(pass_method), - message.format(*args), '\n{0}'.format(kwargs))) + location = ': ' if index is None else ' {0}: '.format(index) + return ''.join(('Error in element', location, + 'PassMethod {0} '.format(pass_method), + message.format(*args), '\n{0}'.format(kwargs))) pass_method = kwargs.get('PassMethod') if pass_method is not None: pass_to_class = PASS_MAPPING.get(pass_method) length = float(kwargs.get('Length', 0.0)) - if (pass_method == 'IdentityPass') and (length != 0.0): - raise AttributeError(err_message("length {0}.", length)) + extension = sysconfig.get_config_vars().get('EXT_SUFFIX', '.so') + file_path = os.path.realpath(os.path.join(integrators.__path__[0], + pass_method + extension)) + if not os.path.isfile(file_path): + raise AttributeError(err_message("does not exist.")) + elif (pass_method == 'IdentityPass') and (length != 0.0): + raise AttributeError(err_message("is not compatible with " + "length {0}.", length)) elif pass_to_class is not None: if pass_to_class != class_name: - raise AttributeError(err_message("Class {0}.", class_name)) + raise AttributeError(err_message("is not compatible with " + "Class {0}.", class_name)) elif class_name in ['Marker', 'Monitor', 'RingParam']: if pass_method != 'IdentityPass': - raise AttributeError(err_message("Class {0}.", class_name)) + raise AttributeError(err_message("is not compatible with " + "Class {0}.", class_name)) elif class_name == 'Drift': if pass_method != 'DriftPass': - raise AttributeError(err_message("Class {0}.", class_name)) + raise AttributeError(err_message("is not compatible with " + "Class {0}.", class_name)) class_name = find_class_name(elem_dict, quiet=quiet) if check: diff --git a/pyat/at/physics/amat.py b/pyat/at/physics/amat.py index 3232cff37..7477dc440 100644 --- a/pyat/at/physics/amat.py +++ b/pyat/at/physics/amat.py @@ -1,7 +1,6 @@ """""" from collections import namedtuple import numpy -from numpy.linalg import multi_dot as md from scipy.linalg import block_diag, eig, inv, det from math import pi from ..physics import jmat @@ -56,7 +55,7 @@ def amat(tt): # n1x n1y n1z # n2x n2y n2z # n3x n3y n3z - nn = 0.5 * abs(numpy.sqrt(-1.j * md((vn.conj().T, jmt, _vxyz[dms - 1])))) + nn = 0.5 * abs(numpy.sqrt(-1.j * vn.conj().T.dot(jmt).dot(_vxyz[dms - 1]))) ind = numpy.argmax(nn[select, :][:, select], axis=0) v_ordered = vn[:, 2 * ind] aa = numpy.vstack((numpy.real(v_ordered), numpy.imag(v_ordered))).reshape( @@ -99,11 +98,11 @@ def decode(rot22): dms = int(nv / 2) jmt = jmat(dms) aa = amat(tt) - rmat = md((inv(aa), tt, aa)) + rmat = inv(aa).dot(tt.dot(aa)) damping_rates, tunes = zip(*(decode(rmat[s, s]) for s in _submat[:dms])) if rr is None: return _Data1(tunes, damping_rates, get_mode_matrices(aa)) else: - rdiag = numpy.diag(md((aa.T, jmt, rr, jmt, aa))) + rdiag = numpy.diag(aa.T.dot(jmt.dot(rr.dot(jmt.dot(aa))))) mode_emit = -0.5 * (rdiag[0:nv:2] + rdiag[1:nv:2]) return _Data2(tunes, damping_rates, get_mode_matrices(aa), mode_emit) diff --git a/pyat/at/physics/linear.py b/pyat/at/physics/linear.py index 2a1e58287..6a3dcddb5 100644 --- a/pyat/at/physics/linear.py +++ b/pyat/at/physics/linear.py @@ -146,7 +146,7 @@ def get_twiss(ring, dp=0.0, refpts=None, get_chrom=False, orbit=None, l_down['closed_orbit'])[:, :4] / ddp disp0 = (d0_up['closed_orbit'] - d0_down['closed_orbit'])[:4] / ddp else: - chrom = None + chrom = numpy.array([numpy.NaN, numpy.NaN]) dispersion = numpy.NaN disp0 = numpy.NaN @@ -314,7 +314,7 @@ def analyze(r44): l_down['closed_orbit'])[:, :4] / ddp disp0 = (d0_up['closed_orbit'] - d0_down['closed_orbit'])[:4] / ddp else: - chrom = None + chrom = numpy.array([numpy.NaN, numpy.NaN]) dispersion = numpy.NaN disp0 = numpy.NaN diff --git a/pyat/at/physics/matrix.py b/pyat/at/physics/matrix.py index b1433a3d0..dcf19aaf3 100644 --- a/pyat/at/physics/matrix.py +++ b/pyat/at/physics/matrix.py @@ -81,7 +81,7 @@ def find_m44(ring, dp=0.0, refpts=None, orbit=None, keep_lattice=False, **kwargs def mrotate(m): m = numpy.squeeze(m) - return numpy.linalg.multi_dot([m, m44, _jmt.T, m.T, _jmt]) + return m.dot(m44.dot(_jmt.T.dot(m.T.dot(_jmt)))) xy_step = kwargs.pop('XYStep', XYDEFSTEP) full = kwargs.pop('full', False) diff --git a/pyat/at/physics/radiation.py b/pyat/at/physics/radiation.py index d6f10cc9a..efbfe8c2f 100644 --- a/pyat/at/physics/radiation.py +++ b/pyat/at/physics/radiation.py @@ -3,7 +3,8 @@ """ import numpy from scipy.linalg import inv, det, solve_sylvester -from at.lattice import uint32_refpts +import at +from at.lattice import uint32_refpts, get_ring_energy from at.tracking import lattice_pass from at.physics import find_orbit6, find_m66, find_elem_m66, get_tunes_damp # noinspection PyUnresolvedReferences @@ -22,15 +23,16 @@ ('emitXYZ', numpy.float64, (3,))] -def ohmi_envelope(ring, refpts=None, orbit=None, keep_lattice=False): +def ohmi_envelope(ring, refpts=None, orbit=None, keep_lattice=False, + energy=None): """ Calculate the equilibrium beam envelope in a circular accelerator using Ohmi's beam envelope formalism [1] - emit0, mode_emit, damping_rates, tunes, emit = ohmi_envelope(ring[, refpts]) + emit0, beamdata, emit = ohmi_envelope(ring[, refpts]) PARAMETERS - ring at.Lattice object + ring lattice description. refpts=None elements at which data is returned. It can be: 1) an integer in the range [-len(ring), len(ring)-1] selecting the element according to python indexing @@ -43,7 +45,11 @@ def ohmi_envelope(ring, refpts=None, orbit=None, keep_lattice=False): KEYWORDS orbit=None Avoids looking for the closed orbit if it is already known (6,) array) - keep_lattice=False Assume no lattice change since the previous tracking + keep_lattice=False Assume no lattice change since the previous + tracking + energy=None Energy of the ring; if it is not specified it is: + - lattice.energy if a lattice object is passed, + - otherwise, taken from the elements. OUTPUT emit0 emittance data at the start/end of the ring @@ -109,6 +115,11 @@ def propag(m, cumb, orbit6): nelems = len(ring) uint32refs = uint32_refpts(refpts, nelems) allrefs = uint32_refpts(range(nelems + 1), nelems) + if energy is None: + if isinstance(ring, at.lattice.Lattice): + energy = ring.energy + else: + energy = get_ring_energy(ring) if orbit is None: orbit, _ = find_orbit6(ring, keep_lattice=keep_lattice) @@ -120,7 +131,7 @@ def propag(m, cumb, orbit6): axis=(1, 3)).T mring, ms = find_m66(ring, uint32refs, orbit=orbit, keep_lattice=True) b0 = numpy.zeros((6, 6)) - bb = [find_mpole_raddiff_matrix(elem, orbit, ring.energy) + bb = [find_mpole_raddiff_matrix(elem, orbit, energy) if elem.PassMethod.endswith('RadPass') else b0 for elem in ring] bbcum = numpy.stack(list(cumulb(zip(ring, orbs, bb))), axis=0) # ------------------------------------------------------------------------ @@ -133,7 +144,7 @@ def propag(m, cumb, orbit6): # A = inv(MRING) # B = -MRING' # Q = inv(MRING)*BCUM - # ----------------------------------------------------------------------- + # ------------------------------------------------------------------------ aa = inv(mring) bb = -mring.T qq = numpy.dot(aa, bbcum[-1]) @@ -145,8 +156,11 @@ def propag(m, cumb, orbit6): data0 = numpy.rec.fromarrays( (rr, rr4, mring, orbit, emitxy, emitxyz), dtype=ENVELOPE_DTYPE) - data = numpy.rec.fromrecords( - list(map(propag, ms, bbcum[uint32refs], orbs[uint32refs, :])), - dtype=ENVELOPE_DTYPE) + if uint32refs.shape == (0,): + data = numpy.recarray((0,), dtype=ENVELOPE_DTYPE) + else: + data = numpy.rec.fromrecords( + list(map(propag, ms, bbcum[uint32refs], orbs[uint32refs, :])), + dtype=ENVELOPE_DTYPE) return data0, r66data, data diff --git a/pyat/test/conftest.py b/pyat/test/conftest.py index de3f7fb45..c4535887e 100644 --- a/pyat/test/conftest.py +++ b/pyat/test/conftest.py @@ -1,11 +1,47 @@ """ A special file that contains test fixtures for the other test files to use. """ +import os import numpy import pytest +from at import elements, load, lattice @pytest.fixture def rin(): rin = numpy.array(numpy.zeros((6, 1)), order='F') return rin + + +@pytest.fixture(scope='session') +def simple_ring(): + ring = [elements.Drift('D1', 1, R1=numpy.eye(6), R2=numpy.eye(6)), + elements.Marker('M1', attr='a_value'), elements.M66('M66'), + elements.Drift('D2', 1, T1=numpy.zeros(6), T2=numpy.zeros(6)), + elements.Drift('D3', 1, R1=numpy.eye(6), R2=numpy.eye(6)), + elements.Drift('D4', 1, T1=numpy.zeros(6), T2=numpy.zeros(6))] + return ring + + +@pytest.fixture(scope='session') +def simple_lattice(simple_ring): + return lattice.Lattice(simple_ring, name='lat', energy=5, periodicity=1) + + +@pytest.fixture(scope='session') +def dba_ring(): + path = os.path.realpath(os.path.join(os.path.dirname(__file__), + '../test_matlab/dba.mat')) + return load.load_mat(path) + + +@pytest.fixture(scope='session') +def hmba_ring(): + path = os.path.realpath(os.path.join(os.path.dirname(__file__), + '../test_matlab/hmba.mat')) + return load.load_mat(path) + + +@pytest.fixture(scope='session') +def hmba_lattice(hmba_ring): + return lattice.Lattice(hmba_ring) diff --git a/pyat/test/test_basic_elements.py b/pyat/test/test_basic_elements.py index 4381f227a..9520d4c4f 100644 --- a/pyat/test/test_basic_elements.py +++ b/pyat/test/test_basic_elements.py @@ -22,6 +22,56 @@ def test_element_creation_raises_exception(): elements.Element('family_name', R1='not_an_array') +def test_base_element_methods(): + e = elements.Element('family_name') + assert e.divide([0.2, 0.5, 0.3]) == [e] + assert id(e.copy()) != id(e) + + +def test_divide_splits_attributes_correctly(): + pre = elements.Drift('drift', 1, KickAngle=0.5) + post = pre.divide([0.2, 0.5, 0.3]) + assert len(post) == 3 + assert sum([e.Length for e in post]) == pre.Length + assert sum([e.KickAngle for e in post]) == pre.KickAngle + pre = elements.Dipole('dipole', 1, KickAngle=[0.5, -0.5], BendingAngle=0.2) + post = pre.divide([0.2, 0.5, 0.3]) + assert len(post) == 3 + assert sum([e.Length for e in post]) == pre.Length + assert sum([e.KickAngle[0] for e in post]) == pre.KickAngle[0] + assert sum([e.KickAngle[1] for e in post]) == pre.KickAngle[1] + assert sum([e.BendingAngle for e in post]) == pre.BendingAngle + pre = elements.RFCavity('rfc', 1, voltage=187500, frequency=3.5237e+8, + harmonic_number=31, energy=6.e+9, KickAngle=0.5) + post = pre.divide([0.2, 0.5, 0.3]) + assert len(post) == 3 + assert sum([e.Length for e in post]) == pre.Length + assert sum([e.KickAngle for e in post]) == pre.KickAngle + assert sum([e.Voltage for e in post]) == pre.Voltage + + +def test_insert_into_drift(): + # Create elements + drift = elements.Drift('drift', 1) + monitor = elements.Monitor('bpm') + quad = elements.Quadrupole('quad', 0.3) + # Test None splitting behaviour + el_list = drift.insert([(0., None), (0.3, None), (0.7, None), (1., None)]) + assert len(el_list) == 3 + numpy.testing.assert_almost_equal([e.Length for e in el_list], + [0.3, 0.4, 0.3]) + # Test normal insertion + el_list = drift.insert([(0.3, monitor), (0.7, quad)]) + assert len(el_list) == 5 + numpy.testing.assert_almost_equal([e.Length for e in el_list], + [0.3, 0.0, 0.25, 0.3, 0.15]) + # Test insertion at either end produces -ve length drifts + el_list = drift.insert([(0.0, quad), (1.0, quad)]) + assert len(el_list) == 5 + numpy.testing.assert_almost_equal([e.Length for e in el_list], + [-0.15, 0.3, 0.7, 0.3, -0.15]) + + def test_correct_dimensions_does_not_raise_error(rin): l = [] atpass(l, rin, 1) @@ -40,6 +90,11 @@ def test_dipole(rin, dipole_class): atpass(l, rin, 1) rin_expected = numpy.array([1e-6, 0, 0, 0, 0, 1e-7]).reshape((6, 1)) numpy.testing.assert_almost_equal(rin_orig, rin_expected) + assert b.K == 0.0 + b.PolynomB[1] = 0.2 + assert b.K == 0.2 + b.K = 0.1 + assert b.PolynomB[1] == 0.1 def test_marker(rin): @@ -134,6 +189,11 @@ def test_quad(rin): expected = numpy.array([0.921060994002885, -0.389418342308651, 0, 0, 0, 0.000000010330489]).reshape(6, 1) * 1e-6 numpy.testing.assert_allclose(rin, expected) + assert q.K == 1 + q.PolynomB[1] = 0.2 + assert q.K == 0.2 + q.K = 0.1 + assert q.PolynomB[1] == 0.1 def test_quad_incorrect_array(rin): diff --git a/pyat/test/test_lattice_object.py b/pyat/test/test_lattice_object.py new file mode 100644 index 000000000..7575e79b8 --- /dev/null +++ b/pyat/test/test_lattice_object.py @@ -0,0 +1,207 @@ +import sys +import numpy +import pytest +from at import elements +from at.load import load_mat +from at.lattice import Lattice, AtWarning, AtError + + +def test_lattice_creation_gets_attributes_from_arguments(): + l = Lattice(name='lattice', energy=3.e+6, periodicity=32, an_attr=12) + assert len(l) == 0 + assert l.name == 'lattice' + assert l.energy == 3.e+6 + assert l.periodicity == 32 + assert l._radiation is False + assert l.an_attr == 12 + + +def test_lattice_creation_short_scan_reads_radiation_status_correctly(): + d = elements.Dipole('d1', 1, BendingAngle=numpy.pi, Energy=5.e+6, + PassMethod='BndMPoleSymplectic4RadPass') + l = Lattice([d], name='lattice', energy=3.e+6, periodicity=32) + assert len(l) == 1 + assert l.name == 'lattice' + assert l.energy == 3.e+6 + assert l.periodicity == 32 + assert l._radiation is True + + +def test_lattice_creation_from_lattice_inherits_attributes(): + d = elements.Dipole('d1', 1, BendingAngle=numpy.pi, Energy=5.e+6, + PassMethod='BndMPoleSymplectic4RadPass') + l = Lattice([d], name='lattice', energy=3.e+6, periodicity=32, an_attr=12) + l.another_attr = 5 + lat = Lattice(l) + assert id(l) != id(lat) + assert len(lat) == 1 + assert lat.name == 'lattice' + assert lat.energy == 3.e+6 + assert lat.periodicity == 32 + assert lat._radiation is True + assert lat.an_attr == 12 + assert lat.another_attr == 5 + + +def test_lattice_gets_attributes_from_RingParam(): + rp = elements.RingParam('lattice_name', 3.e+6, Periodicity=32) + l = Lattice([rp]) + assert len(l) == 0 + assert l.name == 'lattice_name' + assert l.energy == 3.e+6 + assert l.periodicity == 32 + assert l._radiation is False + assert len(Lattice([rp], keep_all=True)) == 1 + + +def test_lattice_gets_attributes_from_elements(): + d = elements.Dipole('d1', 1, BendingAngle=numpy.pi, Energy=3.e+6, + PassMethod='BndMPoleSymplectic4RadPass') + l = Lattice([d]) + assert len(l) == 1 + assert l.name == '' + assert l.energy == 3.e+6 + assert l.periodicity == 2 + assert l._radiation is True + + +def test_lattice_energy_is_not_defined_raises_AtError(): + with pytest.raises(AtError): + Lattice() + + +def test_no_bending_in_the_cell_warns_correctly(): + with pytest.warns(AtWarning): + Lattice([], energy=0) + + +def test_item_is_not_an_AT_element_warns_correctly(): + with pytest.warns(AtWarning): + Lattice(['a'], energy=0, periodicity=1) + + +def test__non_integer_number_of_cells_warns_correctly(): + d = elements.Dipole('d1', 1, BendingAngle=0.5) + with pytest.warns(AtWarning): + Lattice([d], energy=0) + + +def test_inconsistent_energy_values_warns_correctly(): + m1 = elements.Marker('m1', Energy=5) + m2 = elements.Marker('m2', Energy=3) + with pytest.warns(AtWarning): + Lattice([m1, m2], periodicity=1) + + +def test_more_than_one_RingParam_in_ring_raises_warning(): + with pytest.warns(AtWarning): + l = Lattice([elements.RingParam('rp1', 3.e+6), + elements.RingParam('rp2', 12)]) + + +def test_lattice_string_ordering(): + l = Lattice([elements.Drift('D0', 1.0, attr1=numpy.array(0))], name='lat', + energy=5, periodicity=1, attr2=3) + # Default dictionary ordering is only in Python >= 3.6 + if sys.version_info < (3, 6): + assert l.__str__().startswith("Lattice(<1 elements>, ") + assert l.__str__().endswith(", attr2=3)") + assert l.__repr__().startswith("Lattice([Drift('D0', 1.0, " + "attr1=array(0))], ") + assert l.__repr__().endswith(", attr2=3)") + else: + assert l.__str__() == ("Lattice(<1 elements>, energy=5, periodicity=1," + " name='lat', attr2=3)") + assert l.__repr__() == ("Lattice([Drift('D0', 1.0, attr1=array(0))]," + " energy=5, periodicity=1, name='lat', " + "attr2=3)") + + +def test_getitem(simple_lattice, simple_ring): + assert simple_lattice[-1] == simple_ring[-1] + bool_refs = numpy.array([False, True, False, True, False, True]) + assert simple_lattice[bool_refs] == simple_ring[1::2] + + +def test_setitem(simple_lattice): + new = elements.Monitor('M2') + old = simple_lattice[5] + simple_lattice[5] = new + assert simple_lattice[5] != old + assert simple_lattice[5] == new + bool_refs = numpy.array([False, False, False, False, False, True]) + simple_lattice[bool_refs] = old + assert simple_lattice[5] != new + assert simple_lattice[5] == old + + +def test_delitem(simple_lattice, simple_ring): + mon = elements.Monitor('M2') + simple_lattice.append(mon) + assert len(simple_lattice) == 7 + del simple_lattice[-1] + assert simple_lattice[:] == simple_ring + bool_refs = numpy.array([False, False, False, False, False, False, True]) + simple_lattice.append(mon) + assert len(simple_lattice) == 7 + del simple_lattice[bool_refs] + assert simple_lattice[:] == simple_ring + + +def test_copy(hmba_lattice): + assert id(hmba_lattice.copy()) != id(hmba_lattice) + assert id(hmba_lattice.copy()[0]) == id(hmba_lattice[0]) + + +def test_deepcopy(hmba_lattice): + assert id(hmba_lattice.deepcopy()) != id(hmba_lattice) + assert id(hmba_lattice.deepcopy()[0]) != id(hmba_lattice[0]) + + +def test_property_values_against_known(hmba_lattice): + assert hmba_lattice.voltage == 6000000 + assert hmba_lattice.harmonic_number == 992 + assert hmba_lattice.radiation == False + numpy.testing.assert_almost_equal(hmba_lattice.energy_loss, + 2526188.713461808) + + +def test_radiation_change(hmba_lattice): + rfs = [elem for elem in hmba_lattice if isinstance(elem, + elements.RFCavity)] + dipoles = [elem for elem in hmba_lattice if isinstance(elem, + elements.Dipole)] + quads = [elem for elem in hmba_lattice if isinstance(elem, + elements.Quadrupole)] + hmba_lattice.radiation_on(None, 'pass2', 'auto') + assert hmba_lattice.radiation == True + for elem in rfs: + assert elem.PassMethod == 'IdentityPass' + for elem in dipoles: + assert elem.PassMethod == 'pass2' + for elem in quads: + assert elem.PassMethod == 'StrMPoleSymplectic4RadPass' + hmba_lattice.radiation_off(None, 'BndMPoleSymplectic4Pass', 'auto') + assert hmba_lattice.radiation == False + for elem in rfs: + assert elem.PassMethod == 'IdentityPass' + for elem in dipoles: + assert elem.PassMethod == 'BndMPoleSymplectic4Pass' + for elem in quads: + assert elem.PassMethod == 'StrMPoleSymplectic4Pass' + + +def test_radiation_state_errors(hmba_lattice): + hmba_lattice.radiation_on() + with pytest.raises(AtError): + hmba_lattice.linopt() + hmba_lattice.radiation_off() + hmba_lattice.linopt() + with pytest.raises(AtError): + hmba_lattice.ohmi_envelope() + hmba_lattice.radiation_on() + hmba_lattice.ohmi_envelope() + with pytest.raises(AtError): + hmba_lattice.get_mcf() + hmba_lattice.radiation_off() + hmba_lattice.get_mcf() diff --git a/pyat/test/test_lattice.py b/pyat/test/test_lattice_utils.py similarity index 50% rename from pyat/test/test_lattice.py rename to pyat/test/test_lattice_utils.py index 09f906cdf..2b462d323 100644 --- a/pyat/test/test_lattice.py +++ b/pyat/test/test_lattice_utils.py @@ -1,18 +1,13 @@ +import sys import numpy -from at import lattice, elements +from io import BytesIO, StringIO +from at.lattice import elements, uint32_refpts, bool_refpts, checkattr +from at.lattice import checktype, get_cells, refpts_iterator, get_elements +from at.lattice import get_s_pos, tilt_elem, shift_elem, set_tilt, set_shift +from at.lattice import get_ring_energy, AtWarning, AtError import pytest -@pytest.fixture -def simple_ring(): - ring = [elements.Drift('D1', 1, R1=numpy.eye(6), R2=numpy.eye(6)), - elements.Marker('M', attr='a_value'), elements.M66('M66'), - elements.Drift('D2', 1, T1=numpy.zeros(6), T2=numpy.zeros(6)), - elements.Drift('D3', 1, R1=numpy.eye(6), R2=numpy.eye(6)), - elements.Drift('D4', 1, T1=numpy.zeros(6), T2=numpy.zeros(6))] - return ring - - @pytest.mark.parametrize('ref_in, expected', ( [2, numpy.array([2], dtype=numpy.uint32)], [-1, numpy.array([4], dtype=numpy.uint32)], @@ -22,11 +17,11 @@ def simple_ring(): [[0, 6, 2, -2, 4, 5], numpy.array([0, 1, 2, 3, 4, 5], dtype=numpy.uint32)], )) def test_uint32_refpts_converts_numerical_inputs_correctly(ref_in, expected): - numpy.testing.assert_equal(lattice.uint32_refpts(ref_in, 5), expected) + numpy.testing.assert_equal(uint32_refpts(ref_in, 5), expected) ref_in2 = numpy.asarray(ref_in).astype(numpy.float64) - numpy.testing.assert_equal(lattice.uint32_refpts(ref_in2, 5), expected) + numpy.testing.assert_equal(uint32_refpts(ref_in2, 5), expected) ref_in2 = numpy.asarray(ref_in).astype(numpy.int64) - numpy.testing.assert_equal(lattice.uint32_refpts(ref_in2, 5), expected) + numpy.testing.assert_equal(uint32_refpts(ref_in2, 5), expected) @pytest.mark.parametrize('ref_in, expected', ( @@ -39,7 +34,7 @@ def test_uint32_refpts_converts_numerical_inputs_correctly(ref_in, expected): [[True, True, True, True, True, True], numpy.array([0, 1, 2, 3, 4, 5], dtype=numpy.uint32)] )) def test_uint32_refpts_converts_other_input_types_correctly(ref_in, expected): - numpy.testing.assert_equal(lattice.uint32_refpts(ref_in, 5), expected) + numpy.testing.assert_equal(uint32_refpts(ref_in, 5), expected) # too long, misordered, duplicate, -ve indexing misordered, -ve indexing @@ -50,72 +45,108 @@ def test_uint32_refpts_converts_other_input_types_correctly(ref_in, expected): [0, -2], [3, 0], [1, 3], [-1, 3], [3, -2])) def test_uint32_refpts_throws_ValueError_correctly(ref_in): with pytest.raises(ValueError): - r = lattice.uint32_refpts(ref_in, 2) + r = uint32_refpts(ref_in, 2) def test_bool_refpts(): bool_rps1 = numpy.ones(5, dtype=bool) bool_rps1[3] = False - numpy.testing.assert_equal(bool_rps1, lattice.bool_refpts(bool_rps1, 4)) - numpy.testing.assert_equal(lattice.bool_refpts([0, 1, 2, 4], 4), bool_rps1) + numpy.testing.assert_equal(bool_rps1, bool_refpts(bool_rps1, 4)) + numpy.testing.assert_equal(bool_refpts([0, 1, 2, 4], 4), bool_rps1) bool_rps3 = numpy.ones(12, dtype=bool) bool_rps3[3] = False - numpy.testing.assert_equal(bool_rps1, lattice.bool_refpts(bool_rps3, 4)) + numpy.testing.assert_equal(bool_rps1, bool_refpts(bool_rps3, 4)) bool_rps2 = numpy.ones(4, dtype=bool) numpy.testing.assert_equal(numpy.array([True, True, True, True, False]), - lattice.bool_refpts(bool_rps2, 4)) + bool_refpts(bool_rps2, 4)) def test_checkattr(simple_ring): - assert lattice.checkattr('Length')(simple_ring[0]) is True - assert lattice.checkattr('not_an_attr')(simple_ring[0]) is False - assert (list(filter(lattice.checkattr('Length', 1), simple_ring)) == + assert checkattr('Length')(simple_ring[0]) is True + assert checkattr('not_an_attr')(simple_ring[0]) is False + assert (list(filter(checkattr('Length', 1), simple_ring)) == [simple_ring[0], simple_ring[3], simple_ring[4], simple_ring[5]]) - assert list(filter(lattice.checkattr('Length', 2), simple_ring)) == [] - assert list(filter(lattice.checkattr('not_an_attr'), simple_ring)) == [] + assert list(filter(checkattr('Length', 2), simple_ring)) == [] + assert list(filter(checkattr('not_an_attr'), simple_ring)) == [] def test_checktype(simple_ring): - assert lattice.checktype(elements.Drift)(simple_ring[0]) is True - assert lattice.checktype(elements.Marker)(simple_ring[0]) is False - assert (list(filter(lattice.checktype(elements.Drift), simple_ring)) == + assert checktype(elements.Drift)(simple_ring[0]) is True + assert checktype(elements.Marker)(simple_ring[0]) is False + assert (list(filter(checktype(elements.Drift), simple_ring)) == [simple_ring[0], simple_ring[3], simple_ring[4], simple_ring[5]]) - assert list(filter(lattice.checktype(elements.Monitor), simple_ring)) == [] + assert list(filter(checktype(elements.Monitor), simple_ring)) == [] def test_get_cells(simple_ring): a = numpy.ones(6, dtype=bool) - numpy.testing.assert_equal(lattice.get_cells(simple_ring, lattice.checkattr('Length')), a) + numpy.testing.assert_equal(get_cells(simple_ring, checkattr('Length')), a) a = numpy.array([False, True, False, False, False, False]) - numpy.testing.assert_equal(lattice.get_cells(simple_ring, 'attr'), a) + numpy.testing.assert_equal(get_cells(simple_ring, 'attr'), a) a = numpy.array([True, False, False, False, False, False]) - numpy.testing.assert_equal(lattice.get_cells(simple_ring, 'FamName', 'D1'), a) + numpy.testing.assert_equal(get_cells(simple_ring, 'FamName', 'D1'), a) def test_refpts_iterator(simple_ring): - assert (list(lattice.refpts_iterator(simple_ring, [0, 1, 2, 3, 4, 5])) == + assert (list(refpts_iterator(simple_ring, [0, 1, 2, 3, 4, 5])) == simple_ring) - assert (list(lattice.refpts_iterator(simple_ring, numpy.ones(6, dtype=bool))) + assert (list(refpts_iterator(simple_ring, numpy.ones(6, dtype=bool))) == simple_ring) - assert list(lattice.refpts_iterator(simple_ring, [1])) == [simple_ring[1]] + assert list(refpts_iterator(simple_ring, [1])) == [simple_ring[1]] a = numpy.array([False, True, False, False, False, False]) - assert list(lattice.refpts_iterator(simple_ring, a)) == [simple_ring[1]] + assert list(refpts_iterator(simple_ring, a)) == [simple_ring[1]] + + +def test_get_elements(hmba_lattice): + # test FamName direct match + assert get_elements(hmba_lattice, 'BPM_06') == [hmba_lattice[65]] + # test FamName wildcard matching + assert get_elements(hmba_lattice, 'QD2?') == hmba_lattice[9, 113] + assert get_elements(hmba_lattice, 'QD3*') == hmba_lattice[19, 105] + assert get_elements(hmba_lattice, 'S*H2B') == [hmba_lattice[55]] + assert get_elements(hmba_lattice, '*C_1') == hmba_lattice[59, 60] + assert get_elements(hmba_lattice, 'DR_2[1-3]') == hmba_lattice[54, 56, 58] + assert get_elements(hmba_lattice, 'DR_2[!1-7]') == hmba_lattice[52, 78, 80] + # test element instance + marker = elements.Marker('M1') + assert get_elements(hmba_lattice, marker) == hmba_lattice[1, 12, 61, + 67, 73] + # test element type + assert get_elements(hmba_lattice, elements.RFCavity) == [hmba_lattice[0]] + # test invalid key raises TypeError + with pytest.raises(TypeError): + get_elements(hmba_lattice, None) + # test quiet suppresses print statement correctly + if sys.version_info < (3, 0): + capturedOutput = BytesIO() + else: + capturedOutput = StringIO() + sys.stdout = capturedOutput + get_elements(hmba_lattice, 'BPM_06', quiet=True) + sys.stdout = sys.__stdout__ + assert capturedOutput.getvalue() == '' + sys.stdout = capturedOutput + get_elements(hmba_lattice, 'BPM_06', quiet=False) + sys.stdout = sys.__stdout__ + assert capturedOutput.getvalue() == ("String 'BPM_06' matched 1 family: " + "BPM_06\nall corresponding elements " + "have been returned.\n") def test_get_s_pos_returns_zero_for_empty_lattice(): - numpy.testing.assert_equal(lattice.get_s_pos([], None), numpy.array((0,))) + numpy.testing.assert_equal(get_s_pos([], None), numpy.array((0,))) def test_get_s_pos_returns_length_for_lattice_with_one_element(): e = elements.Element('e', 0.1) - assert lattice.get_s_pos([e], [1]) == numpy.array([0.1]) + assert get_s_pos([e], [1]) == numpy.array([0.1]) def test_get_s_pos_returns_all_points_for_lattice_with_two_elements_and_refpts_None(): e = elements.Element('e', 0.1) f = elements.Element('e', 2.1) - print(lattice.get_s_pos([e, f], None)) - numpy.testing.assert_equal(lattice.get_s_pos([e, f], None), + print(get_s_pos([e, f], None)) + numpy.testing.assert_equal(get_s_pos([e, f], None), numpy.array([0, 0.1, 2.2])) @@ -123,19 +154,33 @@ def test_get_s_pos_returns_all_points_for_lattice_with_two_elements_using_int_re e = elements.Element('e', 0.1) f = elements.Element('e', 2.1) lat = [e, f] - numpy.testing.assert_equal(lattice.get_s_pos(lat, range(len(lat) + 1)), + numpy.testing.assert_equal(get_s_pos(lat, range(len(lat) + 1)), numpy.array([0, 0.1, 2.2])) def test_get_s_pos_returns_all_points_for_lattice_with_two_elements_using_bool_refpts(): e = elements.Element('e', 0.1) f = elements.Element('e', 2.1) lat = [e, f] - numpy.testing.assert_equal(lattice.get_s_pos(lat, numpy.ones(3, dtype=bool)), + numpy.testing.assert_equal(get_s_pos(lat, numpy.ones(3, dtype=bool)), numpy.array([0, 0.1, 2.2])) +def test_get_ring_energy(): + ring = [elements.RingParam('RP', 1.e+6), + elements.RFCavity('RF', 1.0, 24, 46, 12, 2.e+6), + elements.Element('EL1', Energy=3.e+6), + elements.Element('EL2', Energy=4.e+6)] + with pytest.warns(AtWarning): + assert get_ring_energy(ring) == 1.e+6 + assert get_ring_energy(ring[1:]) == 2.e+6 + assert get_ring_energy(ring[2:]) == 4.e+6 + assert get_ring_energy(ring[2:3]) == 3.e+6 # shouldn't warn + with pytest.raises(AtError): + get_ring_energy([elements.Drift('D1', 1.0)]) + + def test_tilt_elem(simple_ring): - lattice.tilt_elem(simple_ring[0], (numpy.pi/4)) + tilt_elem(simple_ring[0], (numpy.pi/4)) v = 1/2**0.5 a = numpy.diag([v, v, v, v, 1.0, 1.0]) a[0, 2], a[1, 3], a[2, 0], a[3, 1] = v, v, -v, -v @@ -144,14 +189,14 @@ def test_tilt_elem(simple_ring): def test_shift_elem(simple_ring): - lattice.shift_elem(simple_ring[2], 1.0, 0.5) + shift_elem(simple_ring[2], 1.0, 0.5) a = numpy.array([1.0, 0.0, 0.5, 0.0, 0.0, 0.0]) numpy.testing.assert_equal(simple_ring[2].T1, -a) numpy.testing.assert_equal(simple_ring[2].T2, a) def test_set_tilt(simple_ring): - lattice.set_tilt(simple_ring, [(numpy.pi/4), 0, 0, 0, (numpy.pi/4), 0]) + set_tilt(simple_ring, [(numpy.pi/4), 0, 0, 0, (numpy.pi/4), 0]) v = 1/2**0.5 a = numpy.diag([v, v, v, v, 1.0, 1.0]) a[0, 2], a[1, 3], a[2, 0], a[3, 1] = v, v, -v, -v @@ -160,13 +205,13 @@ def test_set_tilt(simple_ring): numpy.testing.assert_allclose(simple_ring[0].R1, a) numpy.testing.assert_allclose(simple_ring[0].R2, a.T) ring = [simple_ring[0]] - lattice.set_tilt(ring, (0)) + set_tilt(ring, (0)) numpy.testing.assert_allclose(ring[0].R1, numpy.eye(6)) numpy.testing.assert_allclose(ring[0].R2, numpy.eye(6)) def test_set_shift(simple_ring): - lattice.set_shift(simple_ring, numpy.array([0., 0., 0., 1., 0., 0.5,]), + set_shift(simple_ring, numpy.array([0., 0., 0., 1., 0., 0.5,]), numpy.array([0., 0., 0., 2., 0., 1.,])) a = numpy.array([0.5, 0., 1., 0., 0., 0.]) numpy.testing.assert_equal(simple_ring[3].T1, -a*2) @@ -174,7 +219,7 @@ def test_set_shift(simple_ring): numpy.testing.assert_equal(simple_ring[5].T1, -a) numpy.testing.assert_equal(simple_ring[5].T2, a) ring = [simple_ring[3]] - lattice.set_shift(ring, 3, 5) + set_shift(ring, 3, 5) a = numpy.array([3., 0., 5., 0., 0., 0.]) numpy.testing.assert_equal(simple_ring[3].T1, -a) numpy.testing.assert_equal(simple_ring[3].T2, a) diff --git a/pyat/test/test_load_utils.py b/pyat/test/test_load_utils.py index 929edae26..a135ff9f8 100644 --- a/pyat/test/test_load_utils.py +++ b/pyat/test/test_load_utils.py @@ -5,14 +5,28 @@ from at.load import CLASS_MAPPING, PASS_MAPPING -def test_invalid_class_warns_when_quiet_is_False(): +def test_invalid_class_warns_correctly(): elem_kwargs = {'Class': 'Invalid'} with pytest.warns(at.AtWarning): find_class_name(elem_kwargs, quiet=False) + with pytest.warns(None) as record: + find_class_name(elem_kwargs, quiet=True) + assert len(record) is 0 -def test_invalid_class_does_not_warn_when_quiet_is_True(): - elem_kwargs = {'Class': 'Invalid'} +def test_no_pass_method_warns_correctly(): + elem_kwargs = {} + with pytest.warns(at.AtWarning): + find_class_name(elem_kwargs, quiet=False) + with pytest.warns(None) as record: + find_class_name(elem_kwargs, quiet=True) + assert len(record) is 0 + + +def test_invalid_pass_method_warns_correctly(): + elem_kwargs = {'PassMethod': 'invalid'} + with pytest.warns(at.AtWarning): + find_class_name(elem_kwargs, quiet=False) with pytest.warns(None) as record: find_class_name(elem_kwargs, quiet=True) assert len(record) is 0 @@ -45,7 +59,7 @@ def test_PassMethod_mapping(): def test_find_Aperture(): elem_kwargs = {'Limits': [-0.5, 0.5, -0.5, 0.5], 'FamName': 'fam'} - assert find_class_name(elem_kwargs) == 'Aperture' + assert find_class_name(elem_kwargs, True) == 'Aperture' @pytest.mark.parametrize('elem_kwargs', ( @@ -55,12 +69,12 @@ def test_find_Aperture(): {'PhaseLag': 0, 'FamName': 'fam'}, {'TimeLag': 0.0, 'FamName': 'fam'})) def test_find_RFCavity(elem_kwargs): - assert find_class_name(elem_kwargs) == 'RFCavity' + assert find_class_name(elem_kwargs, True) == 'RFCavity' def test_find_Monitor(): elem_kwargs = {'GCR': [1, 1, 0, 0], 'FamName': 'fam'} - assert find_class_name(elem_kwargs) == 'Monitor' + assert find_class_name(elem_kwargs, True) == 'Monitor' @pytest.mark.parametrize('elem_kwargs', ( @@ -71,29 +85,29 @@ def test_find_Monitor(): {'EntranceAngle': 0.05, 'FamName': 'fam'}, {'ExitAngle': 0.05, 'FamName': 'fam'})) def test_find_Dipole(elem_kwargs): - assert find_class_name(elem_kwargs) == 'Dipole' + assert find_class_name(elem_kwargs, True) == 'Dipole' def test_find_Corrector(): elem_kwargs = {'KickAngle': [0, 0], 'FamName': 'fam'} - assert find_class_name(elem_kwargs) == 'Corrector' + assert find_class_name(elem_kwargs, True) == 'Corrector' def test_find_RingParam(): elem_kwargs = {'Periodicity': 1, 'FamName': 'fam'} - assert find_class_name(elem_kwargs) == 'RingParam' + assert find_class_name(elem_kwargs, True) == 'RingParam' def test_find_M66(): elem_kwargs = {'M66': numpy.eye(6), 'FamName': 'fam'} - assert find_class_name(elem_kwargs) == 'M66' + assert find_class_name(elem_kwargs, True) == 'M66' @pytest.mark.parametrize('elem_kwargs', ( {'K': -0.5, 'FamName': 'fam'}, {'PolynomB': [0, 1, 0, 0], 'FamName': 'fam'})) def test_find_Quadrupole(elem_kwargs): - assert find_class_name(elem_kwargs) == 'Quadrupole' + assert find_class_name(elem_kwargs, True) == 'Quadrupole' @pytest.mark.parametrize('elem_kwargs', ( @@ -101,22 +115,22 @@ def test_find_Quadrupole(elem_kwargs): 'FamName': 'fam'}, {'PolynomB': [0, 0, 0, 0], 'Length': 1, 'FamName': 'fam'})) def test_find_Multipole(elem_kwargs): - assert find_class_name(elem_kwargs) == 'Multipole' + assert find_class_name(elem_kwargs, True) == 'Multipole' def test_find_Drift(): elem_kwargs = {'Length': 1, 'FamName': 'fam'} - assert find_class_name(elem_kwargs) == 'Drift' + assert find_class_name(elem_kwargs, True) == 'Drift' def test_find_Sextupole(): elem_kwargs = {'PolynomB': [0, 0, 1, 0], 'FamName': 'fam'} - assert find_class_name(elem_kwargs) == 'Sextupole' + assert find_class_name(elem_kwargs, True) == 'Sextupole' def test_find_Octupole(): elem_kwargs = {'PolynomB': [0, 0, 0, 1], 'FamName': 'fam'} - assert find_class_name(elem_kwargs) == 'Octupole' + assert find_class_name(elem_kwargs, True) == 'Octupole' @pytest.mark.parametrize('elem_kwargs', ( @@ -125,13 +139,13 @@ def test_find_Octupole(): {'PolynomB': [0, 0, 0, 0], 'FamName': 'fam'}, {'PolynomB': [0, 0, 0, 0, 1], 'Length': 0, 'FamName': 'fam'})) def test_find_ThinMultipole(elem_kwargs): - assert find_class_name(elem_kwargs) == 'ThinMultipole' + assert find_class_name(elem_kwargs, True) == 'ThinMultipole' @pytest.mark.parametrize('elem_kwargs', ({'FamName': 'fam'}, {'Length': 0.0, 'FamName': 'fam'})) def test_find_Marker(elem_kwargs): - assert find_class_name(elem_kwargs) == 'Marker' + assert find_class_name(elem_kwargs, True) == 'Marker' @pytest.mark.parametrize('elem_kwargs', ( @@ -146,7 +160,8 @@ def test_find_Marker(elem_kwargs): {'Class': 'RingParam', 'PassMethod': 'StrMPoleSymplectic4Pass', 'Energy': 3E9, 'FamName': 'fam'}, {'Class': 'Drift', 'PassMethod': 'StrMPoleSymplectic4Pass', - 'Length': 1.0, 'FamName': 'fam'})) + 'Length': 1.0, 'FamName': 'fam'}, + {'Class': 'Drift', 'PassMethod': 'InvalidPass'})) def test_sanitise_class_error(elem_kwargs): with pytest.raises(AttributeError): elem = element_from_dict(elem_kwargs) diff --git a/pyat/test/test_physics.py b/pyat/test/test_physics.py index c6cee98bf..550ddfb04 100644 --- a/pyat/test/test_physics.py +++ b/pyat/test/test_physics.py @@ -1,11 +1,8 @@ import at import numpy import pytest -from at import physics, load, atpass - - -LATTICE_FILE = 'test_matlab/dba.mat' -CAVITY_FILE = 'test_matlab/hmba.mat' +from at import physics, load, atpass, elements +from at.lattice import AtWarning, AtError DP = 1e-5 @@ -15,62 +12,60 @@ [0, 0, -0.0059496965, -0.99921979]]) -@pytest.fixture -def ring(): - ring = load.load_mat(LATTICE_FILE) - return ring - - -@pytest.fixture -def cavity_ring(): - ring = load.load_mat(CAVITY_FILE) - return ring - - -def test_find_orbit4(ring): - orbit4, _ = physics.find_orbit4(ring, DP) +def test_find_orbit4(dba_ring): + orbit4, _ = physics.find_orbit4(dba_ring, DP) expected = numpy.array([1.091636e-7, 1.276747e-15, 0, 0, DP, 0]) numpy.testing.assert_allclose(orbit4, expected, atol=1e-12) -def test_find_orbit4_finds_zeros_if_dp_zero(ring): - orbit4, _ = physics.find_orbit4(ring, 0) +def test_find_orbit4_finds_zeros_if_dp_zero(dba_ring): + orbit4, _ = physics.find_orbit4(dba_ring, 0) expected = numpy.zeros((6,)) numpy.testing.assert_allclose(orbit4, expected, atol=1e-7) -def test_find_orbit4_result_unchanged_by_atpass(ring): - orbit, _ = physics.find_orbit4(ring, DP) +def test_find_orbit4_result_unchanged_by_atpass(dba_ring): + orbit, _ = physics.find_orbit4(dba_ring, DP) orbit_copy = numpy.copy(orbit) orbit[4] = DP - atpass(ring, orbit, 1) + atpass(dba_ring, orbit, 1) numpy.testing.assert_allclose(orbit[:4], orbit_copy[:4], atol=1e-12) -def test_find_orbit4_with_two_refpts_with_and_without_guess(ring): +def test_find_orbit4_with_two_refpts_with_and_without_guess(dba_ring): expected = numpy.array( [[8.148212e-6, 1.0993354e-5, 0, 0, DP, 2.963929e-6], [3.0422808e-8, 9.1635269e-8, 0, 0, DP, 5.9280346e-6]] ) - _, all_points = physics.find_orbit4(ring, DP, [49, 99]) + _, all_points = physics.find_orbit4(dba_ring, DP, [49, 99]) numpy.testing.assert_allclose(all_points, expected, atol=1e-12) - _, all_points = physics.find_orbit4(ring, DP, [49, 99], + _, all_points = physics.find_orbit4(dba_ring, DP, [49, 99], numpy.array([0., 0., 0., 0., DP, 0.])) numpy.testing.assert_allclose(all_points, expected, atol=1e-12) -@pytest.mark.parametrize('refpts', ([145], [20], [1, 2, 3])) -def test_find_m44_returns_same_answer_as_matlab(ring, refpts): - m44, mstack = physics.find_m44(ring, dp=DP, refpts=refpts) +def test_orbit_maxiter_warnings(hmba_ring): + with pytest.warns(AtWarning): + physics.find_orbit4(hmba_ring, max_iterations=1) + with pytest.warns(AtWarning): + physics.find_sync_orbit(hmba_ring, max_iterations=1) + with pytest.warns(AtWarning): + physics.find_orbit6(hmba_ring, max_iterations=1) + +@pytest.mark.parametrize('refpts', ([145], [20], [1, 2, 3])) +def test_find_m44_returns_same_answer_as_matlab(dba_ring, refpts): + m44, mstack = physics.find_m44(dba_ring, dp=DP, refpts=refpts) numpy.testing.assert_allclose(m44, M44_MATLAB, rtol=1e-5, atol=1e-7) - stack_size = 0 if refpts is None else len(refpts) - assert mstack.shape == (stack_size, 4, 4) + assert mstack.shape == (len(refpts), 4, 4) + m44, mstack = physics.find_m44(dba_ring, dp=DP, refpts=refpts, full=True) + numpy.testing.assert_allclose(m44, M44_MATLAB, rtol=1e-5, atol=1e-7) + assert mstack.shape == (len(refpts), 4, 4) @pytest.mark.parametrize('refpts', ([145], [20], [1, 2, 3])) -def test_find_m66(cavity_ring, refpts): - m66, mstack = physics.find_m66(cavity_ring, refpts=refpts) +def test_find_m66(hmba_ring, refpts): + m66, mstack = physics.find_m66(hmba_ring, refpts=refpts) expected = numpy.array([[-0.735654, 4.673766, 0., 0., 2.997161e-3, 0.], [-9.816788e-2, -0.735654, 0., 0., 1.695263e-4, 0.], [0., 0., 0.609804, -2.096051, 0., 0.], @@ -83,8 +78,8 @@ def test_find_m66(cavity_ring, refpts): @pytest.mark.parametrize('index', (20, 1, 2)) -def test_find_elem_m66(cavity_ring, index): - m66 = physics.find_elem_m66(cavity_ring[index]) +def test_find_elem_m66(hmba_ring, index): + m66 = physics.find_elem_m66(hmba_ring[index]) if index is 20: expected = numpy.array([[1.0386, 0.180911, 0., 0., 0., 0.], [0.434959, 1.0386, 0., 0., 0., 0.], @@ -97,32 +92,32 @@ def test_find_elem_m66(cavity_ring, index): numpy.testing.assert_allclose(m66, expected, rtol=1e-5, atol=1e-7) -def test_find_sync_orbit(ring): +def test_find_sync_orbit(dba_ring): expected = numpy.array([[1.030844e-5, 1.390795e-5, -2.439041e-30, 4.701621e-30, 1.265181e-5, 3.749859e-6], [3.86388e-8, 1.163782e-7, -9.671192e-30, 3.567819e-30, 1.265181e-5, 7.5e-6]]) - _, all_points = physics.find_sync_orbit(ring, DP, [49, 99]) + _, all_points = physics.find_sync_orbit(dba_ring, DP, [49, 99]) numpy.testing.assert_allclose(all_points, expected, rtol=1e-5, atol=1e-7) -def test_find_sync_orbit_finds_zeros(ring): - sync_orbit = physics.find_sync_orbit(ring)[0] +def test_find_sync_orbit_finds_zeros(dba_ring): + sync_orbit = physics.find_sync_orbit(dba_ring)[0] numpy.testing.assert_equal(sync_orbit, numpy.zeros(6)) -def test_find_orbit6(cavity_ring): - expected = numpy.zeros((len(cavity_ring), 6)) - refpts = numpy.ones(len(cavity_ring), dtype=bool) - _, all_points = physics.find_orbit6(cavity_ring, refpts) +def test_find_orbit6(hmba_ring): + expected = numpy.zeros((len(hmba_ring), 6)) + refpts = numpy.ones(len(hmba_ring), dtype=bool) + _, all_points = physics.find_orbit6(hmba_ring, refpts) numpy.testing.assert_allclose(all_points, expected, atol=1e-12) -def test_find_orbit6_raises_AtError_if_there_is_no_cavity(ring): +def test_find_orbit6_raises_AtError_if_there_is_no_cavity(dba_ring): with pytest.raises(at.lattice.utils.AtError): - physics.find_orbit6(ring) + physics.find_orbit6(dba_ring) -def test_find_m44_no_refpts(ring): - m44 = physics.find_m44(ring, dp=DP)[0] +def test_find_m44_no_refpts(dba_ring): + m44 = physics.find_m44(dba_ring, dp=DP)[0] expected = numpy.array([[-0.66380, 2.23415, 0., 0.], [-0.25037, -0.66380, 0.,0.], [-1.45698e-31, -1.15008e-30, -0.99922, 0.26217], @@ -131,8 +126,8 @@ def test_find_m44_no_refpts(ring): @pytest.mark.parametrize('refpts', ([145], [1, 2, 3, 145])) -def test_get_twiss(ring, refpts): - twiss0, tune, chrom, twiss = physics.get_twiss(ring, DP, refpts, +def test_get_twiss(dba_ring, refpts): + twiss0, tune, chrom, twiss = physics.get_twiss(dba_ring, DP, refpts, get_chrom=True) numpy.testing.assert_allclose(twiss['s_pos'][-1], 56.209377216, atol=1e-9) numpy.testing.assert_allclose(twiss['closed_orbit'][0][:5], @@ -147,15 +142,15 @@ def test_get_twiss(ring, refpts): atol=1e-7) -def test_get_twiss_no_refpts(ring): - twiss0, tune, chrom, twiss = physics.get_twiss(ring, DP, get_chrom=True) +def test_get_twiss_no_refpts(dba_ring): + twiss0, tune, chrom, twiss = physics.get_twiss(dba_ring, DP, get_chrom=True) assert list(twiss) == [] - assert len(physics.get_twiss(ring, DP, get_chrom=True)) is 4 + assert len(physics.get_twiss(dba_ring, DP, get_chrom=True)) is 4 @pytest.mark.parametrize('refpts', ([145], [1, 2, 3, 145])) -def test_linopt(ring, refpts): - lindata0, tune, chrom, lindata = physics.linopt(ring, DP, refpts, +def test_linopt(dba_ring, refpts): + lindata0, tune, chrom, lindata = physics.linopt(dba_ring, DP, refpts, get_chrom=True) numpy.testing.assert_allclose(tune, [0.365529, 0.493713], rtol=1e-5) numpy.testing.assert_allclose(chrom, [-0.309037, -0.441859], rtol=1e-5) @@ -191,8 +186,8 @@ def test_linopt(ring, refpts): @pytest.mark.parametrize('refpts', ([145], [1, 2, 3, 145])) -def test_linopt_uncoupled(ring, refpts): - lindata0, tune, chrom, lindata = physics.linopt(ring, DP, refpts, +def test_linopt_uncoupled(dba_ring, refpts): + lindata0, tune, chrom, lindata = physics.linopt(dba_ring, DP, refpts, coupled=False) numpy.testing.assert_allclose(tune, [0.365529, 0.493713], rtol=1e-5) numpy.testing.assert_allclose(lindata['s_pos'][-1], 56.209377216, atol=1e-9) @@ -222,17 +217,19 @@ def test_linopt_uncoupled(ring, refpts): rtol=1e-5, atol=1e-7) -def test_linopt_no_refpts(ring): - lindata0, tune, chrom, lindata = physics.linopt(ring, DP, get_chrom=True) +def test_linopt_no_refpts(dba_ring): + lindata0, tune, chrom, lindata = physics.linopt(dba_ring, DP, get_chrom=True) assert list(lindata) == [] - assert len(physics.linopt(ring, DP, get_chrom=True)) is 4 + assert len(physics.linopt(dba_ring, DP, get_chrom=True)) is 4 @pytest.mark.parametrize('refpts', ([145], [1, 2, 3, 145])) -def test_ohmi_envelope(cavity_ring, refpts): - lattice = at.Lattice(cavity_ring) - lattice.radiation_on() - emit0, beamdata, emit = lattice.ohmi_envelope(refpts) +@pytest.mark.parametrize('ring_test', (False, True)) +def test_ohmi_envelope(hmba_lattice, refpts, ring_test): + hmba_lattice.radiation_on() + if ring_test: + hmba_lattice = hmba_lattice[:] + emit0, beamdata, emit = physics.ohmi_envelope(hmba_lattice, refpts) expected_beamdata = [([0.38156302, 0.85437641, 1.0906073e-4]), ([1.0044543e-5, 6.6238162e-6, 9.6533473e-6]), ([[[6.9000153, -2.6064253e-5, 1.643376e-25, diff --git a/pyat/test_matlab/README.rst b/pyat/test_matlab/README.rst index 086cd35ef..9591e84af 100644 --- a/pyat/test_matlab/README.rst +++ b/pyat/test_matlab/README.rst @@ -37,13 +37,13 @@ Install the Matlab engine for Python, ensuring your virtualenv is still active: Now run the tests inside your virtualenv: * ``cd $AT_ROOT/pyat`` -* ``$PYTHON_EXECUTABLE -m pytest`` +* ``$PYTHON_EXECUTABLE -m pytest test_matlab`` Footnotes --------- -.. [1] Matlab versions and the Pyton versions thay support: +.. [1] Matlab versions and the Python versions they support: +----------------+--------------------------+ | Matlab Release | Supported Python Version | @@ -66,6 +66,8 @@ Footnotes +----------------+--------------------------+ | 2014b | 2.7, 3.3 | +----------------+--------------------------+ + | <=2014a | Not supported | + +----------------+--------------------------+ .. [2] To check if your Python version is compiled with ucs2 or ucs4:: diff --git a/pyat/test_matlab/conftest.py b/pyat/test_matlab/conftest.py index 765d1c3e7..eaf9bc8ad 100644 --- a/pyat/test_matlab/conftest.py +++ b/pyat/test_matlab/conftest.py @@ -15,22 +15,34 @@ def engine(): eng.quit() -@pytest.fixture +@pytest.fixture(scope='session') def ml_dba(engine): lattice = engine.load(utils.dba_ring) return lattice['RING'] -@pytest.fixture + +@pytest.fixture(scope='session') def ml_hmba(engine): lattice = engine.load(utils.hmba_ring) return lattice['RING'] -@pytest.fixture +@pytest.fixture(scope='session') +def ml_err(engine): + lattice = engine.load(utils.err_ring) + return lattice['RING'] + + +@pytest.fixture(scope='session') def py_dba(): return Lattice(load_mat(utils.dba_ring, key='RING'), keep_all=True) -@pytest.fixture +@pytest.fixture(scope='session') def py_hmba(): return Lattice(load_mat(utils.hmba_ring, key='RING'), keep_all=True) + + +@pytest.fixture(scope='session') +def py_err(): + return Lattice(load_mat(utils.err_ring, key='RING'), keep_all=True) diff --git a/pyat/test_matlab/test_cmp_physics.py b/pyat/test_matlab/test_cmp_physics.py index b9ce76f2c..44023aab1 100644 --- a/pyat/test_matlab/test_cmp_physics.py +++ b/pyat/test_matlab/test_cmp_physics.py @@ -29,8 +29,10 @@ def _compare_physdata(py_data, ml_data, fields, decimal=8): @pytest.mark.parametrize('dp', (-0.01, 0.0, 0.01)) @pytest.mark.parametrize('refpts', (0, [0, 1, 2, -1], None)) @pytest.mark.parametrize('ml_lattice, py_lattice', - [(pytest.lazy_fixture('ml_dba'), pytest.lazy_fixture('py_dba')), - (pytest.lazy_fixture('ml_hmba'), pytest.lazy_fixture('py_hmba'))]) + [(pytest.lazy_fixture('ml_dba'), + pytest.lazy_fixture('py_dba')), + (pytest.lazy_fixture('ml_hmba'), + pytest.lazy_fixture('py_hmba'))]) def test_find_orbit4(engine, ml_lattice, py_lattice, dp, refpts): nelems = len(py_lattice) refpts = range(nelems + 1) if refpts is None else refpts @@ -38,7 +40,8 @@ def test_find_orbit4(engine, ml_lattice, py_lattice, dp, refpts): # Python call py_orb4, py_orbit4 = physics.find_orbit4(py_lattice, dp, refpts) # Matlab call - ml_orbit4, ml_orb4 = engine.findorbit4(ml_lattice, dp, _ml_refs(refpts, nelems), nargout=2) + ml_orbit4, ml_orb4 = engine.findorbit4(ml_lattice, dp, + _ml_refs(refpts, nelems), nargout=2) ml_orbit4 = numpy.rollaxis(numpy.asarray(ml_orbit4), -1) # Comparison numpy.testing.assert_almost_equal(py_orb4, _py_data(ml_orb4), decimal=8) @@ -48,8 +51,10 @@ def test_find_orbit4(engine, ml_lattice, py_lattice, dp, refpts): @pytest.mark.parametrize('dp', (-0.01, 0.0, 0.01)) @pytest.mark.parametrize('refpts', (0, [0, 1, 2, -1], None)) @pytest.mark.parametrize('ml_lattice, py_lattice', - [(pytest.lazy_fixture('ml_dba'), pytest.lazy_fixture('py_dba')), - (pytest.lazy_fixture('ml_hmba'), pytest.lazy_fixture('py_hmba'))]) + [(pytest.lazy_fixture('ml_dba'), + pytest.lazy_fixture('py_dba')), + (pytest.lazy_fixture('ml_hmba'), + pytest.lazy_fixture('py_hmba'))]) def test_find_m44(engine, ml_lattice, py_lattice, dp, refpts): nelems = len(py_lattice) refpts = range(nelems + 1) if refpts is None else refpts @@ -58,8 +63,10 @@ def test_find_m44(engine, ml_lattice, py_lattice, dp, refpts): # Python call py_m44, py_mstack = physics.find_m44(py_lattice, dp, refpts) # Matlab call - ml_m44, ml_mstack = engine.findm44(ml_lattice, dp, _ml_refs(refpts, nelems), nargout=2) - ml_mstack = numpy.rollaxis(numpy.asarray(ml_mstack).reshape((4, 4, nrefs)), -1) + ml_m44, ml_mstack = engine.findm44(ml_lattice, dp, + _ml_refs(refpts, nelems), nargout=2) + ml_mstack = numpy.rollaxis(numpy.asarray(ml_mstack).reshape((4, 4, + nrefs)), -1) # Comparison numpy.testing.assert_almost_equal(py_m44, numpy.asarray(ml_m44), decimal=8) numpy.testing.assert_almost_equal(py_mstack, ml_mstack, decimal=8) @@ -67,15 +74,22 @@ def test_find_m44(engine, ml_lattice, py_lattice, dp, refpts): @pytest.mark.parametrize('dp', (-0.01, 0.0, 0.01)) @pytest.mark.parametrize('refpts', (0, [0, 1, 2, -1], None)) -@pytest.mark.parametrize('func_data', (('twissring', [('SPos', 's_pos'), - ('ClosedOrbit', 'closed_orbit'), ('Dispersion', 'dispersion'), - ('alpha', 'alpha'), ('beta', 'beta'), ('M44', 'm44')]), - ('atlinopt', [('SPos', 's_pos'), ('ClosedOrbit', 'closed_orbit'), - ('Dispersion', 'dispersion'), ('alpha', 'alpha'), ('beta', 'beta'), - ('mu', 'mu'), ('M44', 'm44'), ('A', 'A'), ('B', 'B'), ('C', 'C'), - ('gamma', 'gamma')]))) -@pytest.mark.parametrize('ml_lattice, py_lattice', [(pytest.lazy_fixture('ml_hmba'), - pytest.lazy_fixture('py_hmba'))]) +@pytest.mark.parametrize('ml_lattice, py_lattice', + [(pytest.lazy_fixture('ml_hmba'), + pytest.lazy_fixture('py_hmba'))]) +@pytest.mark.parametrize('func_data', + (('twissring', [('SPos', 's_pos'), + ('ClosedOrbit', 'closed_orbit'), + ('Dispersion', 'dispersion'), + ('alpha', 'alpha'), ('beta', 'beta'), + ('M44', 'm44')]), + ('atlinopt', [('SPos', 's_pos'), + ('ClosedOrbit', 'closed_orbit'), + ('Dispersion', 'dispersion'), + ('alpha', 'alpha'), ('beta', 'beta'), + ('mu', 'mu'), ('M44', 'm44'), + ('A', 'A'), ('B', 'B'), ('C', 'C'), + ('gamma', 'gamma')]))) def test_linear_analysis(engine, ml_lattice, py_lattice, dp, refpts, func_data): """N.B. a 'mu' comparison is left out for twiss data as the values for 'mu' returned by 'twissring' in Matlab are inconsistent with those from @@ -87,24 +101,35 @@ def test_linear_analysis(engine, ml_lattice, py_lattice, dp, refpts, func_data): # Python call if func_data[0] == 'twissring': - py_data0, py_tune, py_chrom, py_data = physics.get_twiss(py_lattice, dp, refpts, get_chrom=True, ddp=1.E-6) + py_data0, py_tune, py_chrom, py_data = physics.get_twiss(py_lattice, + dp, refpts, + True, + ddp=1.E-6) else: - py_data0, py_tune, py_chrom, py_data = physics.linopt(py_lattice, dp, refpts, get_chrom=True, ddp=1.E-6) + py_data0, py_tune, py_chrom, py_data = physics.linopt(py_lattice, dp, + refpts, True, + ddp=1.E-6) # Matlab call - ml_data, ml_tune, ml_chrom = engine.pyproxy(func_data[0], ml_lattice, dp, _ml_refs(refpts, nelems), nargout=3) - ml_data0 = engine.pyproxy(func_data[0], ml_lattice, dp, _ml_refs(nelems, nelems), nargout=3)[0] + ml_data, ml_tune, ml_chrom = engine.pyproxy(func_data[0], ml_lattice, dp, + _ml_refs(refpts, nelems), + nargout=3) + ml_data0 = engine.pyproxy(func_data[0], ml_lattice, dp, + _ml_refs(nelems, nelems), nargout=3)[0] # Comparison numpy.testing.assert_almost_equal(py_tune, _py_data(ml_tune), decimal=6) numpy.testing.assert_almost_equal(py_chrom, _py_data(ml_chrom), decimal=4) - _compare_physdata(numpy.expand_dims(py_data0, 0), ml_data0, func_data[1], decimal=5) + _compare_physdata(numpy.expand_dims(py_data0, 0), ml_data0, func_data[1], + decimal=5) _compare_physdata(py_data, ml_data, func_data[1], decimal=6) @pytest.mark.parametrize('refpts', (0, [0, 1, 2, -1], None)) -@pytest.mark.parametrize('ml_lattice, py_lattice', [(pytest.lazy_fixture('ml_hmba'), - pytest.lazy_fixture('py_hmba'))]) +@pytest.mark.parametrize('ml_lattice, py_lattice', + [(pytest.lazy_fixture('ml_hmba'), + pytest.lazy_fixture('py_hmba'))]) def test_ohmi_envelope(engine, ml_lattice, py_lattice, refpts): - fields = [('beam66', 'r66'), ('beam44', 'r44'), ('emit66', 'emitXYZ'), ('emit44', 'emitXY')] + fields = [('beam66', 'r66'), ('beam44', 'r44'), ('emit66', 'emitXYZ'), + ('emit44', 'emitXY')] nelems = len(py_lattice) refpts = range(nelems + 1) if refpts is None else refpts @@ -113,19 +138,24 @@ def test_ohmi_envelope(engine, ml_lattice, py_lattice, refpts): py_emit0, py_beamdata, py_emit = physics.ohmi_envelope(py_lattice, refpts) # Matlab call ml_emit = engine.pyproxy('atx', ml_lattice, 0.0, _ml_refs(refpts, nelems)) - ml_emit0, ml_params = engine.pyproxy('atx', ml_lattice, 0.0, _ml_refs(0, nelems), nargout=2) + ml_emit0, ml_params = engine.pyproxy('atx', ml_lattice, 0.0, + _ml_refs(0, nelems), nargout=2) revolution_period = get_s_pos(py_lattice, nelems) / speed_of_light damping_times = revolution_period / py_beamdata.damping_rates # Comparison - numpy.testing.assert_almost_equal(damping_times, _py_data(ml_params['dampingtime']), decimal=8) - numpy.testing.assert_almost_equal(py_beamdata.mode_emittances, _py_data(ml_emit0['modemit']), decimal=8) + numpy.testing.assert_almost_equal(damping_times, + _py_data(ml_params['dampingtime']), + decimal=8) + numpy.testing.assert_almost_equal(py_beamdata.mode_emittances, + _py_data(ml_emit0['modemit']), decimal=8) _compare_physdata(numpy.expand_dims(py_emit0, 0), ml_emit0, fields) _compare_physdata(py_emit, ml_emit, fields) @pytest.mark.parametrize('dp', (0.00, 0.01, -0.01)) -@pytest.mark.parametrize('ml_lattice, py_lattice', [(pytest.lazy_fixture('ml_hmba'), - pytest.lazy_fixture('py_hmba'))]) +@pytest.mark.parametrize('ml_lattice, py_lattice', + [(pytest.lazy_fixture('ml_hmba'), + pytest.lazy_fixture('py_hmba'))]) def test_parameters(engine, ml_lattice, py_lattice, dp): # Test perimeter @@ -142,6 +172,7 @@ def test_parameters(engine, ml_lattice, py_lattice, dp): assert py_lattice.harmonic_number == int(ml_harms) # test momentum compaction factor + py_lattice.radiation_off() py_mcf = py_lattice.get_mcf(dp, ddp=1.E-6) # Matlab uses ddp=1.E-6 ml_mcf = engine.mcf(ml_lattice, dp) numpy.testing.assert_allclose(py_mcf, ml_mcf, rtol=1.E-8) diff --git a/pyat/test_matlab/utils.py b/pyat/test_matlab/utils.py index 5213dfdb4..2f06cd762 100644 --- a/pyat/test_matlab/utils.py +++ b/pyat/test_matlab/utils.py @@ -11,6 +11,7 @@ ROOT_DIR = os.path.realpath(os.path.join(os.path.dirname(__file__), '../..')) dba_ring = os.path.join(ROOT_DIR, 'pyat/test_matlab/dba.mat') hmba_ring = os.path.join(ROOT_DIR, 'pyat/test_matlab/hmba.mat') +err_ring = os.path.join(ROOT_DIR, 'pyat/test_matlab/err.mat') def initialise_matlab():