From dde42e7bb492bc17efa9a48c93ada093a44ce133 Mon Sep 17 00:00:00 2001 From: Laurent Farvacque Date: Fri, 22 Mar 2024 10:21:14 +0100 Subject: [PATCH 01/22] rebased on mster --- pyat/at/lattice/elements.py | 260 ++++++++++++++++-------------------- pyat/at/load/__init__.py | 1 + pyat/at/load/json.py | 108 +++++++++++++++ 3 files changed, 227 insertions(+), 142 deletions(-) create mode 100644 pyat/at/load/json.py diff --git a/pyat/at/lattice/elements.py b/pyat/at/lattice/elements.py index 4dbfff9bd..99c161284 100644 --- a/pyat/at/lattice/elements.py +++ b/pyat/at/lattice/elements.py @@ -5,7 +5,6 @@ appropriate attributes. If a different PassMethod is set, it is the caller's responsibility to ensure that the appropriate attributes are present. """ - from __future__ import annotations import abc @@ -13,13 +12,10 @@ from abc import ABC from collections.abc import Generator, Iterable from copy import copy, deepcopy -from typing import Any, Optional +from typing import Optional, Any import numpy -# noinspection PyProtectedMember -from .variables import _nop - def _array(value, shape=(-1,), dtype=numpy.float64): # Ensure proper ordering(F) and alignment(A) for "C" access in integrators @@ -31,17 +27,8 @@ def _array66(value): return _array(value, shape=(6, 6)) -def _float(value) -> float: - return float(value) - - -def _int(value, vmin: Optional[int] = None, vmax: Optional[int] = None) -> int: - intv = int(value) - if vmin is not None and intv < vmin: - raise ValueError(f"Value must be greater of equal to {vmin}") - if vmax is not None and intv > vmax: - raise ValueError(f"Value must be smaller of equal to {vmax}") - return intv +def _nop(value): + return value class LongtMotion(ABC): @@ -57,7 +44,6 @@ class LongtMotion(ABC): * ``set_longt_motion(self, enable, new_pass=None, copy=False, **kwargs)`` must enable or disable longitudinal motion. """ - @abc.abstractmethod def _get_longt_motion(self): return False @@ -119,8 +105,7 @@ class _DictLongtMotion(LongtMotion): Defines a class such that :py:meth:`set_longt_motion` will select ``'IdentityPass'`` or ``'IdentityPass'``. - """ - + """ def _get_longt_motion(self): return self.PassMethod != self.default_pass[False] @@ -178,20 +163,16 @@ def set_longt_motion(self, enable, new_pass=None, copy=False, **kwargs): if new_pass is None or new_pass == self.PassMethod: return self if copy else None if enable: - def setpass(el): el.PassMethod = new_pass el.Energy = kwargs['energy'] - else: - def setpass(el): el.PassMethod = new_pass try: del el.Energy except AttributeError: pass - if copy: newelem = deepcopy(self) setpass(newelem) @@ -261,7 +242,7 @@ class Element(object): """Base class for AT elements""" _BUILD_ATTRIBUTES = ['FamName'] - _conversions = dict(FamName=str, PassMethod=str, Length=_float, + _conversions = dict(FamName=str, PassMethod=str, Length=float, R1=_array66, R2=_array66, T1=lambda v: _array(v, (6,)), T2=lambda v: _array(v, (6,)), @@ -269,9 +250,9 @@ class Element(object): EApertures=lambda v: _array(v, (2,)), KickAngle=lambda v: _array(v, (2,)), PolynomB=_array, PolynomA=_array, - BendingAngle=_float, - MaxOrder=_int, NumIntSteps=lambda v: _int(v, vmin=0), - Energy=_float, + BendingAngle=float, + MaxOrder=int, NumIntSteps=int, + Energy=float, ) _entrance_fields = ['T1', 'R1'] @@ -293,35 +274,27 @@ def __init__(self, family_name: str, **kwargs): def __setattr__(self, key, value): try: - value = self._conversions.get(key, _nop)(value) + super(Element, self).__setattr__( + key, self._conversions.get(key, _nop)(value)) except Exception as exc: exc.args = ('In element {0}, parameter {1}: {2}'.format( self.FamName, key, exc),) raise - else: - super(Element, self).__setattr__(key, value) def __str__(self): - first3 = ["FamName", "Length", "PassMethod"] - # Get values and parameter objects attrs = dict(self.items()) - keywords = [f"\t{k} : {attrs.pop(k)!s}" for k in first3] - keywords += [f"\t{k} : {v!s}" for k, v in attrs.items()] - return "\n".join((type(self).__name__ + ":", "\n".join(keywords))) + return "\n".join( + [self.__class__.__name__ + ":"] + + [f"{k:>14}: {attrs.pop(k)!s}" for k in ["FamName", "Length", "PassMethod"]] + + [f"{k:>14}: {v!s}" for k, v in attrs.items()] + ) def __repr__(self): - # Get values only, even for parameters - attrs = dict((k, getattr(self, k)) for k, v in self.items()) - arguments = [attrs.pop(k) for k in self._BUILD_ATTRIBUTES] - defelem = self.__class__(*arguments) - keywords = [f"{v!r}" for v in arguments] - keywords += [ - f"{k}={v!r}" - for k, v in sorted(attrs.items()) - if not numpy.array_equal(v, getattr(defelem, k, None)) - ] + clsname, args, kwargs = self.definition + keywords = [f"{arg!r}" for arg in args] + keywords += [f"{k}={v!r}" for k, v in kwargs.items()] args = re.sub(r"\n\s*", " ", ", ".join(keywords)) - return "{0}({1})".format(self.__class__.__name__, args) + return f"{clsname}({args})" def equals(self, other) -> bool: """Whether an element is equivalent to another. @@ -352,12 +325,10 @@ def divide(self, frac) -> list[Element]: def swap_faces(self, copy=False): """Swap the faces of an element, alignment errors are ignored""" - def swapattr(element, attro, attri): val = getattr(element, attri) delattr(element, attri) return attro, val - if copy: el = self.copy() else: @@ -388,7 +359,7 @@ def update(self, *args, **kwargs): Update the element attributes with the given arguments """ attrs = dict(*args, **kwargs) - for key, value in attrs.items(): + for (key, value) in attrs.items(): setattr(self, key, value) def copy(self) -> Element: @@ -399,10 +370,23 @@ def deepcopy(self) -> Element: """Return a deep copy of the element""" return deepcopy(self) + @property + def definition(self) -> tuple[str, tuple, dict]: + """tuple (class_name, args, kwargs) defining the element""" + attrs = dict(self.items()) + arguments = tuple(attrs.pop(k, getattr(self, k)) for k in self._BUILD_ATTRIBUTES) + defelem = self.__class__(*arguments) + keywords = dict( + (k, v) + for k, v in sorted(attrs.items()) + if not numpy.array_equal(v, getattr(defelem, k, None)) + ) + return self.__class__.__name__, arguments, keywords + def items(self) -> Generator[tuple[str, Any], None, None]: """Iterates through the data members""" - # Properties may be added by overloading this method - yield from vars(self).items() + for k, v in vars(self).items(): + yield k, v def is_compatible(self, other: Element) -> bool: """Checks if another :py:class:`Element` can be merged""" @@ -412,7 +396,8 @@ def merge(self, other) -> None: """Merge another element""" if not self.is_compatible(other): badname = getattr(other, 'FamName', type(other)) - raise TypeError("Cannot merge {0} and {1}".format(self.FamName, badname)) + raise TypeError('Cannot merge {0} and {1}'.format(self.FamName, + badname)) # noinspection PyMethodMayBeStatic def _get_longt_motion(self): @@ -434,8 +419,8 @@ def is_collective(self) -> bool: class LongElement(Element): - """Base class for long elements""" - + """Base class for long elements + """ _BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES + ['Length'] def __init__(self, family_name: str, length: float, *args, **kwargs): @@ -469,7 +454,8 @@ def popattr(element, attr): # Remove entrance and exit attributes fin = dict(popattr(el, key) for key in vars(self) if key in self._entrance_fields) - fout = dict(popattr(el, key) for key in vars(self) if key in self._exit_fields) + fout = dict(popattr(el, key) for key in vars(self) if + key in self._exit_fields) # Split element element_list = [el._part(f, numpy.sum(frac)) for f in frac] # Restore entrance and exit attributes @@ -480,22 +466,8 @@ def popattr(element, attr): return element_list def is_compatible(self, other) -> bool: - def compatible_field(fieldname): - f1 = getattr(self, fieldname, None) - f2 = getattr(other, fieldname, None) - if f1 is None and f2 is None: # no such field - return True - elif f1 is None or f2 is None: # only one - return False - else: # both - return numpy.all(f1 == f2) - - if not (type(other) is type(self) and self.PassMethod == other.PassMethod): - return False - for fname in ("RApertures", "EApertures"): - if not compatible_field(fname): - return False - return True + return type(other) is type(self) and \ + self.PassMethod == other.PassMethod def merge(self, other) -> None: super().merge(other) @@ -542,7 +514,6 @@ def means(self): class SliceMoments(Element): """Element to compute slices mean and std""" - _BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES + ['nslice'] _conversions = dict(Element._conversions, nslice=int) @@ -561,7 +532,8 @@ def __init__(self, family_name: str, nslice: int, **kwargs): kwargs.setdefault('PassMethod', 'SliceMomentsPass') self._startturn = kwargs.pop('startturn', 0) self._endturn = kwargs.pop('endturn', 1) - super(SliceMoments, self).__init__(family_name, nslice=nslice, **kwargs) + super(SliceMoments, self).__init__(family_name, nslice=nslice, + **kwargs) self._nbunch = 1 self.startturn = self._startturn self.endturn = self._endturn @@ -576,33 +548,45 @@ def set_buffers(self, nturns, nbunch): self.endturn = min(self.endturn, nturns) self._dturns = self.endturn - self.startturn self._nbunch = nbunch - self._stds = numpy.zeros((3, nbunch*self.nslice, self._dturns), order="F") - self._means = numpy.zeros((3, nbunch*self.nslice, self._dturns), order="F") - self._spos = numpy.zeros((nbunch*self.nslice, self._dturns), order="F") - self._weights = numpy.zeros((nbunch*self.nslice, self._dturns), order="F") + self._stds = numpy.zeros((3, nbunch*self.nslice, self._dturns), + order='F') + self._means = numpy.zeros((3, nbunch*self.nslice, self._dturns), + order='F') + self._spos = numpy.zeros((nbunch*self.nslice, self._dturns), + order='F') + self._weights = numpy.zeros((nbunch*self.nslice, self._dturns), + order='F') @property def stds(self): """Slices x,y,dp standard deviation""" - return self._stds.reshape((3, self._nbunch, self.nslice, self._dturns)) + return self._stds.reshape((3, self._nbunch, + self.nslice, + self._dturns)) @property def means(self): """Slices x,y,dp center of mass""" - return self._means.reshape((3, self._nbunch, self.nslice, self._dturns)) + return self._means.reshape((3, self._nbunch, + self.nslice, + self._dturns)) @property def spos(self): """Slices s position""" - return self._spos.reshape((self._nbunch, self.nslice, self._dturns)) + return self._spos.reshape((self._nbunch, + self.nslice, + self._dturns)) @property def weights(self): """Slices weights in mA if beam current >0, - otherwise fraction of total number of - particles in the bunch + otherwise fraction of total number of + particles in the bunch """ - return self._weights.reshape((self._nbunch, self.nslice, self._dturns)) + return self._weights.reshape((self._nbunch, + self.nslice, + self._dturns)) @property def startturn(self): @@ -633,7 +617,6 @@ def endturn(self, value): class Aperture(Element): """Aperture element""" - _BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES + ['Limits'] _conversions = dict(Element._conversions, Limits=lambda v: _array(v, (4,))) @@ -712,7 +695,6 @@ def insert(self, class Collimator(Drift): """Collimator element""" - _BUILD_ATTRIBUTES = LongElement._BUILD_ATTRIBUTES + ['RApertures'] def __init__(self, family_name: str, length: float, limits, **kwargs): @@ -731,8 +713,8 @@ def __init__(self, family_name: str, length: float, limits, **kwargs): class ThinMultipole(Element): """Thin multipole element""" - - _BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES + ["PolynomA", "PolynomB"] + _BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES + ['PolynomA', + 'PolynomB'] def __init__(self, family_name: str, poly_a, poly_b, **kwargs): """ @@ -760,13 +742,10 @@ def lengthen(poly, dl): else: return poly - # PolynomA and PolynomB and convert to ParamArray - prmpola = self._conversions["PolynomA"](kwargs.pop("PolynomA", poly_a)) - prmpolb = self._conversions["PolynomB"](kwargs.pop("PolynomB", poly_b)) - poly_a, len_a, ord_a = getpol(prmpola) - poly_b, len_b, ord_b = getpol(prmpolb) + # Remove MaxOrder, PolynomA and PolynomB + poly_a, len_a, ord_a = getpol(_array(kwargs.pop('PolynomA', poly_a))) + poly_b, len_b, ord_b = getpol(_array(kwargs.pop('PolynomB', poly_b))) deforder = max(getattr(self, 'DefaultOrder', 0), ord_a, ord_b) - # Remove MaxOrder maxorder = kwargs.pop('MaxOrder', deforder) kwargs.setdefault('PassMethod', 'ThinMPolePass') super(ThinMultipole, self).__init__(family_name, **kwargs) @@ -774,32 +753,36 @@ def lengthen(poly, dl): super(ThinMultipole, self).__setattr__('MaxOrder', maxorder) # Adjust polynom lengths and set them len_ab = max(self.MaxOrder + 1, len_a, len_b) - self.PolynomA = lengthen(prmpola, len_ab - len_a) - self.PolynomB = lengthen(prmpolb, len_ab - len_b) + self.PolynomA = lengthen(poly_a, len_ab - len_a) + self.PolynomB = lengthen(poly_b, len_ab - len_b) def __setattr__(self, key, value): """Check the compatibility of MaxOrder, PolynomA and PolynomB""" polys = ('PolynomA', 'PolynomB') if key in polys: - lmin = self.MaxOrder + value = _array(value) + lmin = getattr(self, 'MaxOrder') if not len(value) > lmin: raise ValueError( 'Length of {0} must be larger than {1}'.format(key, lmin)) elif key == 'MaxOrder': - intval = int(value) + value = int(value) lmax = min(len(getattr(self, k)) for k in polys) - if not intval < lmax: - raise ValueError("MaxOrder must be smaller than {0}".format(lmax)) + if not value < lmax: + raise ValueError( + 'MaxOrder must be smaller than {0}'.format(lmax)) + super(ThinMultipole, self).__setattr__(key, value) class Multipole(_Radiative, LongElement, ThinMultipole): """Multipole element""" - - _BUILD_ATTRIBUTES = LongElement._BUILD_ATTRIBUTES + ["PolynomA", "PolynomB"] + _BUILD_ATTRIBUTES = LongElement._BUILD_ATTRIBUTES + ['PolynomA', + 'PolynomB'] _conversions = dict(ThinMultipole._conversions, K=float, H=float) - def __init__(self, family_name: str, length: float, poly_a, poly_b, **kwargs): + def __init__(self, family_name: str, length: float, poly_a, poly_b, + **kwargs): """ Args: family_name: Name of the element @@ -819,10 +802,12 @@ def __init__(self, family_name: str, length: float, poly_a, poly_b, **kwargs): """ kwargs.setdefault('PassMethod', 'StrMPoleSymplectic4Pass') kwargs.setdefault('NumIntSteps', 10) - super(Multipole, self).__init__(family_name, length, poly_a, poly_b, **kwargs) + super(Multipole, self).__init__(family_name, length, + poly_a, poly_b, **kwargs) def is_compatible(self, other) -> bool: - if super().is_compatible(other) and self.MaxOrder == other.MaxOrder: + if super().is_compatible(other) and \ + self.MaxOrder == other.MaxOrder: for i in range(self.MaxOrder + 1): if self.PolynomB[i] != other.PolynomB[i]: return False @@ -836,8 +821,7 @@ def is_compatible(self, other) -> bool: @property def K(self) -> float: """Focusing strength [mˆ-2]""" - arr = self.PolynomB - return 0.0 if len(arr) < 2 else arr[1] + return 0.0 if len(self.PolynomB) < 2 else self.PolynomB[1] # noinspection PyPep8Naming @K.setter @@ -848,8 +832,7 @@ def K(self, strength: float): @property def H(self) -> float: """Sextupolar strength [mˆ-3]""" - arr = self.PolynomB - return 0.0 if len(arr) < 3 else arr[2] + return 0.0 if len(self.PolynomB) < 3 else self.PolynomB[2] # noinspection PyPep8Naming @H.setter @@ -922,15 +905,16 @@ def __init__(self, family_name: str, length: float, Default PassMethod: :ref:`BndMPoleSymplectic4Pass` """ + poly_b = kwargs.pop('PolynomB', numpy.array([0, k])) kwargs.setdefault('BendingAngle', bending_angle) kwargs.setdefault('EntranceAngle', 0.0) kwargs.setdefault('ExitAngle', 0.0) kwargs.setdefault('PassMethod', 'BndMPoleSymplectic4Pass') - super(Dipole, self).__init__(family_name, length, [], [0.0, k], **kwargs) + super(Dipole, self).__init__(family_name, length, [], poly_b, **kwargs) - def items(self) -> Generator[tuple[str, Any], None, None]: + def items(self) -> Generator[tuple, None, None]: yield from super().items() - yield "K", vars(self)["PolynomB"][1] + yield 'K', self.K def _part(self, fr, sumfr): pp = super(Dipole, self)._part(fr, sumfr) @@ -943,9 +927,9 @@ def is_compatible(self, other) -> bool: def invrho(dip: Dipole): return dip.BendingAngle / dip.Length - return (super().is_compatible(other) - and self.ExitAngle == -other.EntranceAngle - and abs(invrho(self) - invrho(other)) <= 1.e-6) + return (super().is_compatible(other) and + self.ExitAngle == -other.EntranceAngle and + abs(invrho(self) - invrho(other)) <= 1.e-6) def merge(self, other) -> None: super().merge(other) @@ -960,7 +944,6 @@ def merge(self, other) -> None: class Quadrupole(Radiative, Multipole): """Quadrupole element""" - _BUILD_ATTRIBUTES = LongElement._BUILD_ATTRIBUTES + ['K'] _conversions = dict(Multipole._conversions, FringeQuadEntrance=int, FringeQuadExit=int) @@ -998,17 +981,18 @@ def __init__(self, family_name: str, length: float, Default PassMethod: ``StrMPoleSymplectic4Pass`` """ - kwargs.setdefault("PassMethod", "StrMPoleSymplectic4Pass") - super(Quadrupole, self).__init__(family_name, length, [], [0.0, k], **kwargs) + poly_b = kwargs.pop('PolynomB', numpy.array([0, k])) + kwargs.setdefault('PassMethod', 'StrMPoleSymplectic4Pass') + super(Quadrupole, self).__init__(family_name, length, [], poly_b, + **kwargs) - def items(self) -> Generator[tuple[str, Any], None, None]: + def items(self) -> Generator[tuple, None, None]: yield from super().items() - yield "K", vars(self)["PolynomB"][1] + yield 'K', self.K class Sextupole(Multipole): """Sextupole element""" - _BUILD_ATTRIBUTES = LongElement._BUILD_ATTRIBUTES + ['H'] DefaultOrder = 2 @@ -1032,18 +1016,14 @@ def __init__(self, family_name: str, length: float, Default PassMethod: ``StrMPoleSymplectic4Pass`` """ - kwargs.setdefault("PassMethod", "StrMPoleSymplectic4Pass") - super(Sextupole, self).__init__(family_name, length, [], [0.0, 0.0, h], + poly_b = kwargs.pop('PolynomB', [0, 0, h]) + kwargs.setdefault('PassMethod', 'StrMPoleSymplectic4Pass') + super(Sextupole, self).__init__(family_name, length, [], poly_b, **kwargs) - def items(self) -> Generator[tuple[str, Any], None, None]: - yield from super().items() - yield "H", vars(self)["PolynomB"][2] - class Octupole(Multipole): """Octupole element, with no changes from multipole at present""" - _BUILD_ATTRIBUTES = Multipole._BUILD_ATTRIBUTES DefaultOrder = 3 @@ -1051,7 +1031,6 @@ class Octupole(Multipole): class RFCavity(LongtMotion, LongElement): """RF cavity element""" - _BUILD_ATTRIBUTES = LongElement._BUILD_ATTRIBUTES + ['Voltage', 'Frequency', 'HarmNumber', @@ -1092,9 +1071,9 @@ def _part(self, fr, sumfr): return pp def is_compatible(self, other) -> bool: - return (super().is_compatible(other) - and self.Frequency == other.Frequency - and self.TimeLag == other.TimeLag) + return (super().is_compatible(other) and + self.Frequency == other.Frequency and + self.TimeLag == other.TimeLag) def merge(self, other) -> None: super().merge(other) @@ -1113,7 +1092,6 @@ def set_longt_motion(self, enable, new_pass=None, **kwargs): class M66(Element): """Linear (6, 6) transfer matrix""" - _BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES + ["M66"] _conversions = dict(Element._conversions, M66=_array66) @@ -1124,7 +1102,7 @@ def __init__(self, family_name: str, m66=None, **kwargs): m66: Transfer matrix. Default: Identity matrix Default PassMethod: ``Matrix66Pass`` - """ + """ if m66 is None: m66 = numpy.identity(6) kwargs.setdefault('PassMethod', 'Matrix66Pass') @@ -1140,7 +1118,6 @@ class SimpleQuantDiff(_DictLongtMotion, Element): Note: The damping times are needed to compute the correct kick for the emittance. Radiation damping is NOT applied. """ - _BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES default_pass = {False: 'IdentityPass', True: 'SimpleQuantDiffPass'} @@ -1165,8 +1142,8 @@ def __init__(self, family_name: str, betax: float = 1.0, tauz: Longitudinal damping time [turns] Default PassMethod: ``SimpleQuantDiffPass`` - """ - kwargs.setdefault("PassMethod", self.default_pass[True]) + """ + kwargs.setdefault('PassMethod', self.default_pass[True]) assert taux >= 0.0, 'taux must be greater than or equal to 0' self.taux = taux @@ -1199,7 +1176,6 @@ def __init__(self, family_name: str, betax: float = 1.0, class SimpleRadiation(_DictLongtMotion, Radiative, Element): """Simple radiation damping and energy loss""" - _BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES _conversions = dict(Element._conversions, U0=float, damp_mat_diag=lambda v: _array(v, shape=(6,))) @@ -1221,7 +1197,7 @@ def __init__(self, family_name: str, U0: Energy loss per turn [eV] Default PassMethod: ``SimpleRadiationPass`` - """ + """ assert taux >= 0.0, 'taux must be greater than or equal to 0' if taux == 0.0: dampx = 1 @@ -1250,7 +1226,6 @@ def __init__(self, family_name: str, class Corrector(LongElement): """Corrector element""" - _BUILD_ATTRIBUTES = LongElement._BUILD_ATTRIBUTES + ['KickAngle'] def __init__(self, family_name: str, length: float, kick_angle, **kwargs): @@ -1276,7 +1251,6 @@ class Wiggler(Radiative, LongElement): See atwiggler.m """ - _BUILD_ATTRIBUTES = LongElement._BUILD_ATTRIBUTES + ['Lw', 'Bmax', 'Energy'] _conversions = dict(Element._conversions, Lw=float, Bmax=float, @@ -1319,12 +1293,14 @@ def __init__(self, family_name: str, length: float, wiggle_period: float, for i, b in enumerate(self.By.T): dk = abs(b[3] ** 2 - b[4] ** 2 - b[2] ** 2) / abs(b[4]) if dk > 1e-6: - raise ValueError("Wiggler(H): kx^2 + kz^2 -ky^2 !=0, i = {0}".format(i)) + raise ValueError("Wiggler(H): kx^2 + kz^2 -ky^2 !=0, i = " + "{0}".format(i)) for i, b in enumerate(self.Bx.T): dk = abs(b[2] ** 2 - b[4] ** 2 - b[3] ** 2) / abs(b[4]) if dk > 1e-6: - raise ValueError("Wiggler(V): ky^2 + kz^2 -kx^2 !=0, i = {0}".format(i)) + raise ValueError("Wiggler(V): ky^2 + kz^2 -kx^2 !=0, i = " + "{0}".format(i)) self.NHharm = self.By.shape[1] self.NVharm = self.Bx.shape[1] diff --git a/pyat/at/load/__init__.py b/pyat/at/load/__init__.py index 28c187e3e..196af9f2a 100644 --- a/pyat/at/load/__init__.py +++ b/pyat/at/load/__init__.py @@ -8,3 +8,4 @@ from .reprfile import * from .tracy import * from .elegant import * +from .json import * diff --git a/pyat/at/load/json.py b/pyat/at/load/json.py new file mode 100644 index 000000000..141d62a25 --- /dev/null +++ b/pyat/at/load/json.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import json +from typing import Optional, Any + +import numpy as np + +from . import register_format +from ..lattice import Element, Lattice, Particle, get_class_map + +_CLASS_MAP = get_class_map() + + +def elemstr(self): + attrs = dict(self.items()) + return "\n".join( + [self.__class__.__name__ + ":"] + + [f"{k:>14}: {attrs.pop(k)!s}" for k in ["FamName", "Length", "PassMethod"]] + + [f"{k:>14}: {v!s}" for k, v in attrs.items()] + ) + + +class _AtEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, Element): + return obj.definition + elif isinstance(obj, Particle): + return obj.to_dict() + elif isinstance(obj, np.ndarray): + return obj.tolist() + else: + return super().default(obj) + + +def save_json(ring: Lattice, filename: Optional[str] = None) -> None: + """Save a :py:class:`.Lattice` as a JSON file + + Parameters: + ring: Lattice description + filename: Name of the JSON file. Default: outputs on + :py:obj:`sys.stdout` + + See Also: + :py:func:`.save_lattice` for a generic lattice-saving function. + :py:meth:`.Lattice.save` for a generic lattice-saving method. + """ + if filename is None: + json.dumps(("Lattice", ring, ring.attrs), cls=_AtEncoder) + else: + with open(filename, "wt") as jsonfile: + json.dump(("Lattice", ring, ring.attrs), jsonfile, cls=_AtEncoder) + + +def load_json(filename: str, **kwargs) -> Lattice: + """Create a :py:class:`.Lattice` from a JSON file + + Parameters: + filename: Name of a JSON file + + Keyword Args: + name (str): Name of the lattice. Default: taken from + the lattice + energy (float): Energy of the lattice [eV]. Default: taken + from the lattice elements + periodicity(int): Number of periods. Default: taken from the + elements, or 1 + *: All other keywords will be set as Lattice + attributes + + Returns: + lattice (Lattice): New :py:class:`.Lattice` object + + See Also: + :py:meth:`.Lattice.load` for a generic lattice-loading method. + """ + + def json_generator(params: dict[str, Any], elem_list): + particle_dict = params.pop("particle", {}) + params["particle"] = Particle(**particle_dict) + for clname, args, keys in elem_list: + cls = _CLASS_MAP[clname] + yield cls(*args, **keys) + + with open(filename, "rt") as jsonfile: + data = json.load(jsonfile) + + try: + code, elems, prms = data + except ValueError: + raise TypeError("Not a Lattice") + if not ( + isinstance(code, str) + and isinstance(elems, list) + and isinstance(prms, dict) + and (code == "Lattice") + ): + raise TypeError("Not a lattice") + + prms.update(kwargs) + return Lattice(elems, iterator=json_generator, **prms) + + +register_format( + ".json", + load_json, + save_json, + descr="JSON representation of a python AT Lattice", +) From b0308eb55db48fec0cc55deeb76f9385fc223810 Mon Sep 17 00:00:00 2001 From: Laurent Farvacque Date: Fri, 22 Mar 2024 20:24:55 +0100 Subject: [PATCH 02/22] merged fro master --- pyat/at/lattice/elements.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/pyat/at/lattice/elements.py b/pyat/at/lattice/elements.py index 99c161284..7c15759f2 100644 --- a/pyat/at/lattice/elements.py +++ b/pyat/at/lattice/elements.py @@ -6,15 +6,13 @@ responsibility to ensure that the appropriate attributes are present. """ from __future__ import annotations - import abc import re +import numpy +from copy import copy, deepcopy from abc import ABC from collections.abc import Generator, Iterable -from copy import copy, deepcopy -from typing import Optional, Any - -import numpy +from typing import Any, Optional def _array(value, shape=(-1,), dtype=numpy.float64): @@ -466,8 +464,22 @@ def popattr(element, attr): return element_list def is_compatible(self, other) -> bool: - return type(other) is type(self) and \ - self.PassMethod == other.PassMethod + def compatible_field(fieldname): + f1 = getattr(self, fieldname, None) + f2 = getattr(other, fieldname, None) + if f1 is None and f2 is None: # no such field + return True + elif f1 is None or f2 is None: # only one + return False + else: # both + return numpy.all(f1 == f2) + + if not (type(other) is type(self) and self.PassMethod == other.PassMethod): + return False + for fname in ("RApertures", "EApertures"): + if not compatible_field(fname): + return False + return True def merge(self, other) -> None: super().merge(other) From f69c9ecd71963f053e5b0344e14dde261dfbf2f8 Mon Sep 17 00:00:00 2001 From: Laurent Farvacque Date: Sat, 23 Mar 2024 11:18:00 +0100 Subject: [PATCH 03/22] string output --- pyat/at/load/json.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/pyat/at/load/json.py b/pyat/at/load/json.py index 141d62a25..ec75c0b6f 100644 --- a/pyat/at/load/json.py +++ b/pyat/at/load/json.py @@ -11,16 +11,8 @@ _CLASS_MAP = get_class_map() -def elemstr(self): - attrs = dict(self.items()) - return "\n".join( - [self.__class__.__name__ + ":"] - + [f"{k:>14}: {attrs.pop(k)!s}" for k in ["FamName", "Length", "PassMethod"]] - + [f"{k:>14}: {v!s}" for k, v in attrs.items()] - ) - - class _AtEncoder(json.JSONEncoder): + """JSON encoder for specific AT types""" def default(self, obj): if isinstance(obj, Element): return obj.definition @@ -45,7 +37,7 @@ def save_json(ring: Lattice, filename: Optional[str] = None) -> None: :py:meth:`.Lattice.save` for a generic lattice-saving method. """ if filename is None: - json.dumps(("Lattice", ring, ring.attrs), cls=_AtEncoder) + print(json.dumps(("Lattice", ring, ring.attrs), cls=_AtEncoder)) else: with open(filename, "wt") as jsonfile: json.dump(("Lattice", ring, ring.attrs), jsonfile, cls=_AtEncoder) From e8a7318ecb23fb32dbe25c5f3c722db588a9f238 Mon Sep 17 00:00:00 2001 From: Laurent Farvacque Date: Sun, 24 Mar 2024 18:01:06 +0100 Subject: [PATCH 04/22] Handle JSON files in python and Matlab --- atmat/lattice/atloadlattice.m | 24 ++++++++++++- atmat/lattice/atwritejson.m | 64 +++++++++++++++++++++++++++++++++++ pyat/at/load/json.py | 52 ++++++++++++++-------------- pyat/at/load/matfile.py | 33 +++++++++++------- pyat/at/load/reprfile.py | 44 ++++++++++++------------ pyat/at/load/utils.py | 55 +++++++++++++++++------------- 6 files changed, 186 insertions(+), 86 deletions(-) create mode 100644 atmat/lattice/atwritejson.m diff --git a/atmat/lattice/atloadlattice.m b/atmat/lattice/atloadlattice.m index 04260fe06..bdb14efdc 100644 --- a/atmat/lattice/atloadlattice.m +++ b/atmat/lattice/atloadlattice.m @@ -21,12 +21,16 @@ % the variable name must be specified using the 'matkey' keyword. % % .m Matlab function. The function must output a valid AT structure. +% .json JSON file +% +%see also atwritem, atwritejson persistent link_table if isempty(link_table) link_table.mat=@load_mat; link_table.m=@load_m; + link_table.json=@load_json; end [~,~,fext]=fileparts(fspec); @@ -57,7 +61,7 @@ dt=load(fpath); vnames=fieldnames(dt); key='RING'; - if length(vnames) == 1 + if isscalar(vnames) key=vnames{1}; else for v={'ring','lattice'} @@ -75,7 +79,25 @@ error('AT:load','Cannot find variable %s\nmatkey must be in: %s',... key, strjoin(vnames,', ')); end + end + function [lattice, opts]=load_json(fpath, opts) + data=jsondecode(fileread(fpath)); + prms=data.parameters; + name=prms.name; + energy=prms.energy; + periodicity=prms.periodicity; + particle=atparticle.loadobj(prms.particle); + harmnumber=prms.harmonic_number; + prms=rmfield(prms,{'name','energy','periodicity','particle','harmonic_number'}); + args=[fieldnames(prms) struct2cell(prms)]'; + lattice=atSetRingProperties(data.elements,... + 'FamName', name,... + 'Energy', energy,... + 'Periodicity', periodicity,... + 'Particle', particle,... + 'HarmNumber', harmnumber, ... + args{:}); end end \ No newline at end of file diff --git a/atmat/lattice/atwritejson.m b/atmat/lattice/atwritejson.m new file mode 100644 index 000000000..76d9b67c2 --- /dev/null +++ b/atmat/lattice/atwritejson.m @@ -0,0 +1,64 @@ +function varargout=atwritejson(ring, varargin) +%ATWRITEJSON Create a JSON file to store an AT lattice +% +%JS=ATWRITEJSON(RING) +% Return the JSON representation of RING as a character array +% +%ATWRITEJSON(RING, FILENAME) +% Write the JSON representation of RING to the file FILENAME +% +%ATWRITEJSON(RING, ..., 'compact', true) +% If compact is true, write a compact JSON file (no linefeeds) +% +%see also atloadlattice + +[compact, varargs]=getoption(varargin, 'compact', false); +[filename, ~]=getargs(varargs,[]); + +if ~isempty(filename) + %get filename information + [pname,fname,ext]=fileparts(filename); + + %Check file extension + if isempty(ext), ext='.json'; end + + % Make fullname + fn=fullfile(pname,[fname ext]); + + % Open file to be written + [fid,mess]=fopen(fullfile(pname,[fname ext]),'wt'); + + if fid==-1 + error('AT:FileErr','Cannot Create file %s\n%s',fn,mess); + else + fprintf(fid, sjson(ring)); + fclose(fid); + end + varargout={}; +else + varargout={sjson(ring)}; +end + + function jsondata=sjson(ring) + ok=~atgetcells(ring, 'Class', 'RingParam'); + data.elements=ring(ok); + data.parameters=get_params(ring); + jsondata=jsonencode(data, 'PrettyPrint', ~compact); + end + + function prms=get_params(ring) + [name, energy, part, periodicity, harmonic_number]=... + atGetRingProperties(ring,'FamName', 'Energy', 'Particle',... + 'Periodicity', 'HarmNumber'); + prms=struct('name', name, 'energy', energy, 'periodicity', periodicity,... + 'particle', saveobj(part), 'harmonic_number', harmonic_number); + idx=atlocateparam(ring); + p2=rmfield(ring{idx},{'FamName','PassMethod','Length','Class',... + 'Energy', 'Particle','Periodicity','cell_harmnumber'}); + for nm=fieldnames(p2)' + na=nm{1}; + prms.(na)=p2.(na); + end + end + +end \ No newline at end of file diff --git a/pyat/at/load/json.py b/pyat/at/load/json.py index ec75c0b6f..84aa54598 100644 --- a/pyat/at/load/json.py +++ b/pyat/at/load/json.py @@ -1,18 +1,24 @@ +""" +Handling of JSON files +""" + from __future__ import annotations +__all__ = ["save_json", "load_json"] + import json from typing import Optional, Any import numpy as np -from . import register_format -from ..lattice import Element, Lattice, Particle, get_class_map - -_CLASS_MAP = get_class_map() +from .allfiles import register_format +from .utils import element_to_dict, element_from_dict, save_filter +from ..lattice import Element, Lattice, Particle class _AtEncoder(json.JSONEncoder): """JSON encoder for specific AT types""" + def default(self, obj): if isinstance(obj, Element): return obj.definition @@ -36,11 +42,15 @@ def save_json(ring: Lattice, filename: Optional[str] = None) -> None: :py:func:`.save_lattice` for a generic lattice-saving function. :py:meth:`.Lattice.save` for a generic lattice-saving method. """ + data = dict( + elements=[element_to_dict(el) for el in save_filter(ring)], + parameters=ring.attrs, + ) if filename is None: - print(json.dumps(("Lattice", ring, ring.attrs), cls=_AtEncoder)) + print(json.dumps(data, cls=_AtEncoder, indent=2)) else: with open(filename, "wt") as jsonfile: - json.dump(("Lattice", ring, ring.attrs), jsonfile, cls=_AtEncoder) + json.dump(data, jsonfile, cls=_AtEncoder, indent=2) def load_json(filename: str, **kwargs) -> Lattice: @@ -50,14 +60,7 @@ def load_json(filename: str, **kwargs) -> Lattice: filename: Name of a JSON file Keyword Args: - name (str): Name of the lattice. Default: taken from - the lattice - energy (float): Energy of the lattice [eV]. Default: taken - from the lattice elements - periodicity(int): Number of periods. Default: taken from the - elements, or 1 - *: All other keywords will be set as Lattice - attributes + *: All keywords update the lattice properties Returns: lattice (Lattice): New :py:class:`.Lattice` object @@ -69,27 +72,22 @@ def load_json(filename: str, **kwargs) -> Lattice: def json_generator(params: dict[str, Any], elem_list): particle_dict = params.pop("particle", {}) params["particle"] = Particle(**particle_dict) - for clname, args, keys in elem_list: - cls = _CLASS_MAP[clname] - yield cls(*args, **keys) + for idx, elem_dict in enumerate(elem_list): + yield element_from_dict(elem_dict, index=idx, check=False) with open(filename, "rt") as jsonfile: data = json.load(jsonfile) try: - code, elems, prms = data - except ValueError: + elements = data["elements"] + parameters = data["parameters"] + except KeyError: raise TypeError("Not a Lattice") - if not ( - isinstance(code, str) - and isinstance(elems, list) - and isinstance(prms, dict) - and (code == "Lattice") - ): + if not (isinstance(elements, list) and isinstance(parameters, dict)): raise TypeError("Not a lattice") - prms.update(kwargs) - return Lattice(elems, iterator=json_generator, **prms) + parameters.update(kwargs) + return Lattice(elements, iterator=json_generator, **parameters) register_format( diff --git a/pyat/at/load/matfile.py b/pyat/at/load/matfile.py index b58152d33..c828da728 100644 --- a/pyat/at/load/matfile.py +++ b/pyat/at/load/matfile.py @@ -12,13 +12,13 @@ from collections.abc import Sequence, Generator from warnings import warn -import numpy +import numpy as np import scipy.io from .allfiles import register_format -from .utils import element_from_dict, element_from_m, RingParam +from .utils import element_from_dict, element_from_m, RingParam, save_filter from .utils import element_to_dict, element_to_m -from ..lattice import Element, Lattice, Filter +from ..lattice import Element, Lattice, Particle, Filter from ..lattice import elements, AtWarning, params_filter, AtError _m2p = { @@ -33,6 +33,18 @@ # Python to Matlab _p2m = {"name", "energy", "periodicity", "particle", "cell_harmnumber", "beam_current"} +# Python to Matlab type translation +_mattype_map = { + int: float, + np.ndarray: lambda attr: np.asanyarray(attr), + Particle: lambda attr: attr.to_dict(), +} + + +def _mat_encoder(v): + """type encoding for .mat files""" + return _mattype_map.get(type(v), lambda attr: attr)(v) + def matfile_generator( params: dict[str, Any], mat_file: str @@ -58,12 +70,12 @@ def matfile_generator( """ def mclean(data): - if data.dtype.type is numpy.str_: + if data.dtype.type is np.str_: # Convert strings in arrays back to strings. return str(data[0]) if data.size > 0 else "" elif data.size == 1: v = data[0, 0] - if issubclass(v.dtype.type, numpy.void): + if issubclass(v.dtype.type, np.void): # Object => Return a dict return {f: mclean(v[f]) for f in v.dtype.fields} else: @@ -71,7 +83,7 @@ def mclean(data): return v else: # Remove any surplus dimensions in arrays. - return numpy.squeeze(data) + return np.squeeze(data) # noinspection PyUnresolvedReferences m = scipy.io.loadmat(params.setdefault("mat_file", mat_file)) @@ -301,12 +313,7 @@ def required(rng): dct = dict(required(ring)) yield RingParam(**dct) - for elem in ring: - if not ( - isinstance(elem, elements.Marker) - and getattr(elem, "tag", None) == "RingParam" - ): - yield elem + yield from save_filter(ring) def save_mat(ring: Lattice, filename: str, mat_key: str = "RING") -> None: @@ -321,7 +328,7 @@ def save_mat(ring: Lattice, filename: str, mat_key: str = "RING") -> None: See Also: :py:func:`.save_lattice` for a generic lattice-saving function. """ - lring = tuple((element_to_dict(elem),) for elem in matlab_ring(ring)) + lring = tuple(element_to_dict(el, encoder=_mat_encoder) for el in matlab_ring(ring)) # noinspection PyUnresolvedReferences scipy.io.savemat(filename, {mat_key: lring}, long_field_names=True) diff --git a/pyat/at/load/reprfile.py b/pyat/at/load/reprfile.py index 045f47528..96802f423 100644 --- a/pyat/at/load/reprfile.py +++ b/pyat/at/load/reprfile.py @@ -1,18 +1,23 @@ """Text representation of a python AT lattice with each element represented by its :py:func:`repr` string """ + from __future__ import print_function + +__all__ = ["load_repr", "save_repr"] + import sys from os.path import abspath from typing import Optional + import numpy + from at.lattice import Lattice -from at.load import register_format -from at.load.utils import element_from_string + # imports necessary in' globals()' for 'eval' from at.lattice import Particle # noqa: F401 - -__all__ = ['load_repr', 'save_repr'] +from at.load import register_format +from at.load.utils import element_from_string def load_repr(filename: str, **kwargs) -> Lattice: @@ -37,8 +42,9 @@ def load_repr(filename: str, **kwargs) -> Lattice: See Also: :py:func:`.load_lattice` for a generic lattice-loading function. """ + def elem_iterator(params, repr_file): - with open(params.setdefault('repr_file', repr_file), 'rt') as file: + with open(params.setdefault("repr_file", repr_file), "rt") as file: # the 1st line is the dictionary of saved lattice parameters for k, v in eval(next(file)).items(): params.setdefault(k, v) @@ -59,25 +65,21 @@ def save_repr(ring: Lattice, filename: Optional[str] = None) -> None: See Also: :py:func:`.save_lattice` for a generic lattice-saving function. """ + def save(file): - # print(repr(dict((k, v) for k, v in vars(ring).items() - # if not k.startswith('_'))), file=file) print(repr(ring.attrs), file=file) for elem in ring: print(repr(elem), file=file) - # Save the current options - opts = numpy.get_printoptions() # Set options to print the full representation of float variables - numpy.set_printoptions(formatter={'float_kind': repr}) - if filename is None: - save(sys.stdout) - else: - with open(filename, 'wt') as reprfile: - save(reprfile) - # Restore the current options - numpy.set_printoptions(**opts) - - -register_format('.repr', load_repr, save_repr, - descr='Text representation of a python AT Lattice') + with numpy.printoptions(formatter={"float_kind": repr}): + if filename is None: + save(sys.stdout) + else: + with open(filename, "wt") as reprfile: + save(reprfile) + + +register_format( + ".repr", load_repr, save_repr, descr="Text representation of a python AT Lattice" +) diff --git a/pyat/at/load/utils.py b/pyat/at/load/utils.py index 446f618f6..3f0dc7f2f 100644 --- a/pyat/at/load/utils.py +++ b/pyat/at/load/utils.py @@ -8,24 +8,30 @@ import os import re import sysconfig -from typing import Optional +from typing import Any, Optional from warnings import warn +from collections.abc import Callable, Generator import numpy as np -# imports necessary in' globals()' for 'eval' +# imports necessary in 'globals()' for 'eval' from numpy import array, uint8, NaN # noqa: F401 from at import integrators from at.lattice import AtWarning from at.lattice import CLASS_MAP, elements as elt -from at.lattice import Particle, Element +from at.lattice import Lattice, Particle, Element, Marker from at.lattice import idtable_element _ext_suffix = sysconfig.get_config_var("EXT_SUFFIX") _relativistic_particle = Particle() +def _no_encoder(v): + """type encoding for .mat files""" + return v + + def _particle(value) -> Particle: if isinstance(value, Particle): # Create from python: save_mat @@ -124,13 +130,6 @@ def __init__( "M66": "Matrix66", } -# Python to Matlab type translation -_mattype_map = { - int: float, - np.ndarray: lambda attr: np.asanyarray(attr), - Particle: lambda attr: attr.to_dict(), -} - _class_to_matfunc = { elt.Dipole: "atsbend", elt.Bend: "atsbend", @@ -139,7 +138,7 @@ def __init__( } -def hasattrs(kwargs: dict, *attributes) -> bool: +def _hasattrs(kwargs: dict, *attributes) -> bool: """Checks the presence of keys in a :py:class:`dict` Returns :py:obj:`True` if any of the ``attributes`` is in ``kwargs`` @@ -159,6 +158,15 @@ def hasattrs(kwargs: dict, *attributes) -> bool: return False +def save_filter(ring: Lattice) -> Generator[Element, None, None]: + for elem in ring: + if not ( + isinstance(elem, Marker) + and getattr(elem, "tag", None) == "RingParam" + ): + yield elem + + def find_class( elem_dict: dict, quiet: bool = False, index: Optional[int] = None ) -> type(Element): @@ -206,7 +214,7 @@ def low_order(key): return class_from_pass else: length = float(elem_dict.get("Length", 0.0)) - if hasattrs( + if _hasattrs( elem_dict, "FullGap", "FringeInt1", @@ -216,7 +224,7 @@ def low_order(key): "ExitAngle", ): return elt.Dipole - elif hasattrs( + elif _hasattrs( elem_dict, "Voltage", "Frequency", @@ -225,16 +233,16 @@ def low_order(key): "TimeLag", ): return elt.RFCavity - elif hasattrs(elem_dict, "Periodicity"): + elif _hasattrs(elem_dict, "Periodicity"): # noinspection PyProtectedMember return RingParam - elif hasattrs(elem_dict, "Limits"): + elif _hasattrs(elem_dict, "Limits"): return elt.Aperture - elif hasattrs(elem_dict, "M66"): + elif _hasattrs(elem_dict, "M66"): return elt.M66 - elif hasattrs(elem_dict, "K"): + elif _hasattrs(elem_dict, "K"): return elt.Quadrupole - elif hasattrs(elem_dict, "PolynomB", "PolynomA"): + elif _hasattrs(elem_dict, "PolynomB", "PolynomA"): loworder = low_order("PolynomB") if loworder == 1: return elt.Quadrupole @@ -246,11 +254,11 @@ def low_order(key): return elt.Multipole else: return elt.ThinMultipole - elif hasattrs(elem_dict, "KickAngle"): + elif _hasattrs(elem_dict, "KickAngle"): return elt.Corrector elif length > 0.0: return elt.Drift - elif hasattrs(elem_dict, "GCR"): + elif _hasattrs(elem_dict, "GCR"): return elt.Monitor elif pass_method == "IdentityPass": return elt.Marker @@ -403,18 +411,17 @@ def convert(value): return cls(*args, **kwargs) -def element_to_dict(elem: Element) -> dict: +def element_to_dict(elem: Element, encoder: Callable[[Any], Any] = _no_encoder) -> dict: """Builds the Matlab structure of an :py:class:`.Element` Parameters: elem: :py:class:`.Element` + encoder: data converter Returns: dct (dict): Dictionary of :py:class:`.Element` attributes """ - dct = dict( - (k, _mattype_map.get(type(v), lambda attr: attr)(v)) for k, v in elem.items() - ) + dct = dict((k, encoder(v)) for k, v in elem.items()) class_name = elem.__class__.__name__ dct["Class"] = _matclass_map.get(class_name, class_name) return dct From a4c6c52b5649e3fedb95d13524f165fd8b19fb1c Mon Sep 17 00:00:00 2001 From: Laurent Farvacque Date: Mon, 25 Mar 2024 13:18:39 +0100 Subject: [PATCH 05/22] improve efficiency --- atmat/lattice/atdivelem.m | 4 ++-- atmat/lattice/atloadlattice.m | 16 ++++++------- atmat/lattice/atwritejson.m | 19 ++++++++------- pyat/at/lattice/elements.py | 11 +++++---- pyat/at/load/json.py | 44 ++++++++++++++++------------------- pyat/at/load/matfile.py | 23 ++++++++++-------- pyat/at/load/utils.py | 37 +++++++++++------------------ 7 files changed, 73 insertions(+), 81 deletions(-) diff --git a/atmat/lattice/atdivelem.m b/atmat/lattice/atdivelem.m index 1ab8899d7..b036fe96f 100644 --- a/atmat/lattice/atdivelem.m +++ b/atmat/lattice/atdivelem.m @@ -37,8 +37,8 @@ line=atsetfieldvalues(line,'ExitAngle',0.0); end if isfield(elem,'KickAngle') - line=atsetfieldvalues(line,'KickAngle',{1,1},el.KickAngle(1,1)*frac(:)/sum(frac)); - line=atsetfieldvalues(line,'KickAngle',{1,2},el.KickAngle(1,2)*frac(:)/sum(frac)); + line=atsetfieldvalues(line,'KickAngle',{1},el.KickAngle(1)*frac(:)/sum(frac)); + line=atsetfieldvalues(line,'KickAngle',{2},el.KickAngle(2)*frac(:)/sum(frac)); end line{1}=mvfield(line{1},entrancef); % Set back entrance fields diff --git a/atmat/lattice/atloadlattice.m b/atmat/lattice/atloadlattice.m index bdb14efdc..2c0190ae8 100644 --- a/atmat/lattice/atloadlattice.m +++ b/atmat/lattice/atloadlattice.m @@ -83,14 +83,14 @@ function [lattice, opts]=load_json(fpath, opts) data=jsondecode(fileread(fpath)); - prms=data.parameters; - name=prms.name; - energy=prms.energy; - periodicity=prms.periodicity; - particle=atparticle.loadobj(prms.particle); - harmnumber=prms.harmonic_number; - prms=rmfield(prms,{'name','energy','periodicity','particle','harmonic_number'}); - args=[fieldnames(prms) struct2cell(prms)]'; + props=data.properties; + name=props.name; + energy=props.energy; + periodicity=props.periodicity; + particle=atparticle.loadobj(props.particle); + harmnumber=props.harmonic_number; + props=rmfield(props,{'name','energy','periodicity','particle','harmonic_number'}); + args=[fieldnames(props) struct2cell(props)]'; lattice=atSetRingProperties(data.elements,... 'FamName', name,... 'Energy', energy,... diff --git a/atmat/lattice/atwritejson.m b/atmat/lattice/atwritejson.m index 76d9b67c2..7da7752c2 100644 --- a/atmat/lattice/atwritejson.m +++ b/atmat/lattice/atwritejson.m @@ -22,9 +22,6 @@ %Check file extension if isempty(ext), ext='.json'; end - % Make fullname - fn=fullfile(pname,[fname ext]); - % Open file to be written [fid,mess]=fopen(fullfile(pname,[fname ext]),'wt'); @@ -42,22 +39,26 @@ function jsondata=sjson(ring) ok=~atgetcells(ring, 'Class', 'RingParam'); data.elements=ring(ok); - data.parameters=get_params(ring); + data.properties=get_params(ring); jsondata=jsonencode(data, 'PrettyPrint', ~compact); end function prms=get_params(ring) + % Get "standard" properties [name, energy, part, periodicity, harmonic_number]=... atGetRingProperties(ring,'FamName', 'Energy', 'Particle',... 'Periodicity', 'HarmNumber'); prms=struct('name', name, 'energy', energy, 'periodicity', periodicity,... 'particle', saveobj(part), 'harmonic_number', harmonic_number); + % Add user-defined properties idx=atlocateparam(ring); - p2=rmfield(ring{idx},{'FamName','PassMethod','Length','Class',... - 'Energy', 'Particle','Periodicity','cell_harmnumber'}); - for nm=fieldnames(p2)' - na=nm{1}; - prms.(na)=p2.(na); + if ~isempty(idx) + p2=rmfield(ring{idx},{'FamName','PassMethod','Length','Class',... + 'Energy', 'Particle','Periodicity','cell_harmnumber'}); + for nm=fieldnames(p2)' + na=nm{1}; + prms.(na)=p2.(na); + end end end diff --git a/pyat/at/lattice/elements.py b/pyat/at/lattice/elements.py index 7c15759f2..d2a4b8873 100644 --- a/pyat/at/lattice/elements.py +++ b/pyat/at/lattice/elements.py @@ -280,11 +280,9 @@ def __setattr__(self, key, value): raise def __str__(self): - attrs = dict(self.items()) return "\n".join( [self.__class__.__name__ + ":"] - + [f"{k:>14}: {attrs.pop(k)!s}" for k in ["FamName", "Length", "PassMethod"]] - + [f"{k:>14}: {v!s}" for k, v in attrs.items()] + + [f"{k:>14}: {v!s}" for k, v in self.items()] ) def __repr__(self): @@ -376,14 +374,17 @@ def definition(self) -> tuple[str, tuple, dict]: defelem = self.__class__(*arguments) keywords = dict( (k, v) - for k, v in sorted(attrs.items()) + for k, v in attrs.items() if not numpy.array_equal(v, getattr(defelem, k, None)) ) return self.__class__.__name__, arguments, keywords def items(self) -> Generator[tuple[str, Any], None, None]: """Iterates through the data members""" - for k, v in vars(self).items(): + v = vars(self).copy() + for k in ["FamName", "Length", "PassMethod"]: + yield k, v.pop(k) + for k, v in sorted(v.items()): yield k, v def is_compatible(self, other: Element) -> bool: diff --git a/pyat/at/load/json.py b/pyat/at/load/json.py index 84aa54598..ec2877c6a 100644 --- a/pyat/at/load/json.py +++ b/pyat/at/load/json.py @@ -6,6 +6,7 @@ __all__ = ["save_json", "load_json"] +from os.path import abspath import json from typing import Optional, Any @@ -21,11 +22,11 @@ class _AtEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, Element): - return obj.definition - elif isinstance(obj, Particle): - return obj.to_dict() + return element_to_dict(obj) elif isinstance(obj, np.ndarray): return obj.tolist() + elif isinstance(obj, Particle): + return obj.to_dict() else: return super().default(obj) @@ -42,10 +43,7 @@ def save_json(ring: Lattice, filename: Optional[str] = None) -> None: :py:func:`.save_lattice` for a generic lattice-saving function. :py:meth:`.Lattice.save` for a generic lattice-saving method. """ - data = dict( - elements=[element_to_dict(el) for el in save_filter(ring)], - parameters=ring.attrs, - ) + data = dict(elements=list(save_filter(ring)), properties=ring.attrs) if filename is None: print(json.dumps(data, cls=_AtEncoder, indent=2)) else: @@ -69,25 +67,23 @@ def load_json(filename: str, **kwargs) -> Lattice: :py:meth:`.Lattice.load` for a generic lattice-loading method. """ - def json_generator(params: dict[str, Any], elem_list): - particle_dict = params.pop("particle", {}) - params["particle"] = Particle(**particle_dict) - for idx, elem_dict in enumerate(elem_list): - yield element_from_dict(elem_dict, index=idx, check=False) - - with open(filename, "rt") as jsonfile: - data = json.load(jsonfile) + def json_generator(params: dict[str, Any], fn): - try: + with open(params.setdefault("json_file", fn), "rt") as jsonfile: + data = json.load(jsonfile) elements = data["elements"] - parameters = data["parameters"] - except KeyError: - raise TypeError("Not a Lattice") - if not (isinstance(elements, list) and isinstance(parameters, dict)): - raise TypeError("Not a lattice") - - parameters.update(kwargs) - return Lattice(elements, iterator=json_generator, **parameters) + try: + properties = data["properties"] + except KeyError: + properties = {} + particle_dict = properties.pop("particle", {}) + params.setdefault("particle", Particle(**particle_dict)) + for k, v in properties.items(): + params.setdefault(k, v) + for idx, elem_dict in enumerate(elements): + yield element_from_dict(elem_dict, index=idx, check=False) + + return Lattice(abspath(filename), iterator=json_generator, **kwargs) register_format( diff --git a/pyat/at/load/matfile.py b/pyat/at/load/matfile.py index c828da728..9719d7dca 100644 --- a/pyat/at/load/matfile.py +++ b/pyat/at/load/matfile.py @@ -21,17 +21,19 @@ from ..lattice import Element, Lattice, Particle, Filter from ..lattice import elements, AtWarning, params_filter, AtError +# Translation of RingParam attributes _m2p = { "FamName": "name", "Energy": "energy", "Periodicity": "periodicity", "Particle": "particle", "cell_harmnumber": "cell_harmnumber", + "beam_current": "beam_current", + "PassMethod": None, + "Length": None, + "cavpts": None, } -_param_ignore = {"PassMethod", "Length", "cavpts"} - -# Python to Matlab -_p2m = {"name", "energy", "periodicity", "particle", "cell_harmnumber", "beam_current"} +_p2m = dict((v, k) for k, v in _m2p.items() if v is not None) # Python to Matlab type translation _mattype_map = { @@ -144,8 +146,9 @@ def ringparam_filter( if isinstance(elem, RingParam): ringparams.append(elem) for k, v in elem.items(): - if k not in _param_ignore: - params.setdefault(_m2p.get(k, k), v) + k2 = _m2p.get(k, k) + if k2 is not None: + params.setdefault(k2, v) if keep_all: pars = vars(elem).copy() name = pars.pop("FamName") @@ -300,14 +303,14 @@ def required(rng): # Public lattice attributes params = dict((k, v) for k, v in vars(rng).items() if not k.startswith("_")) # Output the required attributes/properties - for k in _p2m: + for kp, km in _p2m.items(): try: - v = getattr(rng, k) + v = getattr(rng, kp) except AttributeError: pass else: - params.pop(k, None) - yield k, v + params.pop(kp, None) + yield km, v # Output the remaining attributes yield from params.items() diff --git a/pyat/at/load/utils.py b/pyat/at/load/utils.py index 3f0dc7f2f..afe5a4426 100644 --- a/pyat/at/load/utils.py +++ b/pyat/at/load/utils.py @@ -19,12 +19,11 @@ from at import integrators from at.lattice import AtWarning -from at.lattice import CLASS_MAP, elements as elt +from at.lattice import get_class_map, elements as elt from at.lattice import Lattice, Particle, Element, Marker from at.lattice import idtable_element _ext_suffix = sysconfig.get_config_var("EXT_SUFFIX") -_relativistic_particle = Particle() def _no_encoder(v): @@ -36,10 +35,11 @@ def _particle(value) -> Particle: if isinstance(value, Particle): # Create from python: save_mat return value - else: + elif isinstance(value, dict): # Create from Matlab: load_mat - name = value.pop("name") - return Particle(name, **value) + return Particle(**value) + else: + return Particle(value) def _warn(index: int, message: str, elem_dict: dict) -> None: @@ -64,20 +64,18 @@ class RingParam(elt.Element): elt.Element._conversions, Energy=float, Periodicity=int, Particle=_particle ) + # noinspection PyPep8Naming def __init__( self, - name: str, - energy: float, - periodicity: int = 1, - particle: Particle = _relativistic_particle, + FamName: str, + Energy: float, + Periodicity: int , **kwargs, ): - if not np.isnan(float(energy)): - kwargs.setdefault("Energy", energy) - kwargs.setdefault("Periodicity", periodicity) - kwargs.setdefault("Particle", particle) + if not np.isnan(float(Energy)): + kwargs["Energy"] = Energy kwargs.setdefault("PassMethod", "IdentityPass") - super(RingParam, self).__init__(name, **kwargs) + super(RingParam, self).__init__(FamName, Periodicity=Periodicity, **kwargs) _alias_map = { @@ -96,7 +94,7 @@ def __init__( # Matlab to Python class translation -_CLASS_MAP = dict((k.lower(), v) for k, v in CLASS_MAP.items()) +_CLASS_MAP = dict((k.lower(), v) for k, v in get_class_map().items()) _CLASS_MAP.update(_alias_map) _PASS_MAP = { @@ -116,13 +114,6 @@ def __init__( "GWigSymplecticPass": elt.Wiggler, } -# Matlab to Python attribute translation -_param_to_lattice = { - "Energy": "energy", - "Periodicity": "periodicity", - "FamName": "name", -} - # Python to Matlab class translation _matclass_map = { "Dipole": "Bend", @@ -343,7 +334,7 @@ def element_from_string(elem_string: str) -> Element: Returns: elem (Element): new :py:class:`.Element` """ - return eval(elem_string, globals(), CLASS_MAP) + return eval(elem_string, globals(), get_class_map()) def element_from_m(line: str) -> Element: From 0654e5d40241f2fb70fe0ec1519ad380d9233925 Mon Sep 17 00:00:00 2001 From: Laurent Farvacque Date: Mon, 25 Mar 2024 21:00:26 +0100 Subject: [PATCH 06/22] checks --- pyat/at/load/matfile.py | 6 +++- pyat/at/load/utils.py | 54 +++++++++++++++----------------- pyat/test/test_basic_elements.py | 4 +-- pyat/test/test_load_utils.py | 6 ++-- 4 files changed, 36 insertions(+), 34 deletions(-) diff --git a/pyat/at/load/matfile.py b/pyat/at/load/matfile.py index 9719d7dca..e96b9ad79 100644 --- a/pyat/at/load/matfile.py +++ b/pyat/at/load/matfile.py @@ -331,7 +331,11 @@ def save_mat(ring: Lattice, filename: str, mat_key: str = "RING") -> None: See Also: :py:func:`.save_lattice` for a generic lattice-saving function. """ - lring = tuple(element_to_dict(el, encoder=_mat_encoder) for el in matlab_ring(ring)) +# lring = tuple(element_to_dict(el, encoder=_mat_encoder) for el in matlab_ring(ring)) +# Ensure the lattice is a Matlab column vector + lring = np.array( + [element_to_dict(el, encoder=_mat_encoder) for el in matlab_ring(ring)] + ).reshape(-1, 1) # noinspection PyUnresolvedReferences scipy.io.savemat(filename, {mat_key: lring}, long_field_names=True) diff --git a/pyat/at/load/utils.py b/pyat/at/load/utils.py index afe5a4426..f79b22c9f 100644 --- a/pyat/at/load/utils.py +++ b/pyat/at/load/utils.py @@ -4,6 +4,18 @@ from __future__ import annotations +__all__ = [ + "element_from_dict", + "element_to_dict", + "element_from_m", + "element_to_m", + "element_from_string", + "find_class", + "save_filter", + "split_ignoring_parentheses", + "RingParam", +] + import collections import os import re @@ -69,7 +81,7 @@ def __init__( self, FamName: str, Energy: float, - Periodicity: int , + Periodicity: int, **kwargs, ): if not np.isnan(float(Energy)): @@ -79,6 +91,7 @@ def __init__( _alias_map = { + "bend": elt.Dipole, "rbend": elt.Dipole, "sbend": elt.Dipole, "quad": elt.Quadrupole, @@ -115,17 +128,16 @@ def __init__( } # Python to Matlab class translation -_matclass_map = { +# Default: element_class.__name__ +_mat_class = { "Dipole": "Bend", - "InsertionDeviceKickMap": "InsertionDeviceKickMap", "M66": "Matrix66", } - -_class_to_matfunc = { - elt.Dipole: "atsbend", - elt.Bend: "atsbend", - elt.M66: "atM66", - idtable_element.InsertionDeviceKickMap: "atinsertiondevicekickmap", +# Matlab constructor function +# Default: "".join(("at", element_class.__name__.lower())) +_mat_constructor = { + "Dipole": "atsbend", + "M66": "atM66", } @@ -151,10 +163,7 @@ def _hasattrs(kwargs: dict, *attributes) -> bool: def save_filter(ring: Lattice) -> Generator[Element, None, None]: for elem in ring: - if not ( - isinstance(elem, Marker) - and getattr(elem, "tag", None) == "RingParam" - ): + if not (isinstance(elem, Marker) and getattr(elem, "tag", None) == "RingParam"): yield elem @@ -414,7 +423,7 @@ def element_to_dict(elem: Element, encoder: Callable[[Any], Any] = _no_encoder) """ dct = dict((k, encoder(v)) for k, v in elem.items()) class_name = elem.__class__.__name__ - dct["Class"] = _matclass_map.get(class_name, class_name) + dct["Class"] = _mat_class.get(class_name, class_name) return dct @@ -457,8 +466,8 @@ def convert_array(arr): return repr(arg) def m_name(elclass): - stdname = "".join(("at", elclass.__name__.lower())) - return _class_to_matfunc.get(elclass, stdname) + classname = elclass.__name__ + return _mat_constructor.get(classname, "".join(("at", classname.lower()))) attrs = dict(elem.items()) # noinspection PyProtectedMember @@ -476,19 +485,6 @@ def m_name(elclass): return "{0:>15}({1});...".format(m_name(elem.__class__), ", ".join(argstrs)) -# Kept for compatibility but should be deprecated: - - -CLASS_MAPPING = dict((key, cls.__name__) for (key, cls) in _CLASS_MAP.items()) - -PASS_MAPPING = dict((key, cls.__name__) for (key, cls) in _PASS_MAP.items()) - - -def find_class_name(elem_dict, quiet=False): - """Derive the class name of an Element from its attributes""" - return find_class(elem_dict, quiet=quiet).__name__ - - def split_ignoring_parentheses(string, delimiter): placeholder = "placeholder" substituted = string[:] diff --git a/pyat/test/test_basic_elements.py b/pyat/test/test_basic_elements.py index 94ec5c349..380f4b424 100644 --- a/pyat/test/test_basic_elements.py +++ b/pyat/test/test_basic_elements.py @@ -15,8 +15,8 @@ def test_data_checks(): def test_element_string_ordering(): d = elements.Drift('D0', 1, attr=numpy.array(0)) - assert d.__str__() == ("Drift:\n\tFamName : D0\n\tLength : 1.0\n" - "\tPassMethod : DriftPass\n\tattr : 0") + assert d.__str__() == ("Drift:\n FamName: D0\n Length: 1.0\n" + " PassMethod: DriftPass\n attr: 0") assert d.__repr__() == "Drift('D0', 1.0, attr=array(0))" diff --git a/pyat/test/test_load_utils.py b/pyat/test/test_load_utils.py index 470328756..78c5de551 100644 --- a/pyat/test/test_load_utils.py +++ b/pyat/test/test_load_utils.py @@ -5,8 +5,10 @@ from at import AtWarning from at.lattice import Lattice, elements, params_filter, no_filter from at.load.utils import find_class, element_from_dict +# noinspection PyProtectedMember from at.load.utils import _CLASS_MAP, _PASS_MAP from at.load.utils import RingParam, split_ignoring_parentheses +# noinspection PyProtectedMember from at.load.matfile import ringparam_filter @@ -66,8 +68,8 @@ def test_inconsistent_energy_values_warns_correctly(): def test_more_than_one_RingParam_in_ring_raises_warning(): - p1 = RingParam('rp1', 5.e+6) - p2 = RingParam('rp2', 3.e+6) + p1 = RingParam('rp1', 5.e+6, 1) + p2 = RingParam('rp2', 3.e+6, 1) with pytest.warns(AtWarning): params = _matlab_scanner([p1, p2]) assert params['_energy'] == 5.e+6 From 3d0a9a6d1b7669bd8aff32b42555ddb9db758075 Mon Sep 17 00:00:00 2001 From: Laurent Farvacque Date: Mon, 25 Mar 2024 21:09:48 +0100 Subject: [PATCH 07/22] checks --- pyat/at/load/matfile.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/pyat/at/load/matfile.py b/pyat/at/load/matfile.py index e96b9ad79..f617e4794 100644 --- a/pyat/at/load/matfile.py +++ b/pyat/at/load/matfile.py @@ -331,12 +331,8 @@ def save_mat(ring: Lattice, filename: str, mat_key: str = "RING") -> None: See Also: :py:func:`.save_lattice` for a generic lattice-saving function. """ -# lring = tuple(element_to_dict(el, encoder=_mat_encoder) for el in matlab_ring(ring)) -# Ensure the lattice is a Matlab column vector - lring = np.array( - [element_to_dict(el, encoder=_mat_encoder) for el in matlab_ring(ring)] - ).reshape(-1, 1) - # noinspection PyUnresolvedReferences + # Ensure the lattice is a Matlab column vector: list(list) + lring = [[element_to_dict(el, encoder=_mat_encoder)] for el in matlab_ring(ring)] scipy.io.savemat(filename, {mat_key: lring}, long_field_names=True) From f75a4ef1f1e8a294b652b25b2f35aa8b63e11836 Mon Sep 17 00:00:00 2001 From: Laurent Farvacque Date: Tue, 26 Mar 2024 11:18:31 +0100 Subject: [PATCH 08/22] checks --- atmat/lattice/atwritejson.m | 6 ++++-- pyat/at/load/json.py | 11 ++++++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/atmat/lattice/atwritejson.m b/atmat/lattice/atwritejson.m index 7da7752c2..524e06444 100644 --- a/atmat/lattice/atwritejson.m +++ b/atmat/lattice/atwritejson.m @@ -53,8 +53,10 @@ % Add user-defined properties idx=atlocateparam(ring); if ~isempty(idx) - p2=rmfield(ring{idx},{'FamName','PassMethod','Length','Class',... - 'Energy', 'Particle','Periodicity','cell_harmnumber'}); + flist={'FamName','PassMethod','Length','Class',... + 'Energy', 'Particle','Periodicity','cell_harmnumber'}; + present=isfield(ring{idx}, flist); + p2=rmfield(ring{idx},flist(present)); for nm=fieldnames(p2)' na=nm{1}; prms.(na)=p2.(na); diff --git a/pyat/at/load/json.py b/pyat/at/load/json.py index ec2877c6a..68d0d0777 100644 --- a/pyat/at/load/json.py +++ b/pyat/at/load/json.py @@ -31,24 +31,29 @@ def default(self, obj): return super().default(obj) -def save_json(ring: Lattice, filename: Optional[str] = None) -> None: +def save_json( + ring: Lattice, filename: Optional[str] = None, compact: bool = False +) -> None: """Save a :py:class:`.Lattice` as a JSON file Parameters: ring: Lattice description filename: Name of the JSON file. Default: outputs on :py:obj:`sys.stdout` + compact: If :py:obj:`False` (default), the JSON file is pretty-printed + with line feeds and indentation. Otherwise, the output is a single line. See Also: :py:func:`.save_lattice` for a generic lattice-saving function. :py:meth:`.Lattice.save` for a generic lattice-saving method. """ + indent = None if compact else 2 data = dict(elements=list(save_filter(ring)), properties=ring.attrs) if filename is None: - print(json.dumps(data, cls=_AtEncoder, indent=2)) + print(json.dumps(data, cls=_AtEncoder, indent=indent)) else: with open(filename, "wt") as jsonfile: - json.dump(data, jsonfile, cls=_AtEncoder, indent=2) + json.dump(data, jsonfile, cls=_AtEncoder, indent=indent) def load_json(filename: str, **kwargs) -> Lattice: From f865aa546bd353250357bbf36de5d0a3bd037292 Mon Sep 17 00:00:00 2001 From: Laurent Farvacque Date: Tue, 26 Mar 2024 18:43:38 +0100 Subject: [PATCH 09/22] cleaned lattice attributes in files --- pyat/at/load/json.py | 6 +++--- pyat/at/load/matfile.py | 34 +++++++++++++++++++++++++++------- pyat/at/load/reprfile.py | 8 ++++---- pyat/at/load/utils.py | 22 ++++++++++++++++++++-- 4 files changed, 54 insertions(+), 16 deletions(-) diff --git a/pyat/at/load/json.py b/pyat/at/load/json.py index 68d0d0777..3d2991cbe 100644 --- a/pyat/at/load/json.py +++ b/pyat/at/load/json.py @@ -13,7 +13,7 @@ import numpy as np from .allfiles import register_format -from .utils import element_to_dict, element_from_dict, save_filter +from .utils import element_to_dict, element_from_dict, keep_elements, keep_attributes from ..lattice import Element, Lattice, Particle @@ -48,7 +48,7 @@ def save_json( :py:meth:`.Lattice.save` for a generic lattice-saving method. """ indent = None if compact else 2 - data = dict(elements=list(save_filter(ring)), properties=ring.attrs) + data = dict(elements=list(keep_elements(ring)), properties=keep_attributes(ring)) if filename is None: print(json.dumps(data, cls=_AtEncoder, indent=indent)) else: @@ -74,7 +74,7 @@ def load_json(filename: str, **kwargs) -> Lattice: def json_generator(params: dict[str, Any], fn): - with open(params.setdefault("json_file", fn), "rt") as jsonfile: + with open(params.setdefault("in_file", fn), "rt") as jsonfile: data = json.load(jsonfile) elements = data["elements"] try: diff --git a/pyat/at/load/matfile.py b/pyat/at/load/matfile.py index f617e4794..cf87b88f2 100644 --- a/pyat/at/load/matfile.py +++ b/pyat/at/load/matfile.py @@ -7,6 +7,7 @@ __all__ = ["load_mat", "save_mat", "load_m", "save_m", "load_var"] import sys +import os from os.path import abspath, basename, splitext from typing import Optional, Any from collections.abc import Sequence, Generator @@ -16,8 +17,9 @@ import scipy.io from .allfiles import register_format -from .utils import element_from_dict, element_from_m, RingParam, save_filter +from .utils import element_from_dict, element_from_m, RingParam, keep_elements from .utils import element_to_dict, element_to_m +from .utils import _drop_attrs from ..lattice import Element, Lattice, Particle, Filter from ..lattice import elements, AtWarning, params_filter, AtError @@ -27,13 +29,15 @@ "Energy": "energy", "Periodicity": "periodicity", "Particle": "particle", - "cell_harmnumber": "cell_harmnumber", - "beam_current": "beam_current", + "cell_harmnumber": "cell_harmnumber", # necessary: property + "beam_current": "beam_current", # necessary: property "PassMethod": None, "Length": None, "cavpts": None, } _p2m = dict((v, k) for k, v in _m2p.items() if v is not None) +# Attribute to drop when writing a file +_p2m.update(_drop_attrs) # Python to Matlab type translation _mattype_map = { @@ -88,7 +92,7 @@ def mclean(data): return np.squeeze(data) # noinspection PyUnresolvedReferences - m = scipy.io.loadmat(params.setdefault("mat_file", mat_file)) + m = scipy.io.loadmat(params.setdefault("in_file", mat_file)) matvars = [varname for varname in m if not varname.startswith("__")] default_key = matvars[0] if (len(matvars) == 1) else "RING" key = params.setdefault("mat_key", default_key) @@ -213,7 +217,7 @@ def mfile_generator(params: dict, m_file: str) -> Generator[Element, None, None] Yields: elem (Element): new Elements """ - with open(params.setdefault("m_file", m_file), "rt") as file: + with open(params.setdefault("in_file", m_file), "rt") as file: _ = next(file) # Matlab function definition _ = next(file) # Cell array opening for lineno, line in enumerate(file): @@ -310,13 +314,14 @@ def required(rng): pass else: params.pop(kp, None) - yield km, v + if km is not None: + yield km, v # Output the remaining attributes yield from params.items() dct = dict(required(ring)) yield RingParam(**dct) - yield from save_filter(ring) + yield from keep_elements(ring) def save_mat(ring: Lattice, filename: str, mat_key: str = "RING") -> None: @@ -364,5 +369,20 @@ def save(file): print("end", file=mfile) +def _mat_file(ring): + """.mat input file""" + try: + in_file = ring.in_file + except AttributeError: + raise AttributeError("'Lattice' object has no attribute 'mat_file'") + else: + _, ext = os.path.splitext(in_file) + if ext != ".mat": + raise AttributeError("'Lattice' object has no attribute 'mat_file'") + return in_file + + register_format(".mat", load_mat, save_mat, descr="Matlab binary mat-file") register_format(".m", load_m, save_m, descr="Matlab text m-file") + +Lattice.mat_file = property(_mat_file, None, None) diff --git a/pyat/at/load/reprfile.py b/pyat/at/load/reprfile.py index 96802f423..90e40078f 100644 --- a/pyat/at/load/reprfile.py +++ b/pyat/at/load/reprfile.py @@ -17,7 +17,7 @@ # imports necessary in' globals()' for 'eval' from at.lattice import Particle # noqa: F401 from at.load import register_format -from at.load.utils import element_from_string +from at.load.utils import element_from_string, keep_attributes, keep_elements def load_repr(filename: str, **kwargs) -> Lattice: @@ -44,7 +44,7 @@ def load_repr(filename: str, **kwargs) -> Lattice: """ def elem_iterator(params, repr_file): - with open(params.setdefault("repr_file", repr_file), "rt") as file: + with open(params.setdefault("in_file", repr_file), "rt") as file: # the 1st line is the dictionary of saved lattice parameters for k, v in eval(next(file)).items(): params.setdefault(k, v) @@ -67,8 +67,8 @@ def save_repr(ring: Lattice, filename: Optional[str] = None) -> None: """ def save(file): - print(repr(ring.attrs), file=file) - for elem in ring: + print(repr(keep_attributes(ring)), file=file) + for elem in keep_elements(ring): print(repr(elem), file=file) # Set options to print the full representation of float variables diff --git a/pyat/at/load/utils.py b/pyat/at/load/utils.py index f79b22c9f..e4245f21d 100644 --- a/pyat/at/load/utils.py +++ b/pyat/at/load/utils.py @@ -11,7 +11,8 @@ "element_to_m", "element_from_string", "find_class", - "save_filter", + "keep_elements", + "keep_attributes", "split_ignoring_parentheses", "RingParam", ] @@ -140,6 +141,15 @@ def __init__( "M66": "atM66", } +# Lattice attributes which must be dropped when writing a file +_drop_attrs = { + "in_file": None, + "mat_key": None, + "mat_file": None, # Not used anymore... + "m_file": None, + "repr_file": None, +} + def _hasattrs(kwargs: dict, *attributes) -> bool: """Checks the presence of keys in a :py:class:`dict` @@ -161,7 +171,15 @@ def _hasattrs(kwargs: dict, *attributes) -> bool: return False -def save_filter(ring: Lattice) -> Generator[Element, None, None]: +def keep_attributes(ring: Lattice): + """Remove Lattice attributes which must not be saved on file""" + return dict( + (k, v) for k, v in ring.attrs.items() if _drop_attrs.get(k, k) is not None + ) + + +def keep_elements(ring: Lattice) -> Generator[Element, None, None]: + """Remove the 'RingParam' Marker""" for elem in ring: if not (isinstance(elem, Marker) and getattr(elem, "tag", None) == "RingParam"): yield elem From eac54738ba45f9017bfaf9de38016d610f36cb0a Mon Sep 17 00:00:00 2001 From: Laurent Farvacque Date: Wed, 27 Mar 2024 15:02:16 +0100 Subject: [PATCH 10/22] restructured files --- pyat/at/load/json.py | 4 +- pyat/at/load/matfile.py | 155 +++++++++++++++-- pyat/at/load/reprfile.py | 34 +++- pyat/at/load/utils.py | 350 +++++++++++++-------------------------- 4 files changed, 289 insertions(+), 254 deletions(-) diff --git a/pyat/at/load/json.py b/pyat/at/load/json.py index 3d2991cbe..49c0bfc5a 100644 --- a/pyat/at/load/json.py +++ b/pyat/at/load/json.py @@ -13,7 +13,7 @@ import numpy as np from .allfiles import register_format -from .utils import element_to_dict, element_from_dict, keep_elements, keep_attributes +from .utils import element_from_dict, keep_elements, keep_attributes from ..lattice import Element, Lattice, Particle @@ -22,7 +22,7 @@ class _AtEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, Element): - return element_to_dict(obj) + return obj.to_dict() elif isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, Particle): diff --git a/pyat/at/load/matfile.py b/pyat/at/load/matfile.py index cf87b88f2..03c92fe69 100644 --- a/pyat/at/load/matfile.py +++ b/pyat/at/load/matfile.py @@ -16,10 +16,13 @@ import numpy as np import scipy.io +# imports necessary in 'globals()' for 'eval' +from numpy import array, uint8, NaN # noqa: F401 + from .allfiles import register_format -from .utils import element_from_dict, element_from_m, RingParam, keep_elements -from .utils import element_to_dict, element_to_m -from .utils import _drop_attrs +from .utils import split_ignoring_parentheses, RingParam, keep_elements +from .utils import element_from_dict +from .utils import _drop_attrs, _CLASS_MAP from ..lattice import Element, Lattice, Particle, Filter from ..lattice import elements, AtWarning, params_filter, AtError @@ -45,6 +48,12 @@ np.ndarray: lambda attr: np.asanyarray(attr), Particle: lambda attr: attr.to_dict(), } +# Matlab constructor function +# Default: "".join(("at", element_class.__name__.lower())) +_mat_constructor = { + "Dipole": "atsbend", + "M66": "atM66", +} def _mat_encoder(v): @@ -95,11 +104,11 @@ def mclean(data): m = scipy.io.loadmat(params.setdefault("in_file", mat_file)) matvars = [varname for varname in m if not varname.startswith("__")] default_key = matvars[0] if (len(matvars) == 1) else "RING" - key = params.setdefault("mat_key", default_key) + key = params.setdefault("use", default_key) if key not in m.keys(): kok = [k for k in m.keys() if "__" not in k] raise AtError( - "Selected mat_key does not exist, please select in: {}".format(kok) + f"Selected '{key}' variable does not exist, please select in: {kok}" ) check = params.pop("check", True) quiet = params.pop("quiet", False) @@ -172,9 +181,10 @@ def load_mat(filename: str, **kwargs) -> Lattice: filename: Name of a '.mat' file Keyword Args: - mat_key (str): Name of the Matlab variable containing + use (str): Name of the Matlab variable containing the lattice. Default: Matlab variable name if there is only one, otherwise 'RING' + mat_key (str): alias for *use* check (bool): Run the coherence tests. Default: :py:obj:`True` quiet (bool): Suppress the warning for non-standard @@ -197,7 +207,9 @@ def load_mat(filename: str, **kwargs) -> Lattice: :py:func:`.load_lattice` for a generic lattice-loading function. """ if "key" in kwargs: # process the deprecated 'key' keyword - kwargs.setdefault("mat_key", kwargs.pop("key")) + kwargs.setdefault("use", kwargs.pop("key")) + if "mat_key" in kwargs: # process the deprecated 'mat_key' keyword + kwargs.setdefault("use", kwargs.pop("key")) return Lattice( ringparam_filter, matfile_generator, @@ -207,6 +219,71 @@ def load_mat(filename: str, **kwargs) -> Lattice: ) +def _element_from_m(line: str) -> Element: + """Builds an :py:class:`.Element` from a line in an m-file + + Parameters: + line: Matlab string representation of an :py:class:`.Element` + + Returns: + elem (Element): new :py:class:`.Element` + """ + + def argsplit(value): + return [a.strip() for a in split_ignoring_parentheses(value, ",")] + + def makedir(mat_struct): + """Build directory from Matlab struct arguments""" + + def pairs(it): + while True: + try: + a = next(it) + except StopIteration: + break + yield eval(a), convert(next(it)) + + return dict(pairs(iter(mat_struct))) + + def makearray(mat_arr): + """Build numpy array for Matlab array syntax""" + + def arraystr(arr): + lns = arr.split(";") + rr = [arraystr(v) for v in lns] if len(lns) > 1 else lns[0].split() + return f"[{', '.join(rr)}]" + + return eval(f"array({arraystr(mat_arr)})") + + def convert(value): + """convert Matlab syntax to numpy syntax""" + if value.startswith("["): + result = makearray(value[1:-1]) + elif value.startswith("struct"): + result = makedir(argsplit(value[7:-1])) + else: + result = eval(value) + return result + + left = line.index("(") + right = line.rindex(")") + matcls = line[:left].strip()[2:] + cls = _CLASS_MAP[matcls] + arguments = argsplit(line[left + 1 : right]) + ll = len(cls._BUILD_ATTRIBUTES) + if ll < len(arguments) and arguments[ll].endswith("Pass'"): + arguments.insert(ll, "'PassMethod'") + args = [convert(v) for v in arguments[:ll]] + kwargs = makedir(arguments[ll:]) + if matcls == "rbend": + # the Matlab 'rbend' has no equivalent in PyAT. This adds parameters + # necessary for using the python sector bend + halfangle = 0.5 * args[2] + kwargs.setdefault("EntranceAngle", halfangle) + kwargs.setdefault("ExitAngle", halfangle) + return cls(*args, **kwargs) + + def mfile_generator(params: dict, m_file: str) -> Generator[Element, None, None]: """Run through the lines of a Matlab m-file and generate AT elements @@ -224,7 +301,7 @@ def mfile_generator(params: dict, m_file: str) -> Generator[Element, None, None] if line.startswith("};"): break try: - elem = element_from_m(line) + elem = _element_from_m(line) except ValueError: warn(AtWarning("Invalid line {0} skipped.".format(lineno))) continue @@ -337,10 +414,68 @@ def save_mat(ring: Lattice, filename: str, mat_key: str = "RING") -> None: :py:func:`.save_lattice` for a generic lattice-saving function. """ # Ensure the lattice is a Matlab column vector: list(list) - lring = [[element_to_dict(el, encoder=_mat_encoder)] for el in matlab_ring(ring)] + lring = [[el.to_dict(encoder=_mat_encoder)] for el in matlab_ring(ring)] scipy.io.savemat(filename, {mat_key: lring}, long_field_names=True) +def _element_to_m(elem: Element) -> str: + """Builds the Matlab-evaluable string for an :py:class:`.Element` + + Parameters: + elem: :py:class:`.Element` + + Returns: + mstr (str): Matlab string representation of the + :py:class:`.Element` attributes + """ + + def convert(arg): + def convert_dict(pdir): + def scan(d): + for k, v in d.items(): + yield convert(k) + yield convert(v) + + return "struct({0})".format(", ".join(scan(pdir))) + + def convert_array(arr): + if arr.ndim > 1: + lns = (str(list(ln)).replace(",", "")[1:-1] for ln in arr) + return "".join(("[", "; ".join(lns), "]")) + elif arr.ndim > 0: + return str(list(arr)).replace(",", "") + else: + return str(arr) + + if isinstance(arg, np.ndarray): + return convert_array(arg) + elif isinstance(arg, dict): + return convert_dict(arg) + elif isinstance(arg, Particle): + return convert_dict(arg.to_dict()) + else: + return repr(arg) + + def m_name(elclass): + classname = elclass.__name__ + return _mat_constructor.get(classname, "".join(("at", classname.lower()))) + + attrs = dict(elem.items()) + # noinspection PyProtectedMember + args = [attrs.pop(k, getattr(elem, k)) for k in elem._BUILD_ATTRIBUTES] + defelem = elem.__class__(*args) + kwds = dict( + (k, v) + for k, v in attrs.items() + if not np.array_equal(v, getattr(defelem, k, None)) + ) + argstrs = [convert(arg) for arg in args] + if "PassMethod" in kwds: + argstrs.append(convert(kwds.pop("PassMethod"))) + argstrs += [", ".join((repr(k), convert(v))) for k, v in kwds.items()] + return "{0:>15}({1});...".format(m_name(elem.__class__), ", ".join(argstrs)) + + def save_m(ring: Lattice, filename: Optional[str] = None) -> None: """Save a :py:class:`.Lattice` as a Matlab m-file @@ -356,7 +491,7 @@ def save_m(ring: Lattice, filename: Optional[str] = None) -> None: def save(file): print("ring = {...", file=file) for elem in matlab_ring(ring): - print(element_to_m(elem), file=file) + print(_element_to_m(elem), file=file) print("};", file=file) if filename is None: diff --git a/pyat/at/load/reprfile.py b/pyat/at/load/reprfile.py index 90e40078f..6b7f7028e 100644 --- a/pyat/at/load/reprfile.py +++ b/pyat/at/load/reprfile.py @@ -10,14 +10,32 @@ from os.path import abspath from typing import Optional -import numpy +import numpy as np -from at.lattice import Lattice +# imports necessary in 'globals()' for 'eval' +from numpy import array, uint8, NaN # noqa: F401 + +from at.lattice import Lattice, Element # imports necessary in' globals()' for 'eval' from at.lattice import Particle # noqa: F401 from at.load import register_format -from at.load.utils import element_from_string, keep_attributes, keep_elements +from at.load.utils import element_classes, keep_attributes, keep_elements + +# Map class names to Element classes +_CLASS_MAP = dict((cls.__name__, cls) for cls in element_classes()) + + +def _element_from_string(elem_string: str) -> Element: + """Builds an :py:class:`.Element` from its python :py:func:`repr` string + + Parameters: + elem_string: String representation of an :py:class:`.Element` + + Returns: + elem (Element): new :py:class:`.Element` + """ + return eval(elem_string, globals(), _CLASS_MAP) def load_repr(filename: str, **kwargs) -> Lattice: @@ -49,7 +67,7 @@ def elem_iterator(params, repr_file): for k, v in eval(next(file)).items(): params.setdefault(k, v) for line in file: - yield element_from_string(line.strip()) + yield _element_from_string(line.strip()) return Lattice(abspath(filename), iterator=elem_iterator, **kwargs) @@ -72,7 +90,7 @@ def save(file): print(repr(elem), file=file) # Set options to print the full representation of float variables - with numpy.printoptions(formatter={"float_kind": repr}): + with np.printoptions(formatter={"float_kind": repr}): if filename is None: save(sys.stdout) else: @@ -81,5 +99,9 @@ def save(file): register_format( - ".repr", load_repr, save_repr, descr="Text representation of a python AT Lattice" + ".repr", + load_repr, + save_repr, + descr=("Text representation of a python AT Lattice. " + "See also :py:func:`.load_repr`."), ) diff --git a/pyat/at/load/utils.py b/pyat/at/load/utils.py index e4245f21d..e968eb8b2 100644 --- a/pyat/at/load/utils.py +++ b/pyat/at/load/utils.py @@ -5,11 +5,9 @@ from __future__ import annotations __all__ = [ + "element_classes", "element_from_dict", "element_to_dict", - "element_from_m", - "element_to_m", - "element_from_string", "find_class", "keep_elements", "keep_attributes", @@ -27,12 +25,9 @@ import numpy as np -# imports necessary in 'globals()' for 'eval' -from numpy import array, uint8, NaN # noqa: F401 - from at import integrators from at.lattice import AtWarning -from at.lattice import get_class_map, elements as elt +from at.lattice import elements as elt from at.lattice import Lattice, Particle, Element, Marker from at.lattice import idtable_element @@ -62,6 +57,20 @@ def _warn(index: int, message: str, elem_dict: dict) -> None: warn(AtWarning(warning), stacklevel=2) +def element_classes() -> frozenset[type[Element]]: + """Build a set of all Element subclasses""" + + # Misses class aliases (Bend, Matrix66) + def subclasses_recursive(cl): + direct = cl.__subclasses__() + indirect = [] + for subclass in direct: + indirect.extend(subclasses_recursive(subclass)) + return frozenset([cl] + direct + indirect) + + return subclasses_recursive(Element) + + class RingParam(elt.Element): """Private class for Matlab RingParam element @@ -102,15 +111,14 @@ def __init__( "ap": elt.Aperture, "ringparam": RingParam, "wig": elt.Wiggler, - "insertiondevicekickmap": idtable_element.InsertionDeviceKickMap, "matrix66": elt.M66, } - -# Matlab to Python class translation -_CLASS_MAP = dict((k.lower(), v) for k, v in get_class_map().items()) +# Map class names to Element classes +_CLASS_MAP = dict((cls.__name__.lower(), cls) for cls in element_classes()) _CLASS_MAP.update(_alias_map) +# Maps passmethods to Element classes _PASS_MAP = { "BendLinearPass": elt.Dipole, "BndMPoleSymplectic4RadPass": elt.Dipole, @@ -128,22 +136,17 @@ def __init__( "GWigSymplecticPass": elt.Wiggler, } -# Python to Matlab class translation +# Maps python class name to Matlab class # Default: element_class.__name__ _mat_class = { "Dipole": "Bend", "M66": "Matrix66", } -# Matlab constructor function -# Default: "".join(("at", element_class.__name__.lower())) -_mat_constructor = { - "Dipole": "atsbend", - "M66": "atM66", -} # Lattice attributes which must be dropped when writing a file _drop_attrs = { "in_file": None, + "use": None, "mat_key": None, "mat_file": None, # Not used anymore... "m_file": None, @@ -185,10 +188,65 @@ def keep_elements(ring: Lattice) -> Generator[Element, None, None]: yield elem +def _from_contents(elem: dict) -> type[Element]: + """Deduce the element class from its contents""" + + def low_order(key): + polynom = np.array(elem[key], dtype=np.float64).reshape(-1) + try: + low = np.where(polynom != 0.0)[0][0] + except IndexError: + low = -1 + return low + + length = float(elem.get("Length", 0.0)) + pass_method = elem.get("PassMethod", "") + if _hasattrs( + elem, "FullGap", "FringeInt1", "FringeInt2", "gK", "EntranceAngle", "ExitAngle" + ): + return elt.Dipole + elif _hasattrs(elem, "Voltage", "Frequency", "HarmNumber", "PhaseLag", "TimeLag"): + return elt.RFCavity + elif _hasattrs(elem, "Periodicity"): + # noinspection PyProtectedMember + return RingParam + elif _hasattrs(elem, "Limits"): + return elt.Aperture + elif _hasattrs(elem, "M66"): + return elt.M66 + elif _hasattrs(elem, "K"): + return elt.Quadrupole + elif _hasattrs(elem, "PolynomB", "PolynomA"): + loworder = low_order("PolynomB") + if loworder == 1: + return elt.Quadrupole + elif loworder == 2: + return elt.Sextupole + elif loworder == 3: + return elt.Octupole + elif pass_method.startswith("StrMPoleSymplectic4") or (length > 0): + return elt.Multipole + else: + return elt.ThinMultipole + elif _hasattrs(elem, "KickAngle"): + return elt.Corrector + elif length > 0.0: + return elt.Drift + elif _hasattrs(elem, "GCR"): + return elt.Monitor + elif pass_method == "IdentityPass": + return elt.Marker + else: + return elt.Element + + def find_class( elem_dict: dict, quiet: bool = False, index: Optional[int] = None ) -> type(Element): - """Identify the class of an element from its attributes + """Deduce the class of an element from its attributes + + `find_class` looks first at the "Class" field, if existing. It then tries to deduce + the class from "FamName", from "PassMethod", and finally form the element contents. Args: elem_dict: The dictionary of keyword arguments passed to the @@ -200,88 +258,40 @@ def find_class( element_class: The guessed Class name """ - def low_order(key): - polynom = np.array(elem_dict[key], dtype=np.float64).reshape(-1) - try: - low = np.where(polynom != 0.0)[0][0] - except IndexError: - low = -1 - return low - - class_name = elem_dict.pop("Class", "") - try: - return _CLASS_MAP[class_name.lower()] - except KeyError: - if not quiet and class_name: - _warn(index, f"Class '{class_name}' does not exist.", elem_dict) - fam_name = elem_dict.get("FamName", "") - try: - return _CLASS_MAP[fam_name.lower()] - except KeyError: - pass_method = elem_dict.get("PassMethod", "") - if not quiet and not pass_method: - _warn(index, "No PassMethod provided.", elem_dict) - elif not quiet and not pass_method.endswith("Pass"): - message = ( - f"Invalid PassMethod '{pass_method}', " - "provided pass methods should end in 'Pass'." - ) - _warn(index, message, elem_dict) - class_from_pass = _PASS_MAP.get(pass_method) - if class_from_pass is not None: - return class_from_pass - else: - length = float(elem_dict.get("Length", 0.0)) - if _hasattrs( - elem_dict, - "FullGap", - "FringeInt1", - "FringeInt2", - "gK", - "EntranceAngle", - "ExitAngle", - ): - return elt.Dipole - elif _hasattrs( - elem_dict, - "Voltage", - "Frequency", - "HarmNumber", - "PhaseLag", - "TimeLag", - ): - return elt.RFCavity - elif _hasattrs(elem_dict, "Periodicity"): - # noinspection PyProtectedMember - return RingParam - elif _hasattrs(elem_dict, "Limits"): - return elt.Aperture - elif _hasattrs(elem_dict, "M66"): - return elt.M66 - elif _hasattrs(elem_dict, "K"): - return elt.Quadrupole - elif _hasattrs(elem_dict, "PolynomB", "PolynomA"): - loworder = low_order("PolynomB") - if loworder == 1: - return elt.Quadrupole - elif loworder == 2: - return elt.Sextupole - elif loworder == 3: - return elt.Octupole - elif pass_method.startswith("StrMPoleSymplectic4") or (length > 0): - return elt.Multipole - else: - return elt.ThinMultipole - elif _hasattrs(elem_dict, "KickAngle"): - return elt.Corrector - elif length > 0.0: - return elt.Drift - elif _hasattrs(elem_dict, "GCR"): - return elt.Monitor - elif pass_method == "IdentityPass": - return elt.Marker - else: - return elt.Element + def check_class(clname): + if clname: + _warn(index, f"Class '{clname}' does not exist.", elem_dict) + + def check_pass(passm): + if not passm: + _warn(index, "No PassMethod provided.", elem_dict) + elif not passm.endswith("Pass"): + message = ( + f"Invalid PassMethod '{passm}': " + "provided pass methods should end in 'Pass'." + ) + _warn(index, message, elem_dict) + + class_name = elem_dict.pop("Class", "") # try from class name + cls = _CLASS_MAP.get(class_name.lower(), None) + if cls is not None: + return cls + elif not quiet: + check_class(class_name) + + elname = elem_dict.get("FamName", "") # try from element name + cls = _CLASS_MAP.get(elname.lower(), None) + if cls is not None: + return cls + + pass_method = elem_dict.get("PassMethod", "") # try from passmethod + cls = _PASS_MAP.get(pass_method, None) + if cls is not None: + return cls + elif not quiet: + check_pass(pass_method) + + return _from_contents(elem_dict) # look for contents def element_from_dict( @@ -352,83 +362,6 @@ def sanitise_class(index, cls, elem_dict): return element -def element_from_string(elem_string: str) -> Element: - """Builds an :py:class:`.Element` from its python :py:func:`repr` string - - Parameters: - elem_string: String representation of an :py:class:`.Element` - - Returns: - elem (Element): new :py:class:`.Element` - """ - return eval(elem_string, globals(), get_class_map()) - - -def element_from_m(line: str) -> Element: - """Builds an :py:class:`.Element` from a line in an m-file - - Parameters: - line: Matlab string representation of an :py:class:`.Element` - - Returns: - elem (Element): new :py:class:`.Element` - """ - - def argsplit(value): - return [a.strip() for a in split_ignoring_parentheses(value, ",")] - - def makedir(mat_struct): - """Build directory from Matlab struct arguments""" - - def pairs(it): - while True: - try: - a = next(it) - except StopIteration: - break - yield eval(a), convert(next(it)) - - return dict(pairs(iter(mat_struct))) - - def makearray(mat_arr): - """Build numpy array for Matlab array syntax""" - - def arraystr(arr): - lns = arr.split(";") - rr = [arraystr(v) for v in lns] if len(lns) > 1 else lns[0].split() - return "[{0}]".format(", ".join(rr)) - - return eval("array({0})".format(arraystr(mat_arr))) - - def convert(value): - """convert Matlab syntax to numpy syntax""" - if value.startswith("["): - result = makearray(value[1:-1]) - elif value.startswith("struct"): - result = makedir(argsplit(value[7:-1])) - else: - result = eval(value) - return result - - left = line.index("(") - right = line.rindex(")") - matcls = line[:left].strip()[2:] - cls = _CLASS_MAP[matcls] - arguments = argsplit(line[left + 1 : right]) - ll = len(cls._BUILD_ATTRIBUTES) - if ll < len(arguments) and arguments[ll].endswith("Pass'"): - arguments.insert(ll, "'PassMethod'") - args = [convert(v) for v in arguments[:ll]] - kwargs = makedir(arguments[ll:]) - if matcls == "rbend": - # the Matlab 'rbend' has no equivalent in PyAT. This adds parameters - # necessary for using the python sector bend - halfangle = 0.5 * args[2] - kwargs.setdefault("EntranceAngle", halfangle) - kwargs.setdefault("ExitAngle", halfangle) - return cls(*args, **kwargs) - - def element_to_dict(elem: Element, encoder: Callable[[Any], Any] = _no_encoder) -> dict: """Builds the Matlab structure of an :py:class:`.Element` @@ -445,64 +378,6 @@ def element_to_dict(elem: Element, encoder: Callable[[Any], Any] = _no_encoder) return dct -def element_to_m(elem: Element) -> str: - """Builds the Matlab-evaluable string for an :py:class:`.Element` - - Parameters: - elem: :py:class:`.Element` - - Returns: - mstr (str): Matlab string representation of the - :py:class:`.Element` attributes - """ - - def convert(arg): - def convert_dict(pdir): - def scan(d): - for k, v in d.items(): - yield convert(k) - yield convert(v) - - return "struct({0})".format(", ".join(scan(pdir))) - - def convert_array(arr): - if arr.ndim > 1: - lns = (str(list(ln)).replace(",", "")[1:-1] for ln in arr) - return "".join(("[", "; ".join(lns), "]")) - elif arr.ndim > 0: - return str(list(arr)).replace(",", "") - else: - return str(arr) - - if isinstance(arg, np.ndarray): - return convert_array(arg) - elif isinstance(arg, dict): - return convert_dict(arg) - elif isinstance(arg, Particle): - return convert_dict(arg.to_dict()) - else: - return repr(arg) - - def m_name(elclass): - classname = elclass.__name__ - return _mat_constructor.get(classname, "".join(("at", classname.lower()))) - - attrs = dict(elem.items()) - # noinspection PyProtectedMember - args = [attrs.pop(k, getattr(elem, k)) for k in elem._BUILD_ATTRIBUTES] - defelem = elem.__class__(*args) - kwds = dict( - (k, v) - for k, v in attrs.items() - if not np.array_equal(v, getattr(defelem, k, None)) - ) - argstrs = [convert(arg) for arg in args] - if "PassMethod" in kwds: - argstrs.append(convert(kwds.pop("PassMethod"))) - argstrs += [", ".join((repr(k), convert(v))) for k, v in kwds.items()] - return "{0:>15}({1});...".format(m_name(elem.__class__), ", ".join(argstrs)) - - def split_ignoring_parentheses(string, delimiter): placeholder = "placeholder" substituted = string[:] @@ -519,3 +394,6 @@ def split_ignoring_parentheses(string, delimiter): assert not matches return replaced_parts + + +Element.to_dict = element_to_dict From a11005a9ff54af12b0bd9abd675085565806f050 Mon Sep 17 00:00:00 2001 From: Laurent Farvacque Date: Wed, 27 Mar 2024 15:17:18 +0100 Subject: [PATCH 11/22] restructured files --- pyat/at/load/matfile.py | 85 +++++++++++++++++------------------------ 1 file changed, 35 insertions(+), 50 deletions(-) diff --git a/pyat/at/load/matfile.py b/pyat/at/load/matfile.py index 03c92fe69..b6f1c7809 100644 --- a/pyat/at/load/matfile.py +++ b/pyat/at/load/matfile.py @@ -61,28 +61,10 @@ def _mat_encoder(v): return _mattype_map.get(type(v), lambda attr: attr)(v) -def matfile_generator( +def _matfile_generator( params: dict[str, Any], mat_file: str ) -> Generator[Element, None, None]: - """Run through Matlab cells and generate AT elements - - Parameters: - params: Lattice building parameters (see :py:class:`.Lattice`) - mat_file: File name - - The following keys in ``params`` are used: - - ============ =================== - **mat_key** name of the Matlab variable containing the lattice. - Default: Matlab variable name if there is only one, - otherwise 'RING' - **check** Skip the coherence tests - **quiet** Suppress the warning for non-standard classes - ============ =================== - - Yields: - elem (Element): new Elements - """ + """Run through Matlab cells and generate AT elements""" def mclean(data): if data.dtype.type is np.str_: @@ -212,7 +194,7 @@ def load_mat(filename: str, **kwargs) -> Lattice: kwargs.setdefault("use", kwargs.pop("key")) return Lattice( ringparam_filter, - matfile_generator, + _matfile_generator, abspath(filename), iterator=params_filter, **kwargs, @@ -284,34 +266,6 @@ def convert(value): return cls(*args, **kwargs) -def mfile_generator(params: dict, m_file: str) -> Generator[Element, None, None]: - """Run through the lines of a Matlab m-file and generate AT elements - - Parameters: - params: Lattice building parameters (see :py:class:`.Lattice`) - m_file: File name - - Yields: - elem (Element): new Elements - """ - with open(params.setdefault("in_file", m_file), "rt") as file: - _ = next(file) # Matlab function definition - _ = next(file) # Cell array opening - for lineno, line in enumerate(file): - if line.startswith("};"): - break - try: - elem = _element_from_m(line) - except ValueError: - warn(AtWarning("Invalid line {0} skipped.".format(lineno))) - continue - except KeyError: - warn(AtWarning("Line {0}: Unknown class.".format(lineno))) - continue - else: - yield elem - - def load_m(filename: str, **kwargs) -> Lattice: """Create a :py:class:`.Lattice` from a Matlab m-file @@ -336,6 +290,26 @@ def load_m(filename: str, **kwargs) -> Lattice: See Also: :py:func:`.load_lattice` for a generic lattice-loading function. """ + + def mfile_generator(params: dict, m_file: str) -> Generator[Element, None, None]: + """Run through the lines of a Matlab m-file and generate AT elements""" + with open(params.setdefault("in_file", m_file), "rt") as file: + _ = next(file) # Matlab function definition + _ = next(file) # Cell array opening + for lineno, line in enumerate(file): + if line.startswith("};"): + break + try: + elem = _element_from_m(line) + except ValueError: + warn(AtWarning("Invalid line {0} skipped.".format(lineno))) + continue + except KeyError: + warn(AtWarning("Line {0}: Unknown class.".format(lineno))) + continue + else: + yield elem + return Lattice( ringparam_filter, mfile_generator, @@ -504,8 +478,9 @@ def save(file): print("end", file=mfile) +# Simulates the deprecated "mat_file" and "mat_key" attributes def _mat_file(ring): - """.mat input file""" + """.mat input file. Deprecated, use 'in_file' instead.""" try: in_file = ring.in_file except AttributeError: @@ -517,7 +492,17 @@ def _mat_file(ring): return in_file +def _mat_key(ring): + """selected Matlab variable. Deprecated, use 'use' instead.""" + try: + mat_key = ring.use + except AttributeError: + raise AttributeError("'Lattice' object has no attribute 'mat_key'") + return mat_key + + register_format(".mat", load_mat, save_mat, descr="Matlab binary mat-file") register_format(".m", load_m, save_m, descr="Matlab text m-file") Lattice.mat_file = property(_mat_file, None, None) +Lattice.mat_key = property(_mat_key, None, None) From 4438652931c860f51bbfb90479f188837c541889 Mon Sep 17 00:00:00 2001 From: Laurent Farvacque Date: Wed, 27 Mar 2024 15:23:23 +0100 Subject: [PATCH 12/22] restructured files --- pyat/at/load/reprfile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyat/at/load/reprfile.py b/pyat/at/load/reprfile.py index 6b7f7028e..18f20dde2 100644 --- a/pyat/at/load/reprfile.py +++ b/pyat/at/load/reprfile.py @@ -2,7 +2,7 @@ its :py:func:`repr` string """ -from __future__ import print_function +from __future__ import annotations __all__ = ["load_repr", "save_repr"] From e8281f252d4c1d8fde55fdef629c68e9746ba3a6 Mon Sep 17 00:00:00 2001 From: Laurent Farvacque Date: Wed, 27 Mar 2024 15:35:30 +0100 Subject: [PATCH 13/22] restructured files --- pyat/at/load/json.py | 4 ++-- pyat/at/load/matfile.py | 5 ++--- pyat/at/load/utils.py | 3 ++- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pyat/at/load/json.py b/pyat/at/load/json.py index 49c0bfc5a..79fb6e9db 100644 --- a/pyat/at/load/json.py +++ b/pyat/at/load/json.py @@ -13,7 +13,7 @@ import numpy as np from .allfiles import register_format -from .utils import element_from_dict, keep_elements, keep_attributes +from .utils import keep_elements, keep_attributes from ..lattice import Element, Lattice, Particle @@ -86,7 +86,7 @@ def json_generator(params: dict[str, Any], fn): for k, v in properties.items(): params.setdefault(k, v) for idx, elem_dict in enumerate(elements): - yield element_from_dict(elem_dict, index=idx, check=False) + yield Element.from_dict(elem_dict, index=idx, check=False) return Lattice(abspath(filename), iterator=json_generator, **kwargs) diff --git a/pyat/at/load/matfile.py b/pyat/at/load/matfile.py index b6f1c7809..88a8da556 100644 --- a/pyat/at/load/matfile.py +++ b/pyat/at/load/matfile.py @@ -21,7 +21,6 @@ from .allfiles import register_format from .utils import split_ignoring_parentheses, RingParam, keep_elements -from .utils import element_from_dict from .utils import _drop_attrs, _CLASS_MAP from ..lattice import Element, Lattice, Particle, Filter from ..lattice import elements, AtWarning, params_filter, AtError @@ -98,7 +97,7 @@ def mclean(data): for index, mat_elem in enumerate(cell_array): elem = mat_elem[0, 0] kwargs = {f: mclean(elem[f]) for f in elem.dtype.fields} - yield element_from_dict(kwargs, index=index, check=check, quiet=quiet) + yield Element.from_dict(kwargs, index=index, check=check, quiet=quiet) def ringparam_filter( @@ -344,7 +343,7 @@ def load_var(matlat: Sequence[dict], **kwargs) -> Lattice: # noinspection PyUnusedLocal def var_generator(params, latt): for elem in latt: - yield element_from_dict(elem) + yield Element.from_dict(elem) return Lattice( ringparam_filter, var_generator, matlat, iterator=params_filter, **kwargs diff --git a/pyat/at/load/utils.py b/pyat/at/load/utils.py index e968eb8b2..1e9d015ee 100644 --- a/pyat/at/load/utils.py +++ b/pyat/at/load/utils.py @@ -363,7 +363,7 @@ def sanitise_class(index, cls, elem_dict): def element_to_dict(elem: Element, encoder: Callable[[Any], Any] = _no_encoder) -> dict: - """Builds the Matlab structure of an :py:class:`.Element` + """Convert a :py:class:`.Element` to a :py:class:`dict` Parameters: elem: :py:class:`.Element` @@ -396,4 +396,5 @@ def split_ignoring_parentheses(string, delimiter): return replaced_parts +Element.from_dict = staticmethod(element_from_dict) Element.to_dict = element_to_dict From 0539ccf7d5d77928f55f1080010e452ae5608a2f Mon Sep 17 00:00:00 2001 From: Laurent Farvacque Date: Wed, 27 Mar 2024 15:37:43 +0100 Subject: [PATCH 14/22] restructured files --- pyat/at/lattice/elements.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyat/at/lattice/elements.py b/pyat/at/lattice/elements.py index d2a4b8873..75d5f9405 100644 --- a/pyat/at/lattice/elements.py +++ b/pyat/at/lattice/elements.py @@ -370,7 +370,9 @@ def deepcopy(self) -> Element: def definition(self) -> tuple[str, tuple, dict]: """tuple (class_name, args, kwargs) defining the element""" attrs = dict(self.items()) - arguments = tuple(attrs.pop(k, getattr(self, k)) for k in self._BUILD_ATTRIBUTES) + arguments = tuple(attrs.pop( + k, getattr(self, k)) for k in self._BUILD_ATTRIBUTES + ) defelem = self.__class__(*arguments) keywords = dict( (k, v) From 14626a13ef1fb4d7bbb029aac55b0af7f0fc101d Mon Sep 17 00:00:00 2001 From: Laurent Farvacque Date: Wed, 27 Mar 2024 17:48:18 +0100 Subject: [PATCH 15/22] Documentation --- pyat/at/load/allfiles.py | 26 ++++--------- pyat/at/load/elegant.py | 11 +++--- pyat/at/load/json.py | 2 +- pyat/at/load/matfile.py | 24 ++++++++---- pyat/at/load/reprfile.py | 2 +- pyat/at/load/tracy.py | 82 ++++++++++++++++++++++------------------ 6 files changed, 77 insertions(+), 70 deletions(-) diff --git a/pyat/at/load/allfiles.py b/pyat/at/load/allfiles.py index be051ef32..2ca9f0626 100644 --- a/pyat/at/load/allfiles.py +++ b/pyat/at/load/allfiles.py @@ -13,40 +13,30 @@ def load_lattice(filepath: str, **kwargs) -> Lattice: """Load a Lattice object from a file -The file format is indicated by the filepath extension. +The file format is indicated by the filepath extension. The file name is stored in +the *in_file* Lattice attribute. The selected variable, if relevant, is stored +in the *use* Lattice attribute. Parameters: filepath: Name of the file Keyword Args: + use (str): Name of the variable containing the desired lattice. + Default: if there is a single variable, use it, otherwise select ``"RING"`` name (str): Name of the lattice. - Default: taken from the file, or ``''`` + Default: taken from the file, or ``""`` energy (float): Energy of the lattice (default: taken from the file) - periodicity (int]): Number of periods + periodicity (int): Number of periods (default: taken from the file, or 1) *: All other keywords will be set as :py:class:`.Lattice` attributes -Specific keywords for .mat files - -Keyword Args: - mat_key (str): Name of the Matlab variable containing - the lattice. Default: Matlab variable name if there is only one, - otherwise ``'RING'`` - check (bool): Run coherence tests. Default: :py:obj:`True` - quiet (bool): Suppress the warning for non-standard classes. - Default: :py:obj:`False` - keep_all (bool): Keep Matlab RingParam elements as Markers. - Default: :py:obj:`False` +Check the format-specific function for specific keyword arguments. Returns: lattice (Lattice): New :py:class:`.Lattice` object -See Also: - :py:func:`.load_mat`, :py:func:`.load_m`, :py:func:`.load_repr`, - :py:func:`.load_elegant`, :py:func:`.load_tracy` - .. Admonition:: Known extensions are: """ _, ext = os.path.splitext(filepath) diff --git a/pyat/at/load/elegant.py b/pyat/at/load/elegant.py index ebb99835a..237824ae7 100644 --- a/pyat/at/load/elegant.py +++ b/pyat/at/load/elegant.py @@ -337,10 +337,8 @@ def load_elegant(filename: str, **kwargs) -> Lattice: name (str): Name of the lattice. Default: taken from the file. energy (float): Energy of the lattice [eV] - periodicity(int): Number of periods. Default: taken from the - elements, or 1 - *: All other keywords will be set as Lattice - attributes + periodicity(int): Number of periods. Default: taken from the elements, or 1 + *: All other keywords will be set as Lattice attributes Returns: lattice (Lattice): New :py:class:`.Lattice` object @@ -354,7 +352,7 @@ def load_elegant(filename: str, **kwargs) -> Lattice: harmonic_number = kwargs.pop("harmonic_number") def elem_iterator(params, elegant_file): - with open(params.setdefault("elegant_file", elegant_file)) as f: + with open(params.setdefault("in_file", elegant_file)) as f: contents = f.read() element_lines = expand_elegant( contents, lattice_key, energy, harmonic_number @@ -370,4 +368,5 @@ def elem_iterator(params, elegant_file): 'lattice {}: {}'.format(filename, e)) -register_format(".lte", load_elegant, descr="Elegant format") +register_format( + ".lte", load_elegant, descr="Elegant format. See :py:func:`.load_elegant`.") diff --git a/pyat/at/load/json.py b/pyat/at/load/json.py index 79fb6e9db..f32d093bf 100644 --- a/pyat/at/load/json.py +++ b/pyat/at/load/json.py @@ -95,5 +95,5 @@ def json_generator(params: dict[str, Any], fn): ".json", load_json, save_json, - descr="JSON representation of a python AT Lattice", + descr="JSON representation of a python AT Lattice. See :py:func:`.load_json`.", ) diff --git a/pyat/at/load/matfile.py b/pyat/at/load/matfile.py index 88a8da556..36ba2f45c 100644 --- a/pyat/at/load/matfile.py +++ b/pyat/at/load/matfile.py @@ -163,9 +163,9 @@ def load_mat(filename: str, **kwargs) -> Lattice: Keyword Args: use (str): Name of the Matlab variable containing - the lattice. Default: Matlab variable name if there is only one, - otherwise 'RING' - mat_key (str): alias for *use* + the lattice. Default: it there is a single variable, use it, otherwise + select 'RING' + mat_key (str): deprecated alias for *use* check (bool): Run the coherence tests. Default: :py:obj:`True` quiet (bool): Suppress the warning for non-standard @@ -479,7 +479,7 @@ def save(file): # Simulates the deprecated "mat_file" and "mat_key" attributes def _mat_file(ring): - """.mat input file. Deprecated, use 'in_file' instead.""" + """.mat input file. Deprecated, use *in_file* instead.""" try: in_file = ring.in_file except AttributeError: @@ -492,7 +492,7 @@ def _mat_file(ring): def _mat_key(ring): - """selected Matlab variable. Deprecated, use 'use' instead.""" + """selected Matlab variable. Deprecated, use *use* instead.""" try: mat_key = ring.use except AttributeError: @@ -500,8 +500,18 @@ def _mat_key(ring): return mat_key -register_format(".mat", load_mat, save_mat, descr="Matlab binary mat-file") -register_format(".m", load_m, save_m, descr="Matlab text m-file") +register_format( + ".mat", + load_mat, + save_mat, + descr="Matlab binary mat-file. See :py:func:`.load_mat`.", +), +register_format( + ".m", + load_m, + save_m, + descr="Matlab text m-file. See :py:func:`.load_m`." +), Lattice.mat_file = property(_mat_file, None, None) Lattice.mat_key = property(_mat_key, None, None) diff --git a/pyat/at/load/reprfile.py b/pyat/at/load/reprfile.py index 18f20dde2..8a1fe7cc6 100644 --- a/pyat/at/load/reprfile.py +++ b/pyat/at/load/reprfile.py @@ -103,5 +103,5 @@ def save(file): load_repr, save_repr, descr=("Text representation of a python AT Lattice. " - "See also :py:func:`.load_repr`."), + "See :py:func:`.load_repr`."), ) diff --git a/pyat/at/load/tracy.py b/pyat/at/load/tracy.py index 3630ed497..b82df1ceb 100644 --- a/pyat/at/load/tracy.py +++ b/pyat/at/load/tracy.py @@ -4,6 +4,7 @@ This parser is quite similar to the Elegant parser in elegant.py. """ + import logging as log from os.path import abspath import re @@ -21,7 +22,7 @@ from at.lattice import Lattice from at.load import register_format, utils -__all__ = ['load_tracy'] +__all__ = ["load_tracy"] def create_drift(name, params, variables): @@ -51,15 +52,13 @@ def create_sext(name, params, variables): def create_dipole(name, params, variables): length = parse_float(params.pop("l", 0), variables) - params["NumIntSteps"] = parse_float(params.pop("n", 10), - variables) + params["NumIntSteps"] = parse_float(params.pop("n", 10), variables) params["PassMethod"] = "BndMPoleSymplectic4Pass" - params["BendingAngle"] = (parse_float(params.pop("t"), - variables) / 180) * numpy.pi - params["EntranceAngle"] = (parse_float(params.pop("t1"), - variables) / 180) * numpy.pi - params["ExitAngle"] = (parse_float(params.pop("t2"), - variables) / 180) * numpy.pi + params["BendingAngle"] = (parse_float(params.pop("t"), variables) / 180) * numpy.pi + params["EntranceAngle"] = ( + parse_float(params.pop("t1"), variables) / 180 + ) * numpy.pi + params["ExitAngle"] = (parse_float(params.pop("t2"), variables) / 180) * numpy.pi # Tracy is encoding gap plus fringe int in the 'gap' field. # Since BndMPoleSymplectic4Pass only uses the product of FringeInt # and gap we can substitute the following. @@ -97,8 +96,7 @@ def create_cavity(name, params, variables): params["Phi"] = parse_float(params.pop("phi", 0), variables) harmonic_number = variables["harmonic_number"] energy = variables["energy"] - return RFCavity(name, length, voltage, frequency, - harmonic_number, energy, **params) + return RFCavity(name, length, voltage, frequency, harmonic_number, energy, **params) ELEMENT_MAP = { @@ -187,27 +185,40 @@ def evaluate(tokens): try: b1 = tokens.index("(") b2 = len(tokens) - 1 - tokens[::-1].index(")") - return evaluate(tokens[:b1] + [evaluate(tokens[b1 + 1:b2])] - + tokens[b2 + 1:]) + return evaluate( + tokens[:b1] + [evaluate(tokens[b1 + 1 : b2])] + tokens[b2 + 1 :] + ) except ValueError: # No open parentheses found. pass # Evaluate / and * from left to right. for i, token in enumerate(tokens[:-1]): if token == "/": - return evaluate(tokens[:i-1] + [float(tokens[i-1]) - / float(tokens[i+1])] + tokens[i+2:]) + return evaluate( + tokens[: i - 1] + + [float(tokens[i - 1]) / float(tokens[i + 1])] + + tokens[i + 2 :] + ) if token == "*": - return evaluate(tokens[:i-1] + [float(tokens[i-1]) - * float(tokens[i+1])] + tokens[i+2:]) + return evaluate( + tokens[: i - 1] + + [float(tokens[i - 1]) * float(tokens[i + 1])] + + tokens[i + 2 :] + ) # Evaluate + and - from left to right. for i, token in enumerate(tokens[:-1]): if token == "+": - return evaluate(tokens[:i-1] + [float(tokens[i-1]) - + float(tokens[i+1])] + tokens[i+2:]) + return evaluate( + tokens[: i - 1] + + [float(tokens[i - 1]) + float(tokens[i + 1])] + + tokens[i + 2 :] + ) if token == "-": - return evaluate(tokens[:i-1] + [float(tokens[i-1]) - - float(tokens[i+1])] + tokens[i+2:]) + return evaluate( + tokens[: i - 1] + + [float(tokens[i - 1]) - float(tokens[i + 1])] + + tokens[i + 2 :] + ) return evaluate(tokens) @@ -288,9 +299,7 @@ def expand_tracy(contents, lattice_key, harmonic_number): else: key, value = line.split(":") if value.split(",")[0].strip() in ELEMENT_MAP: - elements[key] = tracy_element_from_string(key, - value, - variables) + elements[key] = tracy_element_from_string(key, value, variables) else: chunk = parse_chunk(value, elements, chunks) chunks[key] = chunk @@ -351,13 +360,12 @@ def load_tracy(filename: str, **kwargs) -> Lattice: filename: Name of a Tracy file Keyword Args: - name (str): Name of the lattice. Default: taken from - the file. + use (str): Name of the variable containing the desired lattice. + Default: ``cell`` + name (str): Name of the lattice. Default: taken from the file. energy (float): Energy of the lattice [eV] - periodicity(int): Number of periods. Default: taken from the - elements, or 1 - *: All other keywords will be set as Lattice - attributes + periodicity(int): Number of periods. Default: taken from the elements, or 1 + *: All other keywords will be set as Lattice attributes Returns: lattice (Lattice): New :py:class:`.Lattice` object @@ -366,11 +374,8 @@ def load_tracy(filename: str, **kwargs) -> Lattice: :py:func:`.load_lattice` for a generic lattice-loading function. """ try: - harmonic_number = kwargs.pop("harmonic_number") - lattice_key = kwargs.pop("lattice_key", "cell") - def elem_iterator(params, tracy_file): - with open(params.setdefault("tracy_file", tracy_file)) as f: + with open(params.setdefault("in_file", tracy_file)) as f: contents = f.read() element_lines, energy = expand_tracy( contents, lattice_key, harmonic_number @@ -379,10 +384,13 @@ def elem_iterator(params, tracy_file): for line in element_lines: yield line + if "lattice_key" in kwargs: # process the deprecated 'lattice_key' keyword + kwargs.setdefault("use", kwargs.pop("lattice_key")) + harmonic_number = kwargs.pop("harmonic_number") + lattice_key = kwargs.pop("use", "cell") return Lattice(abspath(filename), iterator=elem_iterator, **kwargs) except Exception as e: - raise ValueError('Failed to load tracy ' - 'lattice {}: {}'.format(filename, e)) + raise ValueError("Failed to load tracy " "lattice {}: {}".format(filename, e)) -register_format(".lat", load_tracy, descr="Tracy format") +register_format(".lat", load_tracy, descr="Tracy format. See :py:func:`.load_tracy`.") From c62baac6a487fe072a479b5868199c1f24279bc7 Mon Sep 17 00:00:00 2001 From: Laurent Farvacque Date: Thu, 28 Mar 2024 20:34:19 +0100 Subject: [PATCH 16/22] Documentation --- pyat/at/load/matfile.py | 9 +++++---- pyat/at/load/utils.py | 6 +++++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/pyat/at/load/matfile.py b/pyat/at/load/matfile.py index 36ba2f45c..510b197aa 100644 --- a/pyat/at/load/matfile.py +++ b/pyat/at/load/matfile.py @@ -211,7 +211,7 @@ def _element_from_m(line: str) -> Element: """ def argsplit(value): - return [a.strip() for a in split_ignoring_parentheses(value, ",")] + return [a.strip() for a in split_ignoring_parentheses(value)] def makedir(mat_struct): """Build directory from Matlab struct arguments""" @@ -505,13 +505,14 @@ def _mat_key(ring): load_mat, save_mat, descr="Matlab binary mat-file. See :py:func:`.load_mat`.", -), +) + register_format( ".m", load_m, save_m, - descr="Matlab text m-file. See :py:func:`.load_m`." -), + descr="Matlab text m-file. See :py:func:`.load_m`.", +) Lattice.mat_file = property(_mat_file, None, None) Lattice.mat_key = property(_mat_key, None, None) diff --git a/pyat/at/load/utils.py b/pyat/at/load/utils.py index 1e9d015ee..3d0eccddf 100644 --- a/pyat/at/load/utils.py +++ b/pyat/at/load/utils.py @@ -378,7 +378,11 @@ def element_to_dict(elem: Element, encoder: Callable[[Any], Any] = _no_encoder) return dct -def split_ignoring_parentheses(string, delimiter): +def split_ignoring_parentheses(string: str, delimiter: str = ",") -> list[str]: + """Split a string while keeping parenthesized expressions intact + + Example: "l=0,hom(4,0.0,0)" -> ["l=0", "hom(4,0.0,0)"] + """ placeholder = "placeholder" substituted = string[:] matches = collections.deque(re.finditer("\\(.*?\\)", string)) From 8255ebdb841a40b2db0c8e707ef98f3c7ea59f52 Mon Sep 17 00:00:00 2001 From: Laurent Farvacque Date: Tue, 28 May 2024 20:27:33 +0200 Subject: [PATCH 17/22] revert elements.py --- pyat/at/lattice/elements.py | 204 ++++++++++++++++++++---------------- 1 file changed, 111 insertions(+), 93 deletions(-) diff --git a/pyat/at/lattice/elements.py b/pyat/at/lattice/elements.py index 75d5f9405..9a52c5e9a 100644 --- a/pyat/at/lattice/elements.py +++ b/pyat/at/lattice/elements.py @@ -5,15 +5,21 @@ appropriate attributes. If a different PassMethod is set, it is the caller's responsibility to ensure that the appropriate attributes are present. """ + from __future__ import annotations + import abc import re -import numpy -from copy import copy, deepcopy from abc import ABC from collections.abc import Generator, Iterable +from copy import copy, deepcopy from typing import Any, Optional +import numpy + +# noinspection PyProtectedMember +from .variables import _nop + def _array(value, shape=(-1,), dtype=numpy.float64): # Ensure proper ordering(F) and alignment(A) for "C" access in integrators @@ -25,8 +31,17 @@ def _array66(value): return _array(value, shape=(6, 6)) -def _nop(value): - return value +def _float(value) -> float: + return float(value) + + +def _int(value, vmin: Optional[int] = None, vmax: Optional[int] = None) -> int: + intv = int(value) + if vmin is not None and intv < vmin: + raise ValueError(f"Value must be greater of equal to {vmin}") + if vmax is not None and intv > vmax: + raise ValueError(f"Value must be smaller of equal to {vmax}") + return intv class LongtMotion(ABC): @@ -42,6 +57,7 @@ class LongtMotion(ABC): * ``set_longt_motion(self, enable, new_pass=None, copy=False, **kwargs)`` must enable or disable longitudinal motion. """ + @abc.abstractmethod def _get_longt_motion(self): return False @@ -103,7 +119,8 @@ class _DictLongtMotion(LongtMotion): Defines a class such that :py:meth:`set_longt_motion` will select ``'IdentityPass'`` or ``'IdentityPass'``. - """ + """ + def _get_longt_motion(self): return self.PassMethod != self.default_pass[False] @@ -161,16 +178,20 @@ def set_longt_motion(self, enable, new_pass=None, copy=False, **kwargs): if new_pass is None or new_pass == self.PassMethod: return self if copy else None if enable: + def setpass(el): el.PassMethod = new_pass el.Energy = kwargs['energy'] + else: + def setpass(el): el.PassMethod = new_pass try: del el.Energy except AttributeError: pass + if copy: newelem = deepcopy(self) setpass(newelem) @@ -240,7 +261,7 @@ class Element(object): """Base class for AT elements""" _BUILD_ATTRIBUTES = ['FamName'] - _conversions = dict(FamName=str, PassMethod=str, Length=float, + _conversions = dict(FamName=str, PassMethod=str, Length=_float, R1=_array66, R2=_array66, T1=lambda v: _array(v, (6,)), T2=lambda v: _array(v, (6,)), @@ -248,9 +269,9 @@ class Element(object): EApertures=lambda v: _array(v, (2,)), KickAngle=lambda v: _array(v, (2,)), PolynomB=_array, PolynomA=_array, - BendingAngle=float, - MaxOrder=int, NumIntSteps=int, - Energy=float, + BendingAngle=_float, + MaxOrder=_int, NumIntSteps=lambda v: _int(v, vmin=0), + Energy=_float, ) _entrance_fields = ['T1', 'R1'] @@ -272,12 +293,13 @@ def __init__(self, family_name: str, **kwargs): def __setattr__(self, key, value): try: - super(Element, self).__setattr__( - key, self._conversions.get(key, _nop)(value)) + value = self._conversions.get(key, _nop)(value) except Exception as exc: exc.args = ('In element {0}, parameter {1}: {2}'.format( self.FamName, key, exc),) raise + else: + super(Element, self).__setattr__(key, value) def __str__(self): return "\n".join( @@ -321,10 +343,12 @@ def divide(self, frac) -> list[Element]: def swap_faces(self, copy=False): """Swap the faces of an element, alignment errors are ignored""" + def swapattr(element, attro, attri): val = getattr(element, attri) delattr(element, attri) return attro, val + if copy: el = self.copy() else: @@ -355,7 +379,7 @@ def update(self, *args, **kwargs): Update the element attributes with the given arguments """ attrs = dict(*args, **kwargs) - for (key, value) in attrs.items(): + for key, value in attrs.items(): setattr(self, key, value) def copy(self) -> Element: @@ -397,8 +421,7 @@ def merge(self, other) -> None: """Merge another element""" if not self.is_compatible(other): badname = getattr(other, 'FamName', type(other)) - raise TypeError('Cannot merge {0} and {1}'.format(self.FamName, - badname)) + raise TypeError("Cannot merge {0} and {1}".format(self.FamName, badname)) # noinspection PyMethodMayBeStatic def _get_longt_motion(self): @@ -420,8 +443,8 @@ def is_collective(self) -> bool: class LongElement(Element): - """Base class for long elements - """ + """Base class for long elements""" + _BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES + ['Length'] def __init__(self, family_name: str, length: float, *args, **kwargs): @@ -455,8 +478,7 @@ def popattr(element, attr): # Remove entrance and exit attributes fin = dict(popattr(el, key) for key in vars(self) if key in self._entrance_fields) - fout = dict(popattr(el, key) for key in vars(self) if - key in self._exit_fields) + fout = dict(popattr(el, key) for key in vars(self) if key in self._exit_fields) # Split element element_list = [el._part(f, numpy.sum(frac)) for f in frac] # Restore entrance and exit attributes @@ -529,6 +551,7 @@ def means(self): class SliceMoments(Element): """Element to compute slices mean and std""" + _BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES + ['nslice'] _conversions = dict(Element._conversions, nslice=int) @@ -547,8 +570,7 @@ def __init__(self, family_name: str, nslice: int, **kwargs): kwargs.setdefault('PassMethod', 'SliceMomentsPass') self._startturn = kwargs.pop('startturn', 0) self._endturn = kwargs.pop('endturn', 1) - super(SliceMoments, self).__init__(family_name, nslice=nslice, - **kwargs) + super(SliceMoments, self).__init__(family_name, nslice=nslice, **kwargs) self._nbunch = 1 self.startturn = self._startturn self.endturn = self._endturn @@ -563,45 +585,33 @@ def set_buffers(self, nturns, nbunch): self.endturn = min(self.endturn, nturns) self._dturns = self.endturn - self.startturn self._nbunch = nbunch - self._stds = numpy.zeros((3, nbunch*self.nslice, self._dturns), - order='F') - self._means = numpy.zeros((3, nbunch*self.nslice, self._dturns), - order='F') - self._spos = numpy.zeros((nbunch*self.nslice, self._dturns), - order='F') - self._weights = numpy.zeros((nbunch*self.nslice, self._dturns), - order='F') + self._stds = numpy.zeros((3, nbunch*self.nslice, self._dturns), order="F") + self._means = numpy.zeros((3, nbunch*self.nslice, self._dturns), order="F") + self._spos = numpy.zeros((nbunch*self.nslice, self._dturns), order="F") + self._weights = numpy.zeros((nbunch*self.nslice, self._dturns), order="F") @property def stds(self): """Slices x,y,dp standard deviation""" - return self._stds.reshape((3, self._nbunch, - self.nslice, - self._dturns)) + return self._stds.reshape((3, self._nbunch, self.nslice, self._dturns)) @property def means(self): """Slices x,y,dp center of mass""" - return self._means.reshape((3, self._nbunch, - self.nslice, - self._dturns)) + return self._means.reshape((3, self._nbunch, self.nslice, self._dturns)) @property def spos(self): """Slices s position""" - return self._spos.reshape((self._nbunch, - self.nslice, - self._dturns)) + return self._spos.reshape((self._nbunch, self.nslice, self._dturns)) @property def weights(self): """Slices weights in mA if beam current >0, - otherwise fraction of total number of - particles in the bunch + otherwise fraction of total number of + particles in the bunch """ - return self._weights.reshape((self._nbunch, - self.nslice, - self._dturns)) + return self._weights.reshape((self._nbunch, self.nslice, self._dturns)) @property def startturn(self): @@ -632,6 +642,7 @@ def endturn(self, value): class Aperture(Element): """Aperture element""" + _BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES + ['Limits'] _conversions = dict(Element._conversions, Limits=lambda v: _array(v, (4,))) @@ -710,6 +721,7 @@ def insert(self, class Collimator(Drift): """Collimator element""" + _BUILD_ATTRIBUTES = LongElement._BUILD_ATTRIBUTES + ['RApertures'] def __init__(self, family_name: str, length: float, limits, **kwargs): @@ -728,8 +740,8 @@ def __init__(self, family_name: str, length: float, limits, **kwargs): class ThinMultipole(Element): """Thin multipole element""" - _BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES + ['PolynomA', - 'PolynomB'] + + _BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES + ["PolynomA", "PolynomB"] def __init__(self, family_name: str, poly_a, poly_b, **kwargs): """ @@ -757,10 +769,13 @@ def lengthen(poly, dl): else: return poly - # Remove MaxOrder, PolynomA and PolynomB - poly_a, len_a, ord_a = getpol(_array(kwargs.pop('PolynomA', poly_a))) - poly_b, len_b, ord_b = getpol(_array(kwargs.pop('PolynomB', poly_b))) + # PolynomA and PolynomB and convert to ParamArray + prmpola = self._conversions["PolynomA"](kwargs.pop("PolynomA", poly_a)) + prmpolb = self._conversions["PolynomB"](kwargs.pop("PolynomB", poly_b)) + poly_a, len_a, ord_a = getpol(prmpola) + poly_b, len_b, ord_b = getpol(prmpolb) deforder = max(getattr(self, 'DefaultOrder', 0), ord_a, ord_b) + # Remove MaxOrder maxorder = kwargs.pop('MaxOrder', deforder) kwargs.setdefault('PassMethod', 'ThinMPolePass') super(ThinMultipole, self).__init__(family_name, **kwargs) @@ -768,36 +783,32 @@ def lengthen(poly, dl): super(ThinMultipole, self).__setattr__('MaxOrder', maxorder) # Adjust polynom lengths and set them len_ab = max(self.MaxOrder + 1, len_a, len_b) - self.PolynomA = lengthen(poly_a, len_ab - len_a) - self.PolynomB = lengthen(poly_b, len_ab - len_b) + self.PolynomA = lengthen(prmpola, len_ab - len_a) + self.PolynomB = lengthen(prmpolb, len_ab - len_b) def __setattr__(self, key, value): """Check the compatibility of MaxOrder, PolynomA and PolynomB""" polys = ('PolynomA', 'PolynomB') if key in polys: - value = _array(value) - lmin = getattr(self, 'MaxOrder') + lmin = self.MaxOrder if not len(value) > lmin: raise ValueError( 'Length of {0} must be larger than {1}'.format(key, lmin)) elif key == 'MaxOrder': - value = int(value) + intval = int(value) lmax = min(len(getattr(self, k)) for k in polys) - if not value < lmax: - raise ValueError( - 'MaxOrder must be smaller than {0}'.format(lmax)) - + if not intval < lmax: + raise ValueError("MaxOrder must be smaller than {0}".format(lmax)) super(ThinMultipole, self).__setattr__(key, value) class Multipole(_Radiative, LongElement, ThinMultipole): """Multipole element""" - _BUILD_ATTRIBUTES = LongElement._BUILD_ATTRIBUTES + ['PolynomA', - 'PolynomB'] + + _BUILD_ATTRIBUTES = LongElement._BUILD_ATTRIBUTES + ["PolynomA", "PolynomB"] _conversions = dict(ThinMultipole._conversions, K=float, H=float) - def __init__(self, family_name: str, length: float, poly_a, poly_b, - **kwargs): + def __init__(self, family_name: str, length: float, poly_a, poly_b, **kwargs): """ Args: family_name: Name of the element @@ -817,12 +828,10 @@ def __init__(self, family_name: str, length: float, poly_a, poly_b, """ kwargs.setdefault('PassMethod', 'StrMPoleSymplectic4Pass') kwargs.setdefault('NumIntSteps', 10) - super(Multipole, self).__init__(family_name, length, - poly_a, poly_b, **kwargs) + super(Multipole, self).__init__(family_name, length, poly_a, poly_b, **kwargs) def is_compatible(self, other) -> bool: - if super().is_compatible(other) and \ - self.MaxOrder == other.MaxOrder: + if super().is_compatible(other) and self.MaxOrder == other.MaxOrder: for i in range(self.MaxOrder + 1): if self.PolynomB[i] != other.PolynomB[i]: return False @@ -836,7 +845,8 @@ def is_compatible(self, other) -> bool: @property def K(self) -> float: """Focusing strength [mˆ-2]""" - return 0.0 if len(self.PolynomB) < 2 else self.PolynomB[1] + arr = self.PolynomB + return 0.0 if len(arr) < 2 else arr[1] # noinspection PyPep8Naming @K.setter @@ -847,7 +857,8 @@ def K(self, strength: float): @property def H(self) -> float: """Sextupolar strength [mˆ-3]""" - return 0.0 if len(self.PolynomB) < 3 else self.PolynomB[2] + arr = self.PolynomB + return 0.0 if len(arr) < 3 else arr[2] # noinspection PyPep8Naming @H.setter @@ -920,16 +931,15 @@ def __init__(self, family_name: str, length: float, Default PassMethod: :ref:`BndMPoleSymplectic4Pass` """ - poly_b = kwargs.pop('PolynomB', numpy.array([0, k])) kwargs.setdefault('BendingAngle', bending_angle) kwargs.setdefault('EntranceAngle', 0.0) kwargs.setdefault('ExitAngle', 0.0) kwargs.setdefault('PassMethod', 'BndMPoleSymplectic4Pass') - super(Dipole, self).__init__(family_name, length, [], poly_b, **kwargs) + super(Dipole, self).__init__(family_name, length, [], [0.0, k], **kwargs) - def items(self) -> Generator[tuple, None, None]: + def items(self) -> Generator[tuple[str, Any], None, None]: yield from super().items() - yield 'K', self.K + yield "K", vars(self)["PolynomB"][1] def _part(self, fr, sumfr): pp = super(Dipole, self)._part(fr, sumfr) @@ -942,9 +952,9 @@ def is_compatible(self, other) -> bool: def invrho(dip: Dipole): return dip.BendingAngle / dip.Length - return (super().is_compatible(other) and - self.ExitAngle == -other.EntranceAngle and - abs(invrho(self) - invrho(other)) <= 1.e-6) + return (super().is_compatible(other) + and self.ExitAngle == -other.EntranceAngle + and abs(invrho(self) - invrho(other)) <= 1.e-6) def merge(self, other) -> None: super().merge(other) @@ -959,6 +969,7 @@ def merge(self, other) -> None: class Quadrupole(Radiative, Multipole): """Quadrupole element""" + _BUILD_ATTRIBUTES = LongElement._BUILD_ATTRIBUTES + ['K'] _conversions = dict(Multipole._conversions, FringeQuadEntrance=int, FringeQuadExit=int) @@ -996,18 +1007,17 @@ def __init__(self, family_name: str, length: float, Default PassMethod: ``StrMPoleSymplectic4Pass`` """ - poly_b = kwargs.pop('PolynomB', numpy.array([0, k])) - kwargs.setdefault('PassMethod', 'StrMPoleSymplectic4Pass') - super(Quadrupole, self).__init__(family_name, length, [], poly_b, - **kwargs) + kwargs.setdefault("PassMethod", "StrMPoleSymplectic4Pass") + super(Quadrupole, self).__init__(family_name, length, [], [0.0, k], **kwargs) - def items(self) -> Generator[tuple, None, None]: + def items(self) -> Generator[tuple[str, Any], None, None]: yield from super().items() - yield 'K', self.K + yield "K", vars(self)["PolynomB"][1] class Sextupole(Multipole): """Sextupole element""" + _BUILD_ATTRIBUTES = LongElement._BUILD_ATTRIBUTES + ['H'] DefaultOrder = 2 @@ -1031,14 +1041,18 @@ def __init__(self, family_name: str, length: float, Default PassMethod: ``StrMPoleSymplectic4Pass`` """ - poly_b = kwargs.pop('PolynomB', [0, 0, h]) - kwargs.setdefault('PassMethod', 'StrMPoleSymplectic4Pass') - super(Sextupole, self).__init__(family_name, length, [], poly_b, + kwargs.setdefault("PassMethod", "StrMPoleSymplectic4Pass") + super(Sextupole, self).__init__(family_name, length, [], [0.0, 0.0, h], **kwargs) + def items(self) -> Generator[tuple[str, Any], None, None]: + yield from super().items() + yield "H", vars(self)["PolynomB"][2] + class Octupole(Multipole): """Octupole element, with no changes from multipole at present""" + _BUILD_ATTRIBUTES = Multipole._BUILD_ATTRIBUTES DefaultOrder = 3 @@ -1046,6 +1060,7 @@ class Octupole(Multipole): class RFCavity(LongtMotion, LongElement): """RF cavity element""" + _BUILD_ATTRIBUTES = LongElement._BUILD_ATTRIBUTES + ['Voltage', 'Frequency', 'HarmNumber', @@ -1086,9 +1101,9 @@ def _part(self, fr, sumfr): return pp def is_compatible(self, other) -> bool: - return (super().is_compatible(other) and - self.Frequency == other.Frequency and - self.TimeLag == other.TimeLag) + return (super().is_compatible(other) + and self.Frequency == other.Frequency + and self.TimeLag == other.TimeLag) def merge(self, other) -> None: super().merge(other) @@ -1107,6 +1122,7 @@ def set_longt_motion(self, enable, new_pass=None, **kwargs): class M66(Element): """Linear (6, 6) transfer matrix""" + _BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES + ["M66"] _conversions = dict(Element._conversions, M66=_array66) @@ -1117,7 +1133,7 @@ def __init__(self, family_name: str, m66=None, **kwargs): m66: Transfer matrix. Default: Identity matrix Default PassMethod: ``Matrix66Pass`` - """ + """ if m66 is None: m66 = numpy.identity(6) kwargs.setdefault('PassMethod', 'Matrix66Pass') @@ -1133,6 +1149,7 @@ class SimpleQuantDiff(_DictLongtMotion, Element): Note: The damping times are needed to compute the correct kick for the emittance. Radiation damping is NOT applied. """ + _BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES default_pass = {False: 'IdentityPass', True: 'SimpleQuantDiffPass'} @@ -1157,8 +1174,8 @@ def __init__(self, family_name: str, betax: float = 1.0, tauz: Longitudinal damping time [turns] Default PassMethod: ``SimpleQuantDiffPass`` - """ - kwargs.setdefault('PassMethod', self.default_pass[True]) + """ + kwargs.setdefault("PassMethod", self.default_pass[True]) assert taux >= 0.0, 'taux must be greater than or equal to 0' self.taux = taux @@ -1191,6 +1208,7 @@ def __init__(self, family_name: str, betax: float = 1.0, class SimpleRadiation(_DictLongtMotion, Radiative, Element): """Simple radiation damping and energy loss""" + _BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES _conversions = dict(Element._conversions, U0=float, damp_mat_diag=lambda v: _array(v, shape=(6,))) @@ -1212,7 +1230,7 @@ def __init__(self, family_name: str, U0: Energy loss per turn [eV] Default PassMethod: ``SimpleRadiationPass`` - """ + """ assert taux >= 0.0, 'taux must be greater than or equal to 0' if taux == 0.0: dampx = 1 @@ -1241,6 +1259,7 @@ def __init__(self, family_name: str, class Corrector(LongElement): """Corrector element""" + _BUILD_ATTRIBUTES = LongElement._BUILD_ATTRIBUTES + ['KickAngle'] def __init__(self, family_name: str, length: float, kick_angle, **kwargs): @@ -1266,6 +1285,7 @@ class Wiggler(Radiative, LongElement): See atwiggler.m """ + _BUILD_ATTRIBUTES = LongElement._BUILD_ATTRIBUTES + ['Lw', 'Bmax', 'Energy'] _conversions = dict(Element._conversions, Lw=float, Bmax=float, @@ -1308,14 +1328,12 @@ def __init__(self, family_name: str, length: float, wiggle_period: float, for i, b in enumerate(self.By.T): dk = abs(b[3] ** 2 - b[4] ** 2 - b[2] ** 2) / abs(b[4]) if dk > 1e-6: - raise ValueError("Wiggler(H): kx^2 + kz^2 -ky^2 !=0, i = " - "{0}".format(i)) + raise ValueError("Wiggler(H): kx^2 + kz^2 -ky^2 !=0, i = {0}".format(i)) for i, b in enumerate(self.Bx.T): dk = abs(b[2] ** 2 - b[4] ** 2 - b[3] ** 2) / abs(b[4]) if dk > 1e-6: - raise ValueError("Wiggler(V): ky^2 + kz^2 -kx^2 !=0, i = " - "{0}".format(i)) + raise ValueError("Wiggler(V): ky^2 + kz^2 -kx^2 !=0, i = {0}".format(i)) self.NHharm = self.By.shape[1] self.NVharm = self.Bx.shape[1] From 24a10e8f87407e2b1b8c06fef11803aecf49d782 Mon Sep 17 00:00:00 2001 From: Laurent Farvacque Date: Wed, 29 May 2024 16:57:42 +0200 Subject: [PATCH 18/22] Restore the "mat_key" keyword argument in load_matfile --- pyat/at/load/matfile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyat/at/load/matfile.py b/pyat/at/load/matfile.py index 510b197aa..2cd4abdb2 100644 --- a/pyat/at/load/matfile.py +++ b/pyat/at/load/matfile.py @@ -190,7 +190,7 @@ def load_mat(filename: str, **kwargs) -> Lattice: if "key" in kwargs: # process the deprecated 'key' keyword kwargs.setdefault("use", kwargs.pop("key")) if "mat_key" in kwargs: # process the deprecated 'mat_key' keyword - kwargs.setdefault("use", kwargs.pop("key")) + kwargs.setdefault("use", kwargs.pop("mat_key")) return Lattice( ringparam_filter, _matfile_generator, From 5e19ee6970982ec9957840c0aaf8247a0ded7657 Mon Sep 17 00:00:00 2001 From: Laurent Farvacque Date: Wed, 29 May 2024 17:37:59 +0200 Subject: [PATCH 19/22] Documentation and type hints --- pyat/at/load/allfiles.py | 95 ++++++++++++++++++++-------------------- 1 file changed, 47 insertions(+), 48 deletions(-) diff --git a/pyat/at/load/allfiles.py b/pyat/at/load/allfiles.py index 2ca9f0626..7085f316d 100644 --- a/pyat/at/load/allfiles.py +++ b/pyat/at/load/allfiles.py @@ -1,10 +1,14 @@ """Generic function to save and load python AT lattices. The format is determined by the file extension """ + +from __future__ import annotations import os.path +from typing import Optional +from collections.abc import Callable from at.lattice import Lattice -__all__ = ['load_lattice', 'save_lattice', 'register_format'] +__all__ = ["load_lattice", "save_lattice", "register_format"] _load_extension = {} _save_extension = {} @@ -13,31 +17,31 @@ def load_lattice(filepath: str, **kwargs) -> Lattice: """Load a Lattice object from a file -The file format is indicated by the filepath extension. The file name is stored in -the *in_file* Lattice attribute. The selected variable, if relevant, is stored -in the *use* Lattice attribute. - -Parameters: - filepath: Name of the file - -Keyword Args: - use (str): Name of the variable containing the desired lattice. - Default: if there is a single variable, use it, otherwise select ``"RING"`` - name (str): Name of the lattice. - Default: taken from the file, or ``""`` - energy (float): Energy of the lattice - (default: taken from the file) - periodicity (int): Number of periods - (default: taken from the file, or 1) - *: All other keywords will be set as :py:class:`.Lattice` - attributes - -Check the format-specific function for specific keyword arguments. - -Returns: - lattice (Lattice): New :py:class:`.Lattice` object + The file format is indicated by the filepath extension. The file name is stored in + the *in_file* Lattice attribute. The selected variable, if relevant, is stored + in the *use* Lattice attribute. -.. Admonition:: Known extensions are: + Parameters: + filepath: Name of the file + + Keyword Args: + use (str): Name of the variable containing the desired lattice. + Default: if there is a single variable, use it, otherwise select ``"RING"`` + name (str): Name of the lattice. + Default: taken from the file, or ``""`` + energy (float): Energy of the lattice + (default: taken from the file) + periodicity (int): Number of periods + (default: taken from the file, or 1) + *: All other keywords will be set as :py:class:`.Lattice` + attributes + + Returns: + lattice (Lattice): New :py:class:`.Lattice` object + + Check the format-specific function for specific keyword arguments: + + .. Admonition:: Known extensions are: """ _, ext = os.path.splitext(filepath) try: @@ -48,25 +52,18 @@ def load_lattice(filepath: str, **kwargs) -> Lattice: return load_func(filepath, **kwargs) -def save_lattice(ring: Lattice, filepath: str, **kwargs): +def save_lattice(ring: Lattice, filepath: str, **kwargs) -> None: """Save a Lattice object -The file format is indicated by the filepath extension. - -Parameters: - ring: Lattice description - filepath: Name of the file + The file format is indicated by the filepath extension. -Specific keywords for .mat files - -Keyword Args: - mat_key (str): Name of the Matlab variable containing the lattice. - Default: ``'RING'`` + Parameters: + ring: Lattice description + filepath: Name of the file -See Also: - :py:func:`.save_mat`, :py:func:`.save_m`, :py:func:`.save_repr` + Check the format-specific function for specific keyword arguments: -.. Admonition:: Known extensions are: + .. Admonition:: Known extensions are: """ _, ext = os.path.splitext(filepath) try: @@ -77,24 +74,26 @@ def save_lattice(ring: Lattice, filepath: str, **kwargs): return save_func(ring, filepath, **kwargs) -def register_format(extension: str, load_func=None, save_func=None, - descr: str = ''): +def register_format( + extension: str, + load_func: Optional[Callable[..., Lattice]] = None, + save_func: Optional[Callable[..., None]] = None, + descr: str = "", +): """Register format-specific processing functions Parameters: extension: File extension string. - load_func: load function. Default: :py:obj:`None` - save_func: save_lattice function Default: :py:obj:`None` - descr: File type description + load_func: load function. + save_func: save function. + descr: File type description. """ if load_func is not None: _load_extension[extension] = load_func - load_lattice.__doc__ += '\n {0:<10}'\ - '\n {1}\n'.format(extension, descr) + load_lattice.__doc__ += f"\n {extension:<10}\n {descr}\n" if save_func is not None: _save_extension[extension] = save_func - save_lattice.__doc__ += '\n {0:<10}'\ - '\n {1}\n'.format(extension, descr) + save_lattice.__doc__ += f"\n {extension:<10}\n {descr}\n" Lattice.load = staticmethod(load_lattice) From 1be42dc7c1480235af4c20e80e5ab0c3f96ad762 Mon Sep 17 00:00:00 2001 From: Laurent Farvacque Date: Wed, 29 May 2024 18:30:58 +0200 Subject: [PATCH 20/22] Added a file signature, for later use --- atmat/lattice/atloadlattice.m | 6 ++++++ atmat/lattice/atwritejson.m | 1 + pyat/at/load/json.py | 12 +++++++++++- 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/atmat/lattice/atloadlattice.m b/atmat/lattice/atloadlattice.m index 2c0190ae8..56e7439a9 100644 --- a/atmat/lattice/atloadlattice.m +++ b/atmat/lattice/atloadlattice.m @@ -83,6 +83,12 @@ function [lattice, opts]=load_json(fpath, opts) data=jsondecode(fileread(fpath)); + % File signature for later use + try + atjson=data.atjson; + catch + atjson=1; + end props=data.properties; name=props.name; energy=props.energy; diff --git a/atmat/lattice/atwritejson.m b/atmat/lattice/atwritejson.m index 524e06444..2ac8bfabe 100644 --- a/atmat/lattice/atwritejson.m +++ b/atmat/lattice/atwritejson.m @@ -38,6 +38,7 @@ function jsondata=sjson(ring) ok=~atgetcells(ring, 'Class', 'RingParam'); + data.atjson= 1; data.elements=ring(ok); data.properties=get_params(ring); jsondata=jsonencode(data, 'PrettyPrint', ~compact); diff --git a/pyat/at/load/json.py b/pyat/at/load/json.py index f32d093bf..12453e4ac 100644 --- a/pyat/at/load/json.py +++ b/pyat/at/load/json.py @@ -48,7 +48,9 @@ def save_json( :py:meth:`.Lattice.save` for a generic lattice-saving method. """ indent = None if compact else 2 - data = dict(elements=list(keep_elements(ring)), properties=keep_attributes(ring)) + data = dict( + atjson=1, elements=list(keep_elements(ring)), properties=keep_attributes(ring) + ) if filename is None: print(json.dumps(data, cls=_AtEncoder, indent=indent)) else: @@ -76,13 +78,21 @@ def json_generator(params: dict[str, Any], fn): with open(params.setdefault("in_file", fn), "rt") as jsonfile: data = json.load(jsonfile) + # Check the file signature - For later use + try: + atjson = data["atjson"] # noqa F841 + except KeyError: + atjson = 1 # noqa F841 + # Get elements elements = data["elements"] + # Get lattice properties try: properties = data["properties"] except KeyError: properties = {} particle_dict = properties.pop("particle", {}) params.setdefault("particle", Particle(**particle_dict)) + for k, v in properties.items(): params.setdefault(k, v) for idx, elem_dict in enumerate(elements): From 95ae1bd133a013524f644ba3778d46f6f3949fb7 Mon Sep 17 00:00:00 2001 From: Laurent Farvacque Date: Wed, 29 May 2024 20:40:18 +0200 Subject: [PATCH 21/22] optimised imports --- pyat/at/load/allfiles.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pyat/at/load/allfiles.py b/pyat/at/load/allfiles.py index 7085f316d..cb3b154cf 100644 --- a/pyat/at/load/allfiles.py +++ b/pyat/at/load/allfiles.py @@ -3,12 +3,14 @@ """ from __future__ import annotations + +__all__ = ["load_lattice", "save_lattice", "register_format"] + import os.path -from typing import Optional from collections.abc import Callable -from at.lattice import Lattice +from typing import Optional -__all__ = ["load_lattice", "save_lattice", "register_format"] +from at.lattice import Lattice _load_extension = {} _save_extension = {} From cf7398ed39c98a4db049714219f98bd87f9ffdc9 Mon Sep 17 00:00:00 2001 From: Laurent Farvacque Date: Thu, 30 May 2024 12:04:30 +0200 Subject: [PATCH 22/22] Added tests for save and load --- pyat/at/load/matfile.py | 16 ++++++++++------ pyat/test/test_load_save.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 6 deletions(-) create mode 100644 pyat/test/test_load_save.py diff --git a/pyat/at/load/matfile.py b/pyat/at/load/matfile.py index 2cd4abdb2..92d10666e 100644 --- a/pyat/at/load/matfile.py +++ b/pyat/at/load/matfile.py @@ -374,21 +374,25 @@ def required(rng): yield from keep_elements(ring) -def save_mat(ring: Lattice, filename: str, mat_key: str = "RING") -> None: +def save_mat(ring: Lattice, filename: str, **kwargs) -> None: """Save a :py:class:`.Lattice` as a Matlab mat-file Parameters: - ring: Lattice description - filename: Name of the '.mat' file - mat_key (str): Name of the Matlab variable containing - the lattice. Default: ``'RING'`` + ring: Lattice description + filename: Name of the '.mat' file + + Keyword Args: + use (str): Name of the Matlab variable containing the lattice, Default: "RING" + mat_key (str): Deprecated, alias for *use* See Also: :py:func:`.save_lattice` for a generic lattice-saving function. """ # Ensure the lattice is a Matlab column vector: list(list) + use = kwargs.pop("mat_key", "RING") # For backward compatibility + use = kwargs.pop("use", use) lring = [[el.to_dict(encoder=_mat_encoder)] for el in matlab_ring(ring)] - scipy.io.savemat(filename, {mat_key: lring}, long_field_names=True) + scipy.io.savemat(filename, {use: lring}, long_field_names=True) def _element_to_m(elem: Element) -> str: diff --git a/pyat/test/test_load_save.py b/pyat/test/test_load_save.py new file mode 100644 index 000000000..b45b1553a --- /dev/null +++ b/pyat/test/test_load_save.py @@ -0,0 +1,36 @@ +import os +from tempfile import mktemp + +import pytest +from numpy.testing import assert_equal + +from at.lattice import Lattice + + +@pytest.mark.parametrize("lattice", ["dba_lattice", "hmba_lattice"]) +@pytest.mark.parametrize( + "suffix, options", + [ + (".m", {}), + (".repr", {}), + (".mat", {"use": "abcd"}), + (".json", {}), + ], +) +def test_m(request, lattice, suffix, options): + ring0 = request.getfixturevalue(lattice) + fname = mktemp(suffix=suffix) + + # Create a new .m or .repr file + ring0.save(fname, **options) + + # load the new file + ring1 = Lattice.load(fname, **options) + + # Check that we get back the original lattice + el1, rg1, _ = ring0.linopt6() + el2, rg2, _ = ring1.linopt6() + assert_equal(rg1.tune, rg2.tune) + assert_equal(rg1.chromaticity, rg2.chromaticity) + + os.unlink(fname)